diff --git a/ISMIR_2025/MERT/__pycache__/datalib.cpython-311.pyc b/ISMIR_2025/MERT/__pycache__/datalib.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c25d929f202b8609fdadccf09fa33f2815fb7df Binary files /dev/null and b/ISMIR_2025/MERT/__pycache__/datalib.cpython-311.pyc differ diff --git a/ISMIR_2025/MERT/__pycache__/datalib.cpython-312.pyc b/ISMIR_2025/MERT/__pycache__/datalib.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1270e9134dfdf68b1658a320f107df6ef3a2ba92 Binary files /dev/null and b/ISMIR_2025/MERT/__pycache__/datalib.cpython-312.pyc differ diff --git a/ISMIR_2025/MERT/__pycache__/datalib_singfake.cpython-311.pyc b/ISMIR_2025/MERT/__pycache__/datalib_singfake.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03f5ac2f8487445f7c584207de635871e35e59f3 Binary files /dev/null and b/ISMIR_2025/MERT/__pycache__/datalib_singfake.cpython-311.pyc differ diff --git a/ISMIR_2025/MERT/__pycache__/main.cpython-312.pyc b/ISMIR_2025/MERT/__pycache__/main.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b49c972a99e034e49974bd9f3dfe2e32b703c78e Binary files /dev/null and b/ISMIR_2025/MERT/__pycache__/main.cpython-312.pyc differ diff --git a/ISMIR_2025/MERT/__pycache__/networks.cpython-311.pyc b/ISMIR_2025/MERT/__pycache__/networks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98e1f8197ddd0a2f39c8d17f12beb6d92e0f54dd Binary files /dev/null and b/ISMIR_2025/MERT/__pycache__/networks.cpython-311.pyc differ diff --git a/ISMIR_2025/MERT/__pycache__/networks.cpython-312.pyc b/ISMIR_2025/MERT/__pycache__/networks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ca6d42f07755a2c3028bc0dcace83985cfceb20 Binary files /dev/null and b/ISMIR_2025/MERT/__pycache__/networks.cpython-312.pyc differ diff --git a/ISMIR_2025/MERT/__pycache__/networks.cpython-39.pyc b/ISMIR_2025/MERT/__pycache__/networks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f946ae8f730fcc1d6ee2ad0a2a7f7a50978f3ed Binary files /dev/null and b/ISMIR_2025/MERT/__pycache__/networks.cpython-39.pyc differ diff --git a/ISMIR_2025/MERT/datalib.py b/ISMIR_2025/MERT/datalib.py new file mode 100644 index 0000000000000000000000000000000000000000..96a74e056c54a2e8cc7ced35fe63f2ce7d58d54f --- /dev/null +++ b/ISMIR_2025/MERT/datalib.py @@ -0,0 +1,203 @@ +import os +import glob +import torch +import torchaudio +import librosa +import numpy as np +from sklearn.model_selection import train_test_split +from torch.utils.data import Dataset +from imblearn.over_sampling import RandomOverSampler +from transformers import Wav2Vec2Processor +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence +from transformers import Wav2Vec2FeatureExtractor +import scipy.signal as signal +import scipy.signal +# class FakeMusicCapsDataset(Dataset): +# def __init__(self, file_paths, labels, sr=16000, target_duration=10.0): +# self.file_paths = file_paths +# self.labels = labels +# self.sr = sr +# self.target_samples = int(target_duration * sr) # Fixed length: 5 seconds + +# def __len__(self): +# return len(self.file_paths) + +# def __getitem__(self, idx): +# audio_path = self.file_paths[idx] +# label = self.labels[idx] + +# waveform, sr = torchaudio.load(audio_path) +# waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sr)(waveform) +# waveform = waveform.mean(dim=0) # Convert to mono +# waveform = waveform.squeeze(0) + + +# current_samples = waveform.shape[0] + +# # **Ensure waveform is exactly `target_samples` long** +# if current_samples > self.target_samples: +# waveform = waveform[:self.target_samples] # Truncate if too long +# elif current_samples < self.target_samples: +# pad_length = self.target_samples - current_samples +# waveform = torch.nn.functional.pad(waveform, (0, pad_length)) # Pad if too short + +# return waveform.unsqueeze(0), torch.tensor(label, dtype=torch.long) # Ensure 2D shape (1, target_samples) + +class FakeMusicCapsDataset(Dataset): + def __init__(self, file_paths, labels, sr=16000, target_duration=10.0): + self.file_paths = file_paths + self.labels = labels + self.sr = sr + self.target_samples = int(target_duration * sr) # Fixed length: 10 seconds + self.processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True) + + def __len__(self): + return len(self.file_paths) + + def highpass_filter(self, y, sr, cutoff=500, order=5): + if isinstance(sr, np.ndarray): + # print(f"[ERROR] sr is an array, taking mean value. Original sr: {sr}") + sr = np.mean(sr) + if not isinstance(sr, (int, float)): + raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}") + # print(f"[DEBUG] Highpass filter using sr={sr}, cutoff={cutoff}") + if sr <= 0: + raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.") + nyquist = 0.5 * sr + # print(f"[DEBUG] Nyquist frequency={nyquist}") + if cutoff <= 0 or cutoff >= nyquist: + print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...") + cutoff = max(10, min(cutoff, nyquist - 1)) + normal_cutoff = cutoff / nyquist + # print(f"[DEBUG] Adjusted cutoff={cutoff}, normal_cutoff={normal_cutoff}") + b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) + y_filtered = signal.lfilter(b, a, y) + return y_filtered + + def __getitem__(self, idx): + audio_path = self.file_paths[idx] + label = self.labels[idx] + + waveform, sr = torchaudio.load(audio_path) + + target_sr = self.processor.sampling_rate + + if sr != target_sr: + resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr) + waveform = resampler(waveform) + + waveform = waveform.mean(dim=0).squeeze(0) # [Time] + + if label == 1: + waveform = self.highpass_filter(waveform, self.sr) + + current_samples = waveform.shape[0] + if current_samples > self.target_samples: + waveform = waveform[:self.target_samples] # Truncate + elif current_samples < self.target_samples: + pad_length = self.target_samples - current_samples + waveform = torch.nn.functional.pad(waveform, (0, pad_length)) # Pad + + if isinstance(waveform, torch.Tensor): + waveform = waveform.numpy() # Tensor일 경우에만 변환 + + inputs = self.processor(waveform, sampling_rate=target_sr, return_tensors="pt", padding=True) + + return inputs["input_values"].squeeze(0), torch.tensor(label, dtype=torch.long) # [1, time] → [time] + + @staticmethod + def collate_fn(batch, target_samples=16000 * 10): + + inputs, labels = zip(*batch) # Unzip batch + + processed_inputs = [] + for waveform in inputs: + current_samples = waveform.shape[0] + + if current_samples > target_samples: + start_idx = (current_samples - target_samples) // 2 + cropped_waveform = waveform[start_idx:start_idx + target_samples] + else: + pad_length = target_samples - current_samples + cropped_waveform = torch.nn.functional.pad(waveform, (0, pad_length)) + + processed_inputs.append(cropped_waveform) + + processed_inputs = torch.stack(processed_inputs) # [batch, target_samples] + labels = torch.tensor(labels, dtype=torch.long) # [batch] + + return processed_inputs, labels + + def preprocess_audio(audio_path, target_sr=16000, max_length=160000): + """ + 오디오를 모델 입력에 맞게 변환 + - target_sr: 16kHz로 변환 + - max_length: 최대 길이 160000 (10초) + """ + waveform, sr = torchaudio.load(audio_path) + + # Resample if needed + if sr != target_sr: + waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform) + + # Convert to mono + waveform = waveform.mean(dim=0).unsqueeze(0) # (1, sequence_length) + + current_samples = waveform.shape[1] + if current_samples > max_length: + start_idx = (current_samples - max_length) // 2 + waveform = waveform[:, start_idx:start_idx + max_length] + elif current_samples < max_length: + pad_length = max_length - current_samples + waveform = torch.nn.functional.pad(waveform, (0, pad_length)) + + return waveform + + +DATASET_PATH = "/data/kym/AI_Music_Detection/audio/FakeMusicCaps" +SUNOCAPS_PATH = "/data/kym/Audio/SunoCaps" # Open Set 포함 데이터 + +# Closed Test: FakeMusicCaps 데이터셋 사용 +real_files = glob.glob(os.path.join(DATASET_PATH, "real", "**", "*.wav"), recursive=True) +gen_files = glob.glob(os.path.join(DATASET_PATH, "generative", "**", "*.wav"), recursive=True) + +# Open Set Test: SUNOCAPS_PATH 데이터 포함 +open_real_files = real_files + glob.glob(os.path.join(SUNOCAPS_PATH, "real", "**", "*.wav"), recursive=True) +open_gen_files = gen_files + glob.glob(os.path.join(SUNOCAPS_PATH, "generative", "**", "*.wav"), recursive=True) + +real_labels = [0] * len(real_files) +gen_labels = [1] * len(gen_files) + +open_real_labels = [0] * len(open_real_files) +open_gen_labels = [1] * len(open_gen_files) + +# Closed Train, Val +real_train, real_val, real_train_labels, real_val_labels = train_test_split(real_files, real_labels, test_size=0.2, random_state=42) +gen_train, gen_val, gen_train_labels, gen_val_labels = train_test_split(gen_files, gen_labels, test_size=0.2, random_state=42) + +train_files = real_train + gen_train +train_labels = real_train_labels + gen_train_labels +val_files = real_val + gen_val +val_labels = real_val_labels + gen_val_labels + +# Closed Set Test용 데이터셋 +closed_test_files = real_files + gen_files +closed_test_labels = real_labels + gen_labels + +# Open Set Test용 데이터셋 +open_test_files = open_real_files + open_gen_files +open_test_labels = open_real_labels + open_gen_labels + +# Oversampling 적용 +ros = RandomOverSampler(sampling_strategy='auto', random_state=42) +train_files_resampled, train_labels_resampled = ros.fit_resample(np.array(train_files).reshape(-1, 1), train_labels) + +train_files = train_files_resampled.reshape(-1).tolist() +train_labels = train_labels_resampled + +print(f"📌 Train Original FAKE: {len(gen_train)}") +print(f"📌 Train set (Oversampled) - REAL: {sum(1 for label in train_labels if label == 0)}, " + f"FAKE: {sum(1 for label in train_labels if label == 1)}, Total: {len(train_files)}") +print(f"📌 Validation set - REAL: {len(real_val)}, FAKE: {len(gen_val)}, Total: {len(val_files)}") diff --git a/ISMIR_2025/MERT/main.py b/ISMIR_2025/MERT/main.py new file mode 100644 index 0000000000000000000000000000000000000000..2992ffcb72eea51d096f2f40fca04c524d309cd3 --- /dev/null +++ b/ISMIR_2025/MERT/main.py @@ -0,0 +1,197 @@ +import os +import random +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm +from torch.utils.data import DataLoader +from sklearn.metrics import f1_score, precision_score, recall_score, balanced_accuracy_score +import wandb +import argparse +from transformers import AutoModel, AutoConfig, Wav2Vec2FeatureExtractor +from ISMIR_2025.MERT.datalib import FakeMusicCapsDataset, train_files, train_labels, val_files, val_labels +from ISMIR_2025.MERT.networks import MERTFeatureExtractor +# Set device +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Seed for reproducibility +torch.manual_seed(42) +random.seed(42) +np.random.seed(42) + +# Initialize wandb +wandb.init(project="mert", name=f"hpfilter_pretrain_{args.pretrain_epochs}_finetune_{args.finetune_epochs}", config=args) + +# Load datasets +print("🔍 Preparing datasets...") +train_dataset = FakeMusicCapsDataset(train_files, train_labels) +val_dataset = FakeMusicCapsDataset(val_files, val_labels) + +train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, collate_fn=FakeMusicCapsDataset.collate_fn) +val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, collate_fn=FakeMusicCapsDataset.collate_fn) + +# Model Checkpoint Paths +pretrain_ckpt = os.path.join(args.checkpoint_dir, f"mert_pretrain_{args.pretrain_epochs}.pth") +finetune_ckpt = os.path.join(args.checkpoint_dir, f"mert_finetune_{args.finetune_epochs}.pth") + +# Load Music2Vec Model for Pretraining +print("🔍 Initializing MERT model for Pretraining...") + +config = AutoConfig.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True) +if not hasattr(config, "conv_pos_batch_norm"): + setattr(config, "conv_pos_batch_norm", False) + +mert_model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", config=config, trust_remote_code=True).to(device) +mert_model = MERTFeatureExtractor().to(device) + +# Loss and Optimizer +criterion = nn.CrossEntropyLoss() +optimizer = optim.Adam(mert_model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) + +# Training function +def train(model, dataloader, optimizer, criterion, device, epoch, phase="Pretrain"): + model.train() + total_loss, total_correct, total_samples = 0, 0, 0 + all_preds, all_labels = [], [] + + for inputs, labels in tqdm(dataloader, desc=f"{phase} Training Epoch {epoch+1}"): + labels = labels.to(device) + inputs = inputs.to(device) + + # inputs = inputs.float() + # output = model(inputs) + output = model(inputs) + + # Check if the output is a tensor or an object with logits + if isinstance(output, torch.Tensor): + logits = output + elif hasattr(output, "logits"): + logits = output.logits + elif isinstance(output, (tuple, list)): + logits = output[0] + else: + raise ValueError("Unexpected model output type") + + loss = criterion(logits, labels) + + + # loss = criterion(output, labels) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + preds = output.argmax(dim=1) + total_correct += (preds == labels).sum().item() + total_samples += labels.size(0) + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + + scheduler.step() + + accuracy = total_correct / total_samples + f1 = f1_score(all_labels, all_preds, average="binary") + precision = precision_score(all_labels, all_preds, average="binary") + recall = recall_score(all_labels, all_preds, average="binary", pos_label=1) + balanced_acc = balanced_accuracy_score(all_labels, all_preds) + + + wandb.log({ + f"{phase} Train Loss": total_loss / len(dataloader), + f"{phase} Train Accuracy": accuracy, + f"{phase} Train F1 Score": f1, + f"{phase} Train Precision": precision, + f"{phase} Train Recall": recall, + f"{phase} Train Balanced Accuracy": balanced_acc, + }) + + print(f"{phase} Train Epoch {epoch+1}: Train Loss: {total_loss / len(dataloader):.4f}, " + f"Train Acc: {accuracy:.4f}, Train F1: {f1:.4f}, Train Prec: {precision:.4f}, Train Rec: {recall:.4f}, B_ACC: {balanced_acc:.4f}") + +def validate(model, dataloader, optimizer, criterion, device, epoch, phase="Validation"): + model.eval() + total_loss, total_correct, total_samples = 0, 0, 0 + all_preds, all_labels = [], [] + + for inputs, labels in tqdm(dataloader, desc=f"{phase} Validation Epoch {epoch+1}"): + labels = labels.to(device) + inputs = inputs.to(device) + + output = model(inputs) + + # Check if the output is a tensor or an object with logits + if isinstance(output, torch.Tensor): + logits = output + elif hasattr(output, "logits"): + logits = output.logits + elif isinstance(output, (tuple, list)): + logits = output[0] + else: + raise ValueError("Unexpected model output type") + + loss = criterion(logits, labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + preds = outputs.argmax(dim=1) + total_correct += (preds == labels).sum().item() + total_samples += labels.size(0) + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + scheduler.step() + accuracy = total_correct / total_samples + val_f1 = f1_score(all_labels, all_preds, average="weighted") + val_precision = precision_score(all_labels, all_preds, average="binary") + val_recall = recall_score(all_labels, all_preds, average="binary") + val_bal_acc = balanced_accuracy_score(all_labels, all_preds) + + wandb.log({ + f"{phase} Val Loss": total_loss / len(dataloader), + f"{phase} Val Accuracy": accuracy, + f"{phase} Val F1 Score": val_f1, + f"{phase} Val Precision": val_precision, + f"{phase} Val Recall": val_recall, + f"{phase} Val Balanced Accuracy": val_bal_acc, + }) + print(f"{phase} Val Loss: {total_loss / len(dataloader):.4f}, " + f"Val Acc: {accuracy:.4f}, Val F1: {val_f1:.4f}, Val Prec: {val_precision:.4f}, Val Rec: {val_recall:.4f}, Val B_ACC: {val_bal_acc:.4f}") + return total_loss / len(dataloader), accuracy, val_f1 + + +print("\n🔍 Step 1: Self-Supervised Pretraining on REAL Data") +# for epoch in range(args.pretrain_epochs): +# train(mert_model, train_loader, optimizer, criterion, device, epoch, phase="Pretrain") +# torch.save(mert_model.state_dict(), pretrain_ckpt) +# print(f"\nPretraining completed! Model saved at: {pretrain_ckpt}") + +# print("\n🔍 Initializing CCV Model for Fine-Tuning...") +# mert_model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", config=config, trust_remote_code=True).to(device) +# mert_model.feature_extractor.load_state_dict(torch.load(pretrain_ckpt), strict=False) + +# optimizer = optim.Adam(mert_model.parameters(), lr=args.finetune_lr, weight_decay=args.weight_decay) +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) + +print("\n🔍 Step 2: Fine-Tuning CCV Model") +for epoch in range(args.finetune_epochs): + train(mert_model, train_loader, optimizer, criterion, device, epoch, phase="Fine-Tune") + +torch.save(mert_model.state_dict(), finetune_ckpt) +print(f"\nFine-Tuning completed! Model saved at: {finetune_ckpt}") + +print("\n🔍 Step 2: Fine-Tuning MERT Model") +mert_model.load_state_dict(torch.load(pretrain_ckpt), strict=False) + +optimizer = optim.Adam(mert_model.parameters(), lr=args.finetune_lr, weight_decay=args.weight_decay) +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) + +for epoch in range(args.finetune_epochs): + train(mert_model, train_loader, optimizer, criterion, device, epoch, phase="Fine-Tune") + +torch.save(mert_model.state_dict(), finetune_ckpt) +print(f"\nFine-Tuning completed! Model saved at: {finetune_ckpt}") \ No newline at end of file diff --git a/ISMIR_2025/MERT/networks.py b/ISMIR_2025/MERT/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..2e985e909ccd7b63f5a44bc82d7e368903e92c07 --- /dev/null +++ b/ISMIR_2025/MERT/networks.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn +from transformers import AutoModel, AutoConfig + +class MERTFeatureExtractor(nn.Module): + def __init__(self, freeze_feature_extractor=True): + super(MERTFeatureExtractor, self).__init__() + config = AutoConfig.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True) + if not hasattr(config, "conv_pos_batch_norm"): + setattr(config, "conv_pos_batch_norm", False) + self.mert = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", config=config, trust_remote_code=True) + + if freeze_feature_extractor: + self.freeze() + + def forward(self, input_values): + # 입력: [batch, time] + # 사전학습된 MERT의 hidden_states 추출 (예시로 모든 레이어의 hidden state 사용) + with torch.no_grad(): + outputs = self.mert(input_values, output_hidden_states=True) + # hidden_states: tuple of [batch, time, feature_dim] + # 여러 레이어의 hidden state를 스택한 뒤 시간축에 대해 평균하여 feature를 얻음 + hidden_states = torch.stack(outputs.hidden_states) # [num_layers, batch, time, feature_dim] + hidden_states = hidden_states.detach().clone().requires_grad_(True) + time_reduced = hidden_states.mean(dim=2) # [num_layers, batch, feature_dim] + time_reduced = time_reduced.permute(1, 0, 2) # [batch, num_layers, feature_dim] + return time_reduced + + def freeze(self): + for param in self.mert.parameters(): + param.requires_grad = False + + def unfreeze(self): + for param in self.mert.parameters(): + param.requires_grad = True + + +class CrossAttentionLayer(nn.Module): + def __init__(self, embed_dim, num_heads): + super(CrossAttentionLayer, self).__init__() + self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) + self.layer_norm1 = nn.LayerNorm(embed_dim) + self.layer_norm2 = nn.LayerNorm(embed_dim) + self.feed_forward = nn.Sequential( + nn.Linear(embed_dim, embed_dim * 4), + nn.ReLU(), + nn.Linear(embed_dim * 4, embed_dim) + ) + + def forward(self, x, cross_input): + # x와 cross_input 간의 어텐션 수행 + attn_output, _ = self.multihead_attn(query=x, key=cross_input, value=cross_input) + x = self.layer_norm1(x + attn_output) + ff_output = self.feed_forward(x) + x = self.layer_norm2(x + ff_output) + return x + + +class CCV(nn.Module): + def __init__(self, embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True): + super(CCV, self).__init__() + # MERT 기반 feature extractor (pretraining weight로부터 유의미한 피쳐 추출) + self.feature_extractor = MERTFeatureExtractor(freeze_feature_extractor=freeze_feature_extractor) + # Cross-Attention 레이어 여러 층 + self.cross_attention_layers = nn.ModuleList([ + CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers) + ]) + # Transformer Encoder (배치 차원 고려) + encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + # 분류기 + self.classifier = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, 256), + nn.BatchNorm1d(256), + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(256, num_classes) + ) + + + def forward(self, input_values): + """ + input_values: Tensor [batch, time] + 1. MERT로부터 feature 추출 → [batch, num_layers, feature_dim] + 2. 임베딩 차원 맞추기 위해 transpose → [batch, feature_dim, num_layers] + 3. Cross-Attention 적용 + 4. Transformer Encoding 후 평균 풀링 + 5. 분류기 통과하여 최종 출력(logits) 반환 + """ + features = self.feature_extractor(input_values) # [batch, num_layers, feature_dim] + # embed_dim는 보통 feature_dim과 동일하게 맞춤 (예시: 768) + # features = features.permute(0, 2, 1) # [batch, embed_dim, num_layers] + + # Cross-Attention 적용 (여기서는 자기자신과의 어텐션으로 예시) + for layer in self.cross_attention_layers: + features = layer(features, features) + + # Transformer Encoder를 위해 시간 축(여기서는 num_layers 축)에 대해 평균 + features = features.mean(dim=1).unsqueeze(1) # [batch, 1, embed_dim] + encoded = self.transformer(features) # [batch, 1, embed_dim] + encoded = encoded.mean(dim=1) # [batch, embed_dim] + output = self.classifier(encoded) # [batch, num_classes] + return output, encoded + + def unfreeze_feature_extractor(self): + self.feature_extractor.unfreeze() diff --git a/ISMIR_2025/MERT/test.py b/ISMIR_2025/MERT/test.py new file mode 100644 index 0000000000000000000000000000000000000000..0dba71d3b798ffd56b0afc123e659dca49a5f024 --- /dev/null +++ b/ISMIR_2025/MERT/test.py @@ -0,0 +1,114 @@ +import os +import torch +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt +from torch.utils.data import DataLoader +from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, confusion_matrix +from datalib import ( + FakeMusicCapsDataset, + closed_test_files, closed_test_labels, + open_test_files, open_test_labels, + val_files, val_labels +) +from networks import MERTFeatureExtractor +import argparse +parser = argparse.ArgumentParser(description="AI Music Detection Testing with MERT") +parser.add_argument('--gpu', type=str, default='1', help='GPU ID') +parser.add_argument('--batch_size', type=int, default=32, help='Batch size') +parser.add_argument('--ckpt_path', type=str, default="/data/kym/AI_Music_Detection/Code/model/MERT/ckpt/1e-3/mert_finetune_10.pth", help='Path to the pretrained checkpoint') +parser.add_argument('--model_name', type=str, default="mert", help="Model name") +parser.add_argument('--closed_test', action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)") +parser.add_argument('--open_test', action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)") +parser.add_argument('--output_path', type=str, default='', help='Path to save test results') + +args = parser.parse_args() +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def plot_confusion_matrix(y_true, y_pred, classes, output_path): + cm = confusion_matrix(y_true, y_pred) + fig, ax = plt.subplots(figsize=(6, 6)) + im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) + ax.figure.colorbar(im, ax=ax) + + num_classes = cm.shape[0] + tick_labels = classes[:num_classes] + + ax.set(xticks=np.arange(num_classes), + yticks=np.arange(num_classes), + xticklabels=tick_labels, + yticklabels=tick_labels, + ylabel='True label', + xlabel='Predicted label') + + thresh = cm.max() / 2. + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cm[i, j], 'd'), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + + fig.tight_layout() + plt.savefig(output_path) + plt.close(fig) + +model = MERTFeatureExtractor().to(device) + +ckpt_file = args.ckpt_path +if not os.path.exists(ckpt_file): + raise FileNotFoundError(f"Checkpoint not found: {ckpt_file}") +print(f"\nLoading MERT model from {ckpt_file}") +model.load_state_dict(torch.load(ckpt_file, map_location=device)) +model.eval() + +torch.cuda.empty_cache() + +if args.closed_test: + print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...") + test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, target_duration=10.0) +elif args.open_test: + print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...") + test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, target_duration=10.0) +else: + print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...") + test_dataset = FakeMusicCapsDataset(val_files, val_labels, target_duration=10.0) + +test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8) + +def test_mert(model, test_loader, device): + model.eval() + test_loss, test_correct, test_total = 0, 0, 0 + all_preds, all_labels = [], [] + + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + loss = F.cross_entropy(output, target) + + test_loss += loss.item() * data.size(0) + preds = output.argmax(dim=1) + test_correct += (preds == target).sum().item() + test_total += target.size(0) + + all_labels.extend(target.cpu().numpy()) + all_preds.extend(preds.cpu().numpy()) + + test_loss /= test_total + test_acc = test_correct / test_total + test_bal_acc = balanced_accuracy_score(all_labels, all_preds) + test_precision = precision_score(all_labels, all_preds, average="binary") + test_recall = recall_score(all_labels, all_preds, average="binary") + test_f1 = f1_score(all_labels, all_preds, average="binary") + + print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | " + f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | " + f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}") + + os.makedirs(args.output_path, exist_ok=True) + conf_matrix_path = os.path.join(args.output_path, f"confusion_matrix_{args.model_name}.png") + plot_confusion_matrix(all_labels, all_preds, classes=["real", "generative"], output_path=conf_matrix_path) + +print("\nEvaluating MERT Model on Test Set...") +test_mert(model, test_loader, device) diff --git a/ISMIR_2025/MERT/utils/__pycache__/config.cpython-311.pyc b/ISMIR_2025/MERT/utils/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5105baa6b5cc20b8d2fe0a9b029b40c9e6f7c6af Binary files /dev/null and b/ISMIR_2025/MERT/utils/__pycache__/config.cpython-311.pyc differ diff --git a/ISMIR_2025/MERT/utils/__pycache__/idr_torch.cpython-311.pyc b/ISMIR_2025/MERT/utils/__pycache__/idr_torch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb9379cca8c0e244eb2a7b169fa2ce307f2bab05 Binary files /dev/null and b/ISMIR_2025/MERT/utils/__pycache__/idr_torch.cpython-311.pyc differ diff --git a/ISMIR_2025/MERT/utils/__pycache__/utilities.cpython-311.pyc b/ISMIR_2025/MERT/utils/__pycache__/utilities.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97ed2e09ebbddb16a923974a1ef3c30f2259ac34 Binary files /dev/null and b/ISMIR_2025/MERT/utils/__pycache__/utilities.cpython-311.pyc differ diff --git a/ISMIR_2025/MERT/utils/config.py b/ISMIR_2025/MERT/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..69f72ecd472eed266bb9a0d811d7eeb07a3c06db --- /dev/null +++ b/ISMIR_2025/MERT/utils/config.py @@ -0,0 +1,565 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import csv + +import numpy as np + +sample_rate = 32000 +clip_samples = sample_rate * 10 # Audio clips are 10-second + +# Load label +with open( + "/gpfswork/rech/djl/uzj43um/audio_retrieval/audioset_tagging_cnn/metadata/class_labels_indices.csv", + "r", +) as f: + reader = csv.reader(f, delimiter=",") + lines = list(reader) + +labels = [] +ids = [] # Each label has a unique id such as "/m/068hy" +for i1 in range(1, len(lines)): + id = lines[i1][1] + label = lines[i1][2] + ids.append(id) + labels.append(label) + +classes_num = len(labels) + +lb_to_ix = {label: i for i, label in enumerate(labels)} +ix_to_lb = {i: label for i, label in enumerate(labels)} + +id_to_ix = {id: i for i, id in enumerate(ids)} +ix_to_id = {i: id for i, id in enumerate(ids)} + +full_samples_per_class = np.array( + [ + 937432, + 16344, + 7822, + 10271, + 2043, + 14420, + 733, + 1511, + 1258, + 424, + 1751, + 704, + 369, + 590, + 1063, + 1375, + 5026, + 743, + 853, + 1648, + 714, + 1497, + 1251, + 2139, + 1093, + 133, + 224, + 39469, + 6423, + 407, + 1559, + 4546, + 6826, + 7464, + 2468, + 549, + 4063, + 334, + 587, + 238, + 1766, + 691, + 114, + 2153, + 236, + 209, + 421, + 740, + 269, + 959, + 137, + 4192, + 485, + 1515, + 655, + 274, + 69, + 157, + 1128, + 807, + 1022, + 346, + 98, + 680, + 890, + 352, + 4169, + 2061, + 1753, + 9883, + 1339, + 708, + 37857, + 18504, + 12864, + 2475, + 2182, + 757, + 3624, + 677, + 1683, + 3583, + 444, + 1780, + 2364, + 409, + 4060, + 3097, + 3143, + 502, + 723, + 600, + 230, + 852, + 1498, + 1865, + 1879, + 2429, + 5498, + 5430, + 2139, + 1761, + 1051, + 831, + 2401, + 2258, + 1672, + 1711, + 987, + 646, + 794, + 25061, + 5792, + 4256, + 96, + 8126, + 2740, + 752, + 513, + 554, + 106, + 254, + 1592, + 556, + 331, + 615, + 2841, + 737, + 265, + 1349, + 358, + 1731, + 1115, + 295, + 1070, + 972, + 174, + 937780, + 112337, + 42509, + 49200, + 11415, + 6092, + 13851, + 2665, + 1678, + 13344, + 2329, + 1415, + 2244, + 1099, + 5024, + 9872, + 10948, + 4409, + 2732, + 1211, + 1289, + 4807, + 5136, + 1867, + 16134, + 14519, + 3086, + 19261, + 6499, + 4273, + 2790, + 8820, + 1228, + 1575, + 4420, + 3685, + 2019, + 664, + 324, + 513, + 411, + 436, + 2997, + 5162, + 3806, + 1389, + 899, + 8088, + 7004, + 1105, + 3633, + 2621, + 9753, + 1082, + 26854, + 3415, + 4991, + 2129, + 5546, + 4489, + 2850, + 1977, + 1908, + 1719, + 1106, + 1049, + 152, + 136, + 802, + 488, + 592, + 2081, + 2712, + 1665, + 1128, + 250, + 544, + 789, + 2715, + 8063, + 7056, + 2267, + 8034, + 6092, + 3815, + 1833, + 3277, + 8813, + 2111, + 4662, + 2678, + 2954, + 5227, + 1472, + 2591, + 3714, + 1974, + 1795, + 4680, + 3751, + 6585, + 2109, + 36617, + 6083, + 16264, + 17351, + 3449, + 5034, + 3931, + 2599, + 4134, + 3892, + 2334, + 2211, + 4516, + 2766, + 2862, + 3422, + 1788, + 2544, + 2403, + 2892, + 4042, + 3460, + 1516, + 1972, + 1563, + 1579, + 2776, + 1647, + 4535, + 3921, + 1261, + 6074, + 2922, + 3068, + 1948, + 4407, + 712, + 1294, + 1019, + 1572, + 3764, + 5218, + 975, + 1539, + 6376, + 1606, + 6091, + 1138, + 1169, + 7925, + 3136, + 1108, + 2677, + 2680, + 1383, + 3144, + 2653, + 1986, + 1800, + 1308, + 1344, + 122231, + 12977, + 2552, + 2678, + 7824, + 768, + 8587, + 39503, + 3474, + 661, + 430, + 193, + 1405, + 1442, + 3588, + 6280, + 10515, + 785, + 710, + 305, + 206, + 4990, + 5329, + 3398, + 1771, + 3022, + 6907, + 1523, + 8588, + 12203, + 666, + 2113, + 7916, + 434, + 1636, + 5185, + 1062, + 664, + 952, + 3490, + 2811, + 2749, + 2848, + 15555, + 363, + 117, + 1494, + 1647, + 5886, + 4021, + 633, + 1013, + 5951, + 11343, + 2324, + 243, + 372, + 943, + 734, + 242, + 3161, + 122, + 127, + 201, + 1654, + 768, + 134, + 1467, + 642, + 1148, + 2156, + 1368, + 1176, + 302, + 1909, + 61, + 223, + 1812, + 287, + 422, + 311, + 228, + 748, + 230, + 1876, + 539, + 1814, + 737, + 689, + 1140, + 591, + 943, + 353, + 289, + 198, + 490, + 7938, + 1841, + 850, + 457, + 814, + 146, + 551, + 728, + 1627, + 620, + 648, + 1621, + 2731, + 535, + 88, + 1736, + 736, + 328, + 293, + 3170, + 344, + 384, + 7640, + 433, + 215, + 715, + 626, + 128, + 3059, + 1833, + 2069, + 3732, + 1640, + 1508, + 836, + 567, + 2837, + 1151, + 2068, + 695, + 1494, + 3173, + 364, + 88, + 188, + 740, + 677, + 273, + 1533, + 821, + 1091, + 293, + 647, + 318, + 1202, + 328, + 532, + 2847, + 526, + 721, + 370, + 258, + 956, + 1269, + 1641, + 339, + 1322, + 4485, + 286, + 1874, + 277, + 757, + 1393, + 1330, + 380, + 146, + 377, + 394, + 318, + 339, + 1477, + 1886, + 101, + 1435, + 284, + 1425, + 686, + 621, + 221, + 117, + 87, + 1340, + 201, + 1243, + 1222, + 651, + 1899, + 421, + 712, + 1016, + 1279, + 124, + 351, + 258, + 7043, + 368, + 666, + 162, + 7664, + 137, + 70159, + 26179, + 6321, + 32236, + 33320, + 771, + 1169, + 269, + 1103, + 444, + 364, + 2710, + 121, + 751, + 1609, + 855, + 1141, + 2287, + 1940, + 3943, + 289, + ] +) \ No newline at end of file diff --git a/ISMIR_2025/MERT/utils/confusion_matrix_plot.py b/ISMIR_2025/MERT/utils/confusion_matrix_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..e57d6d77e51949970ea76d8400d78ed6540cc155 --- /dev/null +++ b/ISMIR_2025/MERT/utils/confusion_matrix_plot.py @@ -0,0 +1,29 @@ +from sklearn.metrics import confusion_matrix +import matplotlib.pyplot as plt +import numpy as np + +def plot_confusion_matrix(y_true, y_pred, classes, writer, epoch): + cm = confusion_matrix(y_true, y_pred) + fig, ax = plt.subplots(figsize=(6, 6)) + im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) + ax.figure.colorbar(im, ax=ax) + + num_classes = cm.shape[0] + tick_labels = classes[:num_classes] + + ax.set(xticks=np.arange(num_classes), + yticks=np.arange(num_classes), + xticklabels=tick_labels, + yticklabels=tick_labels, + ylabel='True label', + xlabel='Predicted label') + + thresh = cm.max() / 2. + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cm[i, j], 'd'), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + + fig.tight_layout() + writer.add_figure("Confusion Matrix", fig, epoch) \ No newline at end of file diff --git a/ISMIR_2025/MERT/utils/freqeuncy.py b/ISMIR_2025/MERT/utils/freqeuncy.py new file mode 100644 index 0000000000000000000000000000000000000000..b21c5222467ec4906c63e5b9d02052a69aeb67e2 --- /dev/null +++ b/ISMIR_2025/MERT/utils/freqeuncy.py @@ -0,0 +1,24 @@ +import librosa +import librosa.display +import numpy as np +import matplotlib.pyplot as plt + +# 🔹 오디오 파일 로드 +file_real = "/path/to/real_audio.wav" # Real 오디오 경로 +file_fake = "/path/to/generative_audio.wav" # AI 생성 오디오 경로 + +def plot_spectrogram(audio_file, title): + y, sr = librosa.load(audio_file, sr=16000) # 샘플링 레이트 16kHz + D = librosa.amplitude_to_db(np.abs(librosa.stft(y)), ref=np.max) # STFT 변환 + + plt.figure(figsize=(10, 4)) + librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='hz', cmap='magma') + plt.colorbar(format='%+2.0f dB') + plt.title(title) + plt.ylim(4000, 16000) # 4kHz 이상 고주파 영역만 표시 + plt.show() + +# 🔹 Real vs Generative Spectrogram 비교 +plot_spectrogram(file_real, "Real Audio Spectrogram (4kHz+)") +plot_spectrogram(file_fake, "Generative Audio Spectrogram (4kHz+)") + diff --git a/ISMIR_2025/MERT/utils/hf_vis.py b/ISMIR_2025/MERT/utils/hf_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..c99b61bfb27f99880b0c44313daf476e6c0c278f --- /dev/null +++ b/ISMIR_2025/MERT/utils/hf_vis.py @@ -0,0 +1,89 @@ +import librosa +import librosa.display +import numpy as np +import matplotlib.pyplot as plt +import scipy.signal as signal +import torch +import torch.nn as nn +import soundfile as sf + +from networks import audiocnn, AudioCNNWithViTDecoder, AudioCNNWithViTDecoderAndCrossAttention + + +def highpass_filter(y, sr, cutoff=500, order=5): + """High-pass filter to remove low frequencies below `cutoff` Hz.""" + nyquist = 0.5 * sr + normal_cutoff = cutoff / nyquist + b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) + y_filtered = signal.lfilter(b, a, y) + return y_filtered + +def plot_combined_visualization(y_original, y_filtered, sr, save_path="combined_visualization.png"): + """Plot waveform comparison and spectrograms in a single figure.""" + fig, axes = plt.subplots(3, 1, figsize=(12, 12)) + + # 1️⃣ Waveform Comparison + time = np.linspace(0, len(y_original) / sr, len(y_original)) + axes[0].plot(time, y_original, label='Original', alpha=0.7) + axes[0].plot(time, y_filtered, label='High-pass Filtered', alpha=0.7, linestyle='dashed') + axes[0].set_xlabel("Time (s)") + axes[0].set_ylabel("Amplitude") + axes[0].set_title("Waveform Comparison (Original vs High-pass Filtered)") + axes[0].legend() + + # 2️⃣ Spectrogram - Original + S_orig = librosa.amplitude_to_db(np.abs(librosa.stft(y_original)), ref=np.max) + img = librosa.display.specshow(S_orig, sr=sr, x_axis='time', y_axis='log', ax=axes[1]) + axes[1].set_title("Original Spectrogram") + fig.colorbar(img, ax=axes[1], format="%+2.0f dB") + + # 3️⃣ Spectrogram - High-pass Filtered + S_filt = librosa.amplitude_to_db(np.abs(librosa.stft(y_filtered)), ref=np.max) + img = librosa.display.specshow(S_filt, sr=sr, x_axis='time', y_axis='log', ax=axes[2]) + axes[2].set_title("High-pass Filtered Spectrogram") + fig.colorbar(img, ax=axes[2], format="%+2.0f dB") + + plt.tight_layout() + plt.savefig(save_path, dpi=300) + plt.show() + + +def load_model(checkpoint_path, model_class, device): + """Load a trained model from checkpoint.""" + model = model_class() + model.load_state_dict(torch.load(checkpoint_path, map_location=device)) + model.to(device) + model.eval() + return model + +def predict_audio(model, audio_tensor, device): + """Make predictions using a trained model.""" + with torch.no_grad(): + audio_tensor = audio_tensor.unsqueeze(0).to(device) # Add batch dimension + output = model(audio_tensor) + prediction = torch.argmax(output, dim=1).cpu().numpy()[0] + return prediction + +# Load audio +audio_path = "/data/kym/AI Music Detection/audio/FakeMusicCaps/real/musiccaps/_RrA-0lfIiU.wav" # Replace with actual file path +y, sr = librosa.load(audio_path, sr=None) +y_filtered = highpass_filter(y, sr, cutoff=500) + +# Convert audio to tensor +audio_tensor = torch.tensor(librosa.feature.melspectrogram(y=y, sr=sr), dtype=torch.float).unsqueeze(0) +audio_tensor_filtered = torch.tensor(librosa.feature.melspectrogram(y=y_filtered, sr=sr), dtype=torch.float).unsqueeze(0) + +# Load models +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +original_model = load_model("/data/kym/AI Music Detection/AudioCNN/ckpt/FakeMusicCaps/pretraining/best_model_audiocnn.pth", audiocnn, device) +highpass_model = load_model("/data/kym/AI Music Detection/AudioCNN/ckpt/FakeMusicCaps/500hz_Add_crossattn_decoder/best_model_AudioCNNWithViTDecoderAndCrossAttention.pth", AudioCNNWithViTDecoderAndCrossAttention, device) + +# Predict +original_pred = predict_audio(original_model, audio_tensor, device) +highpass_pred = predict_audio(highpass_model, audio_tensor_filtered, device) + +print(f"Original Model Prediction: {original_pred}") +print(f"High-pass Filter Model Prediction: {highpass_pred}") + +# Generate combined visualization (all plots in one image) +plot_combined_visualization(y, y_filtered, sr, save_path="/data/kym/AI Music Detection/AudioCNN/hf_vis/rawvs500.png") diff --git a/ISMIR_2025/MERT/utils/idr_torch.py b/ISMIR_2025/MERT/utils/idr_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..b2e76040394ce27390c27bd8ef022e126d8e55dc --- /dev/null +++ b/ISMIR_2025/MERT/utils/idr_torch.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# coding: utf-8 + +import os +import hostlist + +# get SLURM variables +# rank = int(os.environ["SLURM_PROCID"]) +local_rank = int(os.environ["SLURM_LOCALID"]) +size = int(os.environ["SLURM_NTASKS"]) +cpus_per_task = int(os.environ["SLURM_CPUS_PER_TASK"]) + +# get node list from slurm +hostnames = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"]) + +# get IDs of reserved GPU +gpu_ids = os.environ["SLURM_STEP_GPUS"].split(",") + +# define MASTER_ADD & MASTER_PORT +os.environ["MASTER_ADDR"] = hostnames[0] +os.environ["MASTER_PORT"] = str( + 12345 + int(min(gpu_ids)) +) # to avoid port conflict on the same node \ No newline at end of file diff --git a/ISMIR_2025/MERT/utils/mfcc.py b/ISMIR_2025/MERT/utils/mfcc.py new file mode 100644 index 0000000000000000000000000000000000000000..5d63db14375fedcc1cc60f2ef3cecf5c70e9a8fb --- /dev/null +++ b/ISMIR_2025/MERT/utils/mfcc.py @@ -0,0 +1,266 @@ +import os +import glob +import librosa +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import Dataset, DataLoader, random_split +import torch.nn.functional as F +from sklearn.metrics import precision_score, recall_score, f1_score +from tqdm import tqdm +import argparse +import wandb + +class RealFakeDataset(Dataset): + """ + audio/FakeMusicCaps/ + ├─ real/ + │ └─ MusicCaps/*.wav (label=0) + └─ generative/ + └─ .../*.wav (label=1) + """ + def __init__(self, root_dir, sr=16000, n_mels=64, target_duration=10.0): + + self.sr = sr + self.n_mels = n_mels + self.target_duration = target_duration + self.target_samples = int(target_duration * sr) # 10초 = 160,000 샘플 + + self.file_paths = [] + self.labels = [] + + # Real 데이터 (label=0) + real_dir = os.path.join(root_dir, "real") + real_wav_files = glob.glob(os.path.join(real_dir, "**", "*.wav"), recursive=True) + for f in real_wav_files: + self.file_paths.append(f) + self.labels.append(0) + + # Generative 데이터 (label=1) + gen_dir = os.path.join(root_dir, "generative") + gen_wav_files = glob.glob(os.path.join(gen_dir, "**", "*.wav"), recursive=True) + for f in gen_wav_files: + self.file_paths.append(f) + self.labels.append(1) + + def __len__(self): + return len(self.file_paths) + + def __getitem__(self, idx): + audio_path = self.file_paths[idx] + label = self.labels[idx] + # print(f"[DEBUG] Path: {audio_path}, Label: {label}") # 추가 + + waveform, sr = librosa.load(audio_path, sr=self.sr, mono=True) + + current_samples = waveform.shape[0] + if current_samples > self.target_samples: + waveform = waveform[:self.target_samples] + elif current_samples < self.target_samples: + stretch_factor = self.target_samples / current_samples + waveform = librosa.effects.time_stretch(waveform, rate=stretch_factor) + waveform = waveform[:self.target_samples] + + mfcc = librosa.feature.mfcc( + y=waveform, sr=self.sr, n_mfcc=self.n_mels, n_fft=1024, hop_length=256 + ) + mfcc = librosa.util.normalize(mfcc) + + mfcc = np.expand_dims(mfcc, axis=0) + mfcc_tensor = torch.tensor(mfcc, dtype=torch.float) + label_tensor = torch.tensor(label, dtype=torch.long) + + return mfcc_tensor, label_tensor + + + +class AudioCNN(nn.Module): + def __init__(self, num_classes=2): + super(AudioCNN, self).__init__() + self.conv_block = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Conv2d(16, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.AdaptiveAvgPool2d((4,4)) # 최종 -> (B,32,4,4) + ) + self.fc_block = nn.Sequential( + nn.Linear(32*4*4, 128), + nn.ReLU(), + nn.Linear(128, num_classes) + ) + + + def forward(self, x): + x = self.conv_block(x) + # x.shape: (B,32,new_freq,new_time) + + # 1) Flatten + B, C, H, W = x.shape # 동적 shape + x = x.view(B, -1) # (B, 32*H*W) + + # 2) FC + x = self.fc_block(x) + return x + + +def my_collate_fn(batch): + mel_list, label_list = zip(*batch) + + max_frames = max(m.shape[2] for m in mel_list) + + padded = [] + for m in mel_list: + diff = max_frames - m.shape[2] + if diff > 0: + print(f"Padding applied: Original frames = {m.shape[2]}, Target frames = {max_frames}") + m = F.pad(m, (0, diff), mode='constant', value=0) + padded.append(m) + + + mel_batch = torch.stack(padded, dim=0) + label_batch = torch.tensor(label_list, dtype=torch.long) + return mel_batch, label_batch + + +class EarlyStopping: + def __init__(self, patience=5, delta=0, path='./ckpt/mfcc/early_stop_best_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}.pth', verbose=False): + self.patience = patience + self.delta = delta + self.path = path + self.verbose = verbose + self.counter = 0 + self.best_loss = None + self.early_stop = False + + def __call__(self, val_loss, model): + if self.best_loss is None: + self.best_loss = val_loss + self._save_checkpoint(val_loss, model) + elif val_loss > self.best_loss - self.delta: + self.counter += 1 + if self.verbose: + print(f"EarlyStopping counter: {self.counter} out of {self.patience}") + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_loss = val_loss + self._save_checkpoint(val_loss, model) + self.counter = 0 + + def _save_checkpoint(self, val_loss, model): + if self.verbose: + print(f"Validation loss decreased ({self.best_loss:.6f} --> {val_loss:.6f}). Saving model ...") + torch.save(model.state_dict(), self.path) + +def train(batch_size, epochs, learning_rate, root_dir="audio/FakeMusicCaps"): + if not os.path.exists("./ckpt/mfcc/"): + os.makedirs("./ckpt/mfcc/") + + wandb.init( + project="AI Music Detection", + name=f"mfcc_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}", + config={"batch_size": batch_size, "epochs": epochs, "learning_rate": learning_rate}, + ) + + dataset = RealFakeDataset(root_dir=root_dir) + n_total = len(dataset) + n_train = int(n_total * 0.8) + n_val = n_total - n_train + train_ds, val_ds = random_split(dataset, [n_train, n_val]) + + train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn) + val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=my_collate_fn) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = AudioCNN(num_classes=2).to(device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + + best_val_loss = float('inf') + patience = 3 + patience_counter = 0 + + for epoch in range(1, epochs + 1): + print(f"\n[Epoch {epoch}/{epochs}]") + + # Training + model.train() + train_loss, train_correct, train_total = 0, 0, 0 + train_pbar = tqdm(train_loader, desc="Train", leave=False) + for mel_batch, labels in train_pbar: + mel_batch, labels = mel_batch.to(device), labels.to(device) + optimizer.zero_grad() + outputs = model(mel_batch) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + train_loss += loss.item() * mel_batch.size(0) + preds = outputs.argmax(dim=1) + train_correct += (preds == labels).sum().item() + train_total += labels.size(0) + + train_pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + + train_loss /= train_total + train_acc = train_correct / train_total + + # Validation + model.eval() + val_loss, val_correct, val_total = 0, 0, 0 + all_preds, all_labels = [], [] + val_pbar = tqdm(val_loader, desc=" Val ", leave=False) + with torch.no_grad(): + for mel_batch, labels in val_pbar: + mel_batch, labels = mel_batch.to(device), labels.to(device) + outputs = model(mel_batch) + loss = criterion(outputs, labels) + val_loss += loss.item() * mel_batch.size(0) + preds = outputs.argmax(dim=1) + val_correct += (preds == labels).sum().item() + val_total += labels.size(0) + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + + val_loss /= val_total + val_acc = val_correct / val_total + val_precision = precision_score(all_labels, all_preds, average="macro") + val_recall = recall_score(all_labels, all_preds, average="macro") + val_f1 = f1_score(all_labels, all_preds, average="macro") + + print(f"Train Loss: {train_loss:.4f} Acc: {train_acc:.3f} | " + f"Val Loss: {val_loss:.4f} Acc: {val_acc:.3f} " + f"Precision: {val_precision:.3f} Recall: {val_recall:.3f} F1: {val_f1:.3f}") + + wandb.log({"train_loss": train_loss, "train_acc": train_acc, + "val_loss": val_loss, "val_acc": val_acc, + "val_precision": val_precision, "val_recall": val_recall, "val_f1": val_f1}) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + best_model_path = f"./ckpt/mfcc/best_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}.pth" + torch.save(model.state_dict(), best_model_path) + print(f"[INFO] New best model saved: {best_model_path}") + else: + patience_counter += 1 + if patience_counter >= patience: + print("Early stopping triggered!") + break + + wandb.finish() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train AI Music Detection model.") + parser.add_argument('--batch_size', type=int, required=True, help="Batch size for training") + parser.add_argument('--epochs', type=int, required=True, help="Number of epochs") + parser.add_argument('--learning_rate', type=float, required=True, help="Learning rate") + parser.add_argument('--root_dir', type=str, default="audio/FakeMusicCaps", help="Root directory for dataset") + + args = parser.parse_args() + + train(batch_size=args.batch_size, epochs=args.epochs, learning_rate=args.learning_rate, root_dir=args.root_dir) diff --git a/ISMIR_2025/MERT/utils/utilities.py b/ISMIR_2025/MERT/utils/utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..e0be98e8645b8bb1c838d3dc9ae49daac706df62 --- /dev/null +++ b/ISMIR_2025/MERT/utils/utilities.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import logging +import pickle + +import numpy as np + +from scipy import stats + +import csv +import json + +def create_folder(fd): + if not os.path.exists(fd): + os.makedirs(fd, exist_ok=True) + + +def get_filename(path): + path = os.path.realpath(path) + na_ext = path.split("/")[-1] + na = os.path.splitext(na_ext)[0] + return na + + +def get_sub_filepaths(folder): + paths = [] + for root, dirs, files in os.walk(folder): + for name in files: + path = os.path.join(root, name) + paths.append(path) + return paths + + +def create_logging(log_dir, filemode): + create_folder(log_dir) + i1 = 0 + + while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))): + i1 += 1 + + log_path = os.path.join(log_dir, "{:04d}.log".format(i1)) + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s", + datefmt="%a, %d %b %Y %H:%M:%S", + filename=log_path, + filemode=filemode, + ) + + # Print to console + console = logging.StreamHandler() + console.setLevel(logging.INFO) + formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s") + console.setFormatter(formatter) + logging.getLogger("").addHandler(console) + + return logging + + +def read_metadata(csv_path, audio_dir, classes_num, id_to_ix): + """Read metadata of AudioSet from a csv file. + + Args: + csv_path: str + + Returns: + meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)} + """ + + with open(csv_path, "r") as fr: + lines = fr.readlines() + lines = lines[3:] # Remove heads + + # first, count the audio names only of existing files on disk only + + audios_num = 0 + for n, line in enumerate(lines): + items = line.split(", ") + """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']""" + + # audio_name = 'Y{}.wav'.format(items[0]) # Audios are started with an extra 'Y' when downloading + audio_name = "{}_{}_{}.flac".format( + items[0], items[1].replace(".", ""), items[2].replace(".", "") + ) + audio_name = audio_name.replace("_0000_", "_0_") + + if os.path.exists(os.path.join(audio_dir, audio_name)): + audios_num += 1 + + print("CSV audio files: %d" % (len(lines))) + print("Existing audio files: %d" % audios_num) + + # audios_num = len(lines) + targets = np.zeros((audios_num, classes_num), dtype=bool) + audio_names = [] + + n = 0 + for line in lines: + items = line.split(", ") + """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']""" + + # audio_name = 'Y{}.wav'.format(items[0]) # Audios are started with an extra 'Y' when downloading + audio_name = "{}_{}_{}.flac".format( + items[0], items[1].replace(".", ""), items[2].replace(".", "") + ) + audio_name = audio_name.replace("_0000_", "_0_") + + if not os.path.exists(os.path.join(audio_dir, audio_name)): + continue + + label_ids = items[3].split('"')[1].split(",") + + audio_names.append(audio_name) + + # Target + for id in label_ids: + ix = id_to_ix[id] + targets[n, ix] = 1 + n += 1 + + meta_dict = {"audio_name": np.array(audio_names), "target": targets} + return meta_dict + + +def read_audioset_ontology(id_to_ix): + with open('../metadata/audioset_ontology.json', 'r') as f: + data = json.load(f) + + # Output: {'name': 'Bob', 'languages': ['English', 'French']} + sentences = [] + for el in data: + print(el.keys()) + id = el['id'] + if id in id_to_ix: + name = el['name'] + desc = el['description'] + # if '(' in desc: + # print(name, '---', desc) + # print(id_to_ix[id], name, '---', ) + + # sent = name + # sent = name + ', ' + desc.replace('(', '').replace(')', '').lower() + # sent = desc.replace('(', '').replace(')', '').lower() + # sentences.append(sent) + sentences.append(desc) + # print(sent) + # break + return sentences + + +def original_read_metadata(csv_path, classes_num, id_to_ix): + """Read metadata of AudioSet from a csv file. + + Args: + csv_path: str + + Returns: + meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)} + """ + + with open(csv_path, "r") as fr: + lines = fr.readlines() + lines = lines[3:] # Remove heads + + # Thomas Pellegrini: added 02/12/2022 + # check if the audio files indeed exist, otherwise remove from list + + audios_num = len(lines) + targets = np.zeros((audios_num, classes_num), dtype=bool) + audio_names = [] + + for n, line in enumerate(lines): + items = line.split(", ") + """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']""" + + audio_name = "{}_{}_{}.flac".format( + items[0], items[1].replace(".", ""), items[2].replace(".", "") + ) # Audios are started with an extra 'Y' when downloading + audio_name = audio_name.replace("_0000_", "_0_") + + label_ids = items[3].split('"')[1].split(",") + + audio_names.append(audio_name) + + # Target + for id in label_ids: + ix = id_to_ix[id] + targets[n, ix] = 1 + + meta_dict = {"audio_name": np.array(audio_names), "target": targets} + return meta_dict + +def read_audioset_label_tags(class_labels_indices_csv): + with open(class_labels_indices_csv, 'r') as f: + reader = csv.reader(f, delimiter=',') + lines = list(reader) + + labels = [] + ids = [] # Each label has a unique id such as "/m/068hy" + for i1 in range(1, len(lines)): + id = lines[i1][1] + label = lines[i1][2] + ids.append(id) + labels.append(label) + + classes_num = len(labels) + + lb_to_ix = {label : i for i, label in enumerate(labels)} + ix_to_lb = {i : label for i, label in enumerate(labels)} + + id_to_ix = {id : i for i, id in enumerate(ids)} + ix_to_id = {i : id for i, id in enumerate(ids)} + + return lb_to_ix, ix_to_lb, id_to_ix, ix_to_id + + + +def float32_to_int16(x): + # assert np.max(np.abs(x)) <= 1.5 + x = np.clip(x, -1, 1) + return (x * 32767.0).astype(np.int16) + + +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + + +def pad_or_truncate(x, audio_length): + """Pad all audio to specific length.""" + if len(x) <= audio_length: + return np.concatenate((x, np.zeros(audio_length - len(x))), axis=0) + else: + return x[0:audio_length] + + +def pad_audio(x, audio_length): + """Pad all audio to specific length.""" + if len(x) <= audio_length: + return np.concatenate((x, np.zeros(audio_length - len(x))), axis=0) + else: + return x + + +def d_prime(auc): + d_prime = stats.norm().ppf(auc) * np.sqrt(2.0) + return d_prime + + +class Mixup(object): + def __init__(self, mixup_alpha, random_seed=1234): + """Mixup coefficient generator.""" + self.mixup_alpha = mixup_alpha + self.random_state = np.random.RandomState(random_seed) + + def get_lambda(self, batch_size): + """Get mixup random coefficients. + Args: + batch_size: int + Returns: + mixup_lambdas: (batch_size,) + """ + mixup_lambdas = [] + for n in range(0, batch_size, 2): + lam = self.random_state.beta(self.mixup_alpha, self.mixup_alpha, 1)[0] + mixup_lambdas.append(lam) + mixup_lambdas.append(1.0 - lam) + + return np.array(mixup_lambdas) + + +class StatisticsContainer(object): + def __init__(self, statistics_path): + """Contain statistics of different training iterations.""" + self.statistics_path = statistics_path + + self.backup_statistics_path = "{}_{}.pkl".format( + os.path.splitext(self.statistics_path)[0], + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), + ) + + self.statistics_dict = {"bal": [], "test": []} + + def append(self, iteration, statistics, data_type): + statistics["iteration"] = iteration + self.statistics_dict[data_type].append(statistics) + + def dump(self): + pickle.dump(self.statistics_dict, open(self.statistics_path, "wb")) + pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb")) + logging.info(" Dump statistics to {}".format(self.statistics_path)) + logging.info(" Dump statistics to {}".format(self.backup_statistics_path)) + + def load_state_dict(self, resume_iteration): + self.statistics_dict = pickle.load(open(self.statistics_path, "rb")) + + resume_statistics_dict = {"bal": [], "test": []} + + for key in self.statistics_dict.keys(): + for statistics in self.statistics_dict[key]: + if statistics["iteration"] <= resume_iteration: + resume_statistics_dict[key].append(statistics) + + self.statistics_dict = resume_statistics_dict \ No newline at end of file diff --git a/ISMIR_2025/Model/__pycache__/networks.cpython-312.pyc b/ISMIR_2025/Model/__pycache__/networks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e29a727e05815865175b9cd4c06134b0656bf4bd Binary files /dev/null and b/ISMIR_2025/Model/__pycache__/networks.cpython-312.pyc differ diff --git a/ISMIR_2025/Model/__pycache__/networks.cpython-39.pyc b/ISMIR_2025/Model/__pycache__/networks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12e7e2780f5adc11469a6ebe0149397249411f41 Binary files /dev/null and b/ISMIR_2025/Model/__pycache__/networks.cpython-39.pyc differ diff --git a/ISMIR_2025/Model/datalib.py b/ISMIR_2025/Model/datalib.py new file mode 100644 index 0000000000000000000000000000000000000000..f2478eab2f6439506b085bc2c305a048c30ff827 --- /dev/null +++ b/ISMIR_2025/Model/datalib.py @@ -0,0 +1,206 @@ +import os +import glob +import random +import torch +import librosa +import numpy as np +import utils +from sklearn.model_selection import train_test_split +from torch.utils.data import Dataset, DataLoader +import scipy.signal as signal +import scipy.signal +from scipy.signal import butter, lfilter +import numpy as np +import scipy.signal as signal +import librosa +import torch +import random +from torch.utils.data import Dataset +import logging +import csv +import logging +import time +import numpy as np +import h5py +import torch +import torchaudio +# Oversampling Lib +from imblearn.over_sampling import RandomOverSampler + +class FakeMusicCapsDataset(Dataset): + def __init__(self, file_paths, labels, feat_type=['mel'], sr=16000, n_mels=64, target_duration=10.0, augment=True, augment_real=True): + self.file_paths = file_paths + self.labels = labels + self.feat_type = feat_type + self.sr = sr + self.n_mels = n_mels + self.target_duration = target_duration + self.target_samples = int(target_duration * sr) + self.augment = augment + self.augment_real = augment_real + + + def pre_emphasis(self, x, alpha=0.97): + return np.append(x[0], x[1:] - alpha * x[:-1]) + + def highpass_filter(self, y, sr, cutoff=1000, order=5): + nyquist = 0.5 * sr + normal_cutoff = cutoff / nyquist + b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) + return signal.lfilter(b, a, y) + + def augment_audio(self, y, sr): + if random.random() < 0.5: + rate = random.uniform(0.8, 1.2) + y = librosa.effects.time_stretch(y=y, rate=rate) + + if random.random() < 0.5: + n_steps = random.randint(-2, 2) + y = librosa.effects.pitch_shift(y=y, sr=sr, n_steps=n_steps) + + if random.random() < 0.5: + noise_level = np.random.uniform(0.001, 0.005) + y = y + np.random.normal(0, noise_level, y.shape) + + if random.random() < 0.5: + gain = np.random.uniform(0.9, 1.1) + y = y * gain + + return y + + + def __len__(self): + return len(self.file_paths) + + def __getitem__(self, idx): + """ + Load and preprocess audio file. + """ + audio_path = self.file_paths[idx] + label = self.labels[idx] + + waveform, sr = librosa.load(audio_path, sr=self.sr, mono=True) + if label == 0: + if self.augment_real: + waveform = self.augment_audio(waveform, self.sr) + if label == 1: + waveform = self.highpass_filter(waveform, self.sr) + waveform = self.augment_audio(waveform, self.sr) + + current_samples = waveform.shape[0] + if current_samples > self.target_samples: + start_idx = (current_samples - self.target_samples) // 2 + waveform = waveform[start_idx:start_idx + self.target_samples] + elif current_samples < self.target_samples: + waveform = np.pad(waveform, (0, self.target_samples - current_samples), mode='constant') + + + mel_spec = librosa.feature.melspectrogram( + y=waveform, sr=self.sr, n_mels=self.n_mels, n_fft=1024, hop_length=256 + ) + log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max) + + log_mel_spec = np.expand_dims(log_mel_spec, axis=0) + mel_tensor = torch.tensor(log_mel_spec, dtype=torch.float) + label_tensor = torch.tensor(label, dtype=torch.long) + + return mel_tensor, label_tensor + + def extract_feature(self, waveform, feat): + """Extracts specified feature (mel, stft, cqt) from waveform.""" + try: + if feat == 'mel': + mel_spec = librosa.feature.melspectrogram(y=waveform, sr=self.sr, n_mels=self.n_mels, n_fft=1024, hop_length=256) + log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max) + return torch.tensor(log_mel_spec, dtype=torch.float).unsqueeze(0) + elif feat == 'stft': + stft = librosa.stft(waveform, n_fft=512, hop_length=128, window="hann") + logSTFT = np.log(np.abs(stft) + 1e-3) + return torch.tensor(logSTFT, dtype=torch.float).unsqueeze(0) + elif feat == 'cqt': + cqt = librosa.cqt(waveform, sr=self.sr, hop_length=128, bins_per_octave=24) + logCQT = np.log(np.abs(cqt) + 1e-3) + return torch.tensor(logCQT, dtype=torch.float).unsqueeze(0) + else: + raise ValueError(f"[ERROR] Unsupported feature type: {feat}") + except Exception as e: + print(f"[ERROR] Feature extraction failed for {feat}: {e}") + return None + + def highpass_filter(self, y, sr, cutoff=1000, order=5): + if isinstance(sr, np.ndarray): + sr = np.mean(sr) + if not isinstance(sr, (int, float)): + raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}") + if sr <= 0: + raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.") + nyquist = 0.5 * sr + if cutoff <= 0 or cutoff >= nyquist: + print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...") + cutoff = max(10, min(cutoff, nyquist - 1)) + normal_cutoff = cutoff / nyquist + b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) + y_filtered = signal.lfilter(b, a, y) + return y_filtered + +def preprocess_audio(audio_path, sr=16000, n_mels=64, target_duration=10.0): + try: + waveform, _ = librosa.load(audio_path, sr=sr, mono=True) + + target_samples = int(target_duration * sr) + if len(waveform) > target_samples: + start_idx = (len(waveform) - target_samples) // 2 + waveform = waveform[start_idx:start_idx + target_samples] + elif len(waveform) < target_samples: + waveform = np.pad(waveform, (0, target_samples - len(waveform)), mode='constant') + mel_spec = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=n_mels, n_fft=1024, hop_length=256) + log_mel_spec = librosa.power_to_db(mel_spec, ref=np.max) + return torch.tensor(log_mel_spec, dtype=torch.float32).unsqueeze(0).unsqueeze(0) + + except Exception as e: + print(f"[ERROR] 전처리 실패: {audio_path} | 오류: {e}") + return None + + +DATASET_PATH = "/data/kym/AI_Music_Detection/audio/FakeMusicCaps" +SUNOCAPS_PATH = "/data/kym/Audio/SunoCaps" # Open Set 포함 데이터 + +real_files = glob.glob(os.path.join(DATASET_PATH, "real", "**", "*.wav"), recursive=True) +gen_files = glob.glob(os.path.join(DATASET_PATH, "generative", "**", "*.wav"), recursive=True) + +open_real_files = real_files + glob.glob(os.path.join(SUNOCAPS_PATH, "real", "**", "*.wav"), recursive=True) +open_gen_files = gen_files + glob.glob(os.path.join(SUNOCAPS_PATH, "generative", "**", "*.wav"), recursive=True) + +real_labels = [0] * len(real_files) +gen_labels = [1] * len(gen_files) + +open_real_labels = [0] * len(open_real_files) +open_gen_labels = [1] * len(open_gen_files) + +real_train, real_val, real_train_labels, real_val_labels = train_test_split(real_files, real_labels, test_size=0.2, random_state=42) +gen_train, gen_val, gen_train_labels, gen_val_labels = train_test_split(gen_files, gen_labels, test_size=0.2, random_state=42) + +train_files = real_train + gen_train +train_labels = real_train_labels + gen_train_labels +val_files = real_val + gen_val +val_labels = real_val_labels + gen_val_labels + +closed_test_files = real_files + gen_files +closed_test_labels = real_labels + gen_labels + +open_test_files = open_real_files + open_gen_files +open_test_labels = open_real_labels + open_gen_labels + +ros = RandomOverSampler(sampling_strategy='auto', random_state=42) +train_files_resampled, train_labels_resampled = ros.fit_resample(np.array(train_files).reshape(-1, 1), train_labels) + +train_files = train_files_resampled.reshape(-1).tolist() +train_labels = train_labels_resampled +print(f"type(train_labels_resampled): {type(train_labels_resampled)}") + +print(f"Train Org Fake: {len(gen_val)}") +print(f"Train set (Oversampled) - Real: {sum(1 for label in train_labels if label == 0)}, " + f"Fake: {sum(1 for label in train_labels if label == 1)}, Total: {len(train_files)}") +print(f"Validation set - Real: {len(real_val)}, Fake: {len(gen_val)}, Total: {len(val_files)}") +print(f"Closed Test set - Real: {len(real_files)}, Fake: {len(gen_files)}, Total: {len(closed_test_files)}") +print(f"Open Test set - Real: {len(open_real_files)}, Fake: {len(open_gen_files)}, Total: {len(open_test_files)}") \ No newline at end of file diff --git a/ISMIR_2025/Model/main.py b/ISMIR_2025/Model/main.py new file mode 100644 index 0000000000000000000000000000000000000000..68ce313c7adc981df67673e8e2f1c40472f22725 --- /dev/null +++ b/ISMIR_2025/Model/main.py @@ -0,0 +1,336 @@ +import os +import random +import numpy as np +import torch +import torch.nn.functional as F +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm +from torch.utils.tensorboard import SummaryWriter +import wandb +import matplotlib.pyplot as plt +from torch.utils.data import DataLoader +from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, balanced_accuracy_score +from datalib import FakeMusicCapsDataset +from datalib import ( + FakeMusicCapsDataset, + train_files, val_files, train_labels, val_labels, + closed_test_files, closed_test_labels, + open_test_files, open_test_labels, + preprocess_audio +) +from datalib import preprocess_audio +from networks import CCV +from attentionmap import visualize_attention_map +from confusion_matrix import plot_confusion_matrix + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) +''' +python3 main.py --model_name CCV --batch_size 32 --epochs 10 --loss_type ce --oversample True + +audiocnn encoder - crossattn based decoder (ViT) model +''' +# Argument parsing +import argparse +parser = argparse.ArgumentParser(description='AI Music Detection Training') +parser.add_argument('--gpu', type=str, default='1', help='GPU ID') +parser.add_argument('--model_name', type=str, choices=['audiocnn', 'CCV'], default='CCV', help='Model name') +parser.add_argument('--batch_size', type=int, default=32, help='Batch size') +parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate') +parser.add_argument('--epochs', type=int, default=10, help='Number of epochs') +parser.add_argument('--audio_duration', type=float, default=10, help='Length of the audio slice in seconds') +parser.add_argument('--patience_counter', type=int, default=5, help='Early stopping patience') +parser.add_argument('--log_dir', type=str, default='', help='TensorBoard log directory') +parser.add_argument('--ckpt_path', type=str, default='', help='Checkpoint directory') +parser.add_argument("--weight_decay", type=float, default=0.05, help="weight decay (default: 0.0)") +parser.add_argument("--loss_type", type=str, choices=["ce", "weighted_ce", "focal"], default="ce", help="Loss function type") + +parser.add_argument('--inference', type=str, help='Path to a .wav file for inference') +parser.add_argument("--closed_test", action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)") +parser.add_argument("--open_test", action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)") +parser.add_argument("--oversample", type=bool, default=True, help="Apply Oversampling to balance classes") # real data oversampling + + +args = parser.parse_args() +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +torch.manual_seed(42) +random.seed(42) +np.random.seed(42) +wandb.init(project="", + name=f"{args.model_name}_lr{args.learning_rate}_ep{args.epochs}_bs{args.batch_size}", config=args) + +if args.model_name == 'CCV': + model = CCV(embed_dim=512, num_heads=8, num_layers=6, num_classes=2).cuda() + feat_type = 'mel' +else: + raise ValueError(f"Invalid model name: {args.model_name}") + +model = model.to(device) +print(f"Using model: {args.model_name}, Parameters: {count_parameters(model)}") +print(f"weight_decay WD: {args.weight_decay}") + +optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) + +if args.loss_type == "ce": + print("Using CrossEntropyLoss") + criterion = nn.CrossEntropyLoss() + +elif args.loss_type == "weighted_ce": + print("Using Weighted CrossEntropyLoss") + + num_real = sum(1 for label in train_labels if label == 0) + num_fake = sum(1 for label in train_labels if label == 1) + + total_samples = num_real + num_fake + weight_real = total_samples / (2 * num_real) + weight_fake = total_samples / (2 * num_fake) + class_weights = torch.tensor([weight_real, weight_fake]).to(device) + + criterion = nn.CrossEntropyLoss(weight=class_weights) + +elif args.loss_type == "focal": + print("Using Focal Loss") + + class FocalLoss(torch.nn.Module): + def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'): + super(FocalLoss, self).__init__() + self.alpha = alpha + self.gamma = gamma + self.reduction = reduction + + def forward(self, inputs, targets): + ce_loss = F.cross_entropy(inputs, targets, reduction='none') + pt = torch.exp(-ce_loss) + focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss + + if self.reduction == 'mean': + return focal_loss.mean() + elif self.reduction == 'sum': + return focal_loss.sum() + else: + return focal_loss + + criterion = FocalLoss().to(device) + +if not os.path.exists(args.ckpt_path): + os.makedirs(args.ckpt_path) + +train_dataset = FakeMusicCapsDataset(train_files, train_labels, feat_type=feat_type, target_duration=args.audio_duration) +val_dataset = FakeMusicCapsDataset(val_files, val_labels, feat_type=feat_type, target_duration=args.audio_duration) + +train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16) +val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16) + +def train(model, train_loader, val_loader, optimizer, scheduler, criterion, device, args): + writer = SummaryWriter(log_dir=args.log_dir) + best_val_bal_acc = float('inf') + early_stop_cnt = 0 + log_interval = 1 + + for epoch in range(args.epochs): + print(f"\n[Epoch {epoch + 1}/{args.epochs}]") + model.train() + train_loss, train_correct, train_total = 0, 0, 0 + + all_train_preds= [] + all_train_labels = [] + attention_maps = [] + + train_pbar = tqdm(train_loader, desc="Train", leave=False) + for batch_idx, (data, target) in enumerate(train_pbar): + data = data.to(device) + target = target.to(device) + output = model(data) + loss = criterion(output, target) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + train_loss += loss.item() * data.size(0) + preds = output.argmax(dim=1) + train_correct += (preds == target).sum().item() + train_total += target.size(0) + + all_train_labels.extend(target.cpu().numpy()) + all_train_preds.extend(preds.cpu().numpy()) + + if hasattr(model, "get_attention_maps"): + attention_maps.append(model.get_attention_maps()) + + train_loss /= train_total + train_acc = train_correct / train_total + train_bal_acc = balanced_accuracy_score(all_train_labels, all_train_preds) + train_precision = precision_score(all_train_labels, all_train_preds, average="binary") + train_recall = recall_score(all_train_labels, all_train_preds, average="binary") + train_f1 = f1_score(all_train_labels, all_train_preds, average="binary") + + wandb.log({ + "Train Loss": train_loss, "Train Accuracy": train_acc, + "Train Precision": train_precision, "Train Recall": train_recall, + "Train F1 Score": train_f1, "Train B_ACC": train_bal_acc, + }) + + print(f"Train Epoch: {epoch+1} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.3f} | " + f"Train B_ACC: {train_bal_acc:.4f} | Train Prec: {train_precision:.3f} | " + f"Train Rec: {train_recall:.3f} | Train F1: {train_f1:.3f}") + + model.eval() + val_loss, val_correct, val_total = 0, 0, 0 + all_val_preds, all_val_labels = [], [] + attention_maps = [] + val_pbar = tqdm(val_loader, desc=" Val ", leave=False) + with torch.no_grad(): + for data, target in val_pbar: + data, target = data.to(device), target.to(device) + output = model(data) + loss = criterion(output, target) + val_loss += loss.item() * data.size(0) + preds = output.argmax(dim=1) + val_correct += (preds == target).sum().item() + val_total += target.size(0) + + all_val_labels.extend(target.cpu().numpy()) + all_val_preds.extend(preds.cpu().numpy()) + + if hasattr(model, "get_attention_maps"): + attention_maps.append(model.get_attention_maps()) + + val_loss /= val_total + val_acc = val_correct / val_total + val_bal_acc = balanced_accuracy_score(all_val_labels, all_val_preds) + val_precision = precision_score(all_val_labels, all_val_preds, average="binary") + val_recall = recall_score(all_val_labels, all_val_preds, average="binary") + val_f1 = f1_score(all_val_labels, all_val_preds, average="binary") + + wandb.log({ + "Validation Loss": val_loss, "Validation Accuracy": val_acc, + "Validation Precision": val_precision, "Validation Recall": val_recall, + "Validation F1 Score": val_f1, "Validation B_ACC": val_bal_acc, + }) + + print(f"Val Epoch: {epoch+1} [{batch_idx * len(data)}/{len(val_loader.dataset)} " + f"({100. * batch_idx / len(val_loader):.0f}%)]\t" + f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.3f} | " + f"Val B_ACC: {val_bal_acc:.4f} | Val Prec: {val_precision:.3f} | " + f"Val Rec: {val_recall:.3f} | Val F1: {val_f1:.3f}") + + if epoch % 1 == 0 and len(attention_maps) > 0: + print(f"Visualizing Attention Map at Epoch {epoch+1}") + + if isinstance(attention_maps[0], list): + attn_map_numpy = np.array([t.detach().cpu().numpy() for t in attention_maps[0]]) + elif isinstance(attention_maps[0], torch.Tensor): + attn_map_numpy = attention_maps[0].detach().cpu().numpy() + else: + attn_map_numpy = np.array(attention_maps[0]) + + print(f"Attention Map Shape: {attn_map_numpy.shape}") + + if len(attn_map_numpy) > 0: + fig, ax = plt.subplots(figsize=(10, 8)) + ax.imshow(attn_map_numpy[0], cmap='viridis', interpolation='nearest') + ax.set_title(f"Attention Map - Epoch {epoch+1}") + plt.colorbar(ax.imshow(attn_map_numpy[0], cmap='viridis')) + plt.savefig("") + plt.show() + else: + print(f"Warning: attention_maps[0] is empty! Shape={attn_map_numpy.shape}") + + if val_bal_acc < best_val_bal_acc: + best_val_bal_acc = val_bal_acc + early_stop_cnt = 0 + torch.save(model.state_dict(), os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth")) + print("Best model saved.") + else: + early_stop_cnt += 1 + print(f'PATIENCE {early_stop_cnt}/{args.patience_counter}') + + if early_stop_cnt >= args.patience_counter: + print("Early stopping triggered.") + break + + scheduler.step() + plot_confusion_matrix(all_val_labels, all_val_preds, classes=["REAL", "FAKE"], writer=writer, epoch=epoch) + + wandb.finish() + writer.close() + +def predict(audio_path): + print(f"Loading model from {args.ckpt_path}/celoss_best_model_{args.model_name}.pth") + model.load_state_dict(torch.load(os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth"), map_location=device)) + model.eval() + + input_tensor = preprocess_audio(audio_path).to(device) + + with torch.no_grad(): + output = model(input_tensor) + probabilities = F.softmax(output, dim=1) + ai_music_prob = probabilities[0, 1].item() + + if ai_music_prob > 0.5: + print(f"FAKE MUSIC {ai_music_prob:.2%})") + else: + print(f"REAL MUSIC {100 - ai_music_prob * 100:.2f}%") + +def Test(model, test_loader, criterion, device): + model.load_state_dict(torch.load(os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth"), map_location=device)) + model.eval() + test_loss, test_correct, test_total = 0, 0, 0 + all_preds, all_labels = [], [] + + with torch.no_grad(): + for data, target in tqdm(test_loader, desc=" Test ", leave=False): + data, target = data.to(device), target.to(device) + output = model(data) + loss = criterion(output, target) + + test_loss += loss.item() * data.size(0) + preds = output.argmax(dim=1) + test_correct += (preds == target).sum().item() + test_total += target.size(0) + + all_labels.extend(target.cpu().numpy()) + all_preds.extend(preds.cpu().numpy()) + + test_loss /= test_total + test_acc = test_correct / test_total + test_bal_acc = balanced_accuracy_score(all_labels, all_preds) + test_precision = precision_score(all_labels, all_preds, average="binary") + test_recall = recall_score(all_labels, all_preds, average="binary") + test_f1 = f1_score(all_labels, all_preds, average="binary") + + print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | " + f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | " + f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}") + + +if __name__ == "__main__": + train(model, train_loader, val_loader, optimizer, scheduler, criterion, device, args) + if args.closed_test: + print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...") + test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, feat_type=feat_type, target_duration=args.audio_duration) + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16) + + elif args.open_test: + print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...") + test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, feat_type=feat_type, target_duration=args.audio_duration) + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16) + + else: + print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...") + test_dataset = FakeMusicCapsDataset(val_files, val_labels, feat_type=feat_type, target_duration=args.audio_duration) + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16) + + print("\nEvaluating Model on Test Set...") + Test(model, test_loader, criterion, device) + + if args.inference: + if not os.path.exists(args.inference): + print(f"[ERROR] No File Found: {args.inference}") + else: + predict(args.inference) diff --git a/ISMIR_2025/Model/networks.py b/ISMIR_2025/Model/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..07914ba386647553f4696c452323c5c0a77f37ce --- /dev/null +++ b/ISMIR_2025/Model/networks.py @@ -0,0 +1,237 @@ +import torch +import torch.nn as nn + +class audiocnn(nn.Module): + def __init__(self, num_classes=2): + super(audiocnn, self).__init__() + self.conv_block = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Conv2d(16, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.AdaptiveAvgPool2d((4,4)) # 최종 -> (B,32,4,4) + ) + self.fc_block = nn.Sequential( + nn.Linear(32*4*4, 128), + nn.ReLU(), + nn.Linear(128, num_classes) + ) + + def forward(self, x): + x = self.conv_block(x) + # x.shape: (B,32,new_freq,new_time) + + # 1) Flatten + B, C, H, W = x.shape # 동적 shape + x = x.view(B, -1) # (B, 32*H*W) + + # 2) FC + x = self.fc_block(x) + return x + +class AudioCNN(nn.Module): + def __init__(self, embed_dim=512): + super(AudioCNN, self).__init__() + self.conv_block = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Conv2d(16, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.AdaptiveAvgPool2d((4, 4)) # 최종 -> (B, 32, 4, 4) + ) + self.projection = nn.Linear(32 * 4 * 4, embed_dim) + + def forward(self, x): + x = self.conv_block(x) + B, C, H, W = x.shape + x = x.view(B, -1) # Flatten (B, C * H * W) + x = self.projection(x) # Project to embed_dim + return x + +class ViTDecoder(nn.Module): + def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): + super(ViTDecoder, self).__init__() + + # Transformer layers + encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + # Classification head + self.classifier = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, num_classes) + ) + + def forward(self, x): + # Transformer expects input of shape (seq_len, batch, embed_dim) + x = x.unsqueeze(1).permute(1, 0, 2) # Add sequence dim (1, B, embed_dim) + x = self.transformer(x) # Pass through Transformer + x = x.mean(dim=0) # Take the mean over the sequence dimension (B, embed_dim) + + x = self.classifier(x) # Classification head + return x + +class AudioCNNWithViTDecoder(nn.Module): + def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): + super(AudioCNNWithViTDecoder, self).__init__() + self.encoder = AudioCNN(embed_dim=embed_dim) + self.decoder = ViTDecoder(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes) + + def forward(self, x): + x = self.encoder(x) # Pass through AudioCNN encoder + x = self.decoder(x) # Pass through ViT decoder + return x + + +# class AudioCNN(nn.Module): +# def __init__(self, num_classes=2): +# super(AudioCNN, self).__init__() +# self.conv_block = nn.Sequential( +# nn.Conv2d(1, 16, kernel_size=3, padding=1), +# nn.ReLU(), +# nn.MaxPool2d(2), +# nn.Conv2d(16, 32, kernel_size=3, padding=1), +# nn.ReLU(), +# nn.MaxPool2d(2), +# nn.AdaptiveAvgPool2d((4,4)) # 최종 -> (B,32,4,4) +# ) +# self.fc_block = nn.Sequential( +# nn.Linear(32*4*4, 128), +# nn.ReLU(), +# nn.Linear(128, num_classes) +# ) + + +# def forward(self, x): +# x = self.conv_block(x) +# # x.shape: (B,32,new_freq,new_time) + +# # 1) Flatten +# B, C, H, W = x.shape # 동적 shape +# x = x.view(B, -1) # (B, 32*H*W) + +# # 2) FC +# x = self.fc_block(x) +# return x + + + +class audio_crossattn(nn.Module): + def __init__(self, embed_dim=512): + super(audio_crossattn, self).__init__() + self.conv_block = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Conv2d(16, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.AdaptiveAvgPool2d((4, 4)) # 최종 출력 -> (B, 32, 4, 4) + ) + self.projection = nn.Linear(32 * 4 * 4, embed_dim) + + def forward(self, x): + x = self.conv_block(x) # Convolutional feature extraction + B, C, H, W = x.shape + x = x.view(B, -1) # Flatten (B, C * H * W) + x = self.projection(x) # Linear projection to embed_dim + return x + + +class CrossAttentionLayer(nn.Module): + def __init__(self, embed_dim, num_heads): + super(CrossAttentionLayer, self).__init__() + self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) + self.layer_norm = nn.LayerNorm(embed_dim) + self.feed_forward = nn.Sequential( + nn.Linear(embed_dim, embed_dim * 4), + nn.ReLU(), + nn.Linear(embed_dim * 4, embed_dim) + ) + + def forward(self, x, cross_input): + # Cross-attention between x and cross_input + attn_output, _ = self.multihead_attn(query=x, key=cross_input, value=cross_input) + x = self.layer_norm(x + attn_output) # Add & Norm + feed_forward_output = self.feed_forward(x) + x = self.layer_norm(x + feed_forward_output) # Add & Norm + return x + +class ViTDecoderWithCrossAttention(nn.Module): + def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): + super(ViTDecoderWithCrossAttention, self).__init__() + + # Cross-Attention layers + self.cross_attention_layers = nn.ModuleList([ + CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers) + ]) + + # Transformer Encoder layers + encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + # Classification head + self.classifier = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, num_classes) + ) + + def forward(self, x, cross_attention_input): + # Pass through Cross-Attention layers + for layer in self.cross_attention_layers: + x = layer(x, cross_attention_input) + + # Transformer expects input of shape (seq_len, batch, embed_dim) + x = x.unsqueeze(1).permute(1, 0, 2) # Add sequence dim (1, B, embed_dim) + x = self.transformer(x) # Pass through Transformer + embedding = x.mean(dim=0) # Take the mean over the sequence dimension (B, embed_dim) + + # Classification head + x = self.classifier(embedding) + return x, embedding + +# class AudioCNNWithViTDecoderAndCrossAttention(nn.Module): +# def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): +# super(AudioCNNWithViTDecoderAndCrossAttention, self).__init__() +# self.encoder = audio_crossattn(embed_dim=embed_dim) +# self.decoder = ViTDecoderWithCrossAttention(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes) + +# def forward(self, x, cross_attention_input): +# # Pass through AudioCNN encoder +# x = self.encoder(x) + +# # Pass through ViTDecoder with Cross-Attention +# x = self.decoder(x, cross_attention_input) +# return x +class CCV(nn.Module): + def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True): + super(CCV, self).__init__() + self.encoder = AudioCNN(embed_dim=embed_dim) + self.decoder = ViTDecoderWithCrossAttention(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes) + if freeze_feature_extractor: + for param in self.encoder.parameters(): + param.requires_grad = False + for param in self.decoder.parameters(): + param.requires_grad = False + def forward(self, x, cross_attention_input=None): + # Pass through AudioCNN encoder + x = self.encoder(x) + + # If cross_attention_input is not provided, use the encoder output + if cross_attention_input is None: + cross_attention_input = x + + # Pass through ViTDecoder with Cross-Attention + x, embedding = self.decoder(x, cross_attention_input) + return x, embedding + +#--------------------------------------------------------- +''' +audiocnn weight frozen +crossatten decoder -lora tuning +''' + diff --git a/ISMIR_2025/Model/test.py b/ISMIR_2025/Model/test.py new file mode 100644 index 0000000000000000000000000000000000000000..0837d7094ef5affc3111a56bdfe4adfe67333bf6 --- /dev/null +++ b/ISMIR_2025/Model/test.py @@ -0,0 +1,129 @@ +import os +import torch +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt +from torch.utils.data import DataLoader +from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, confusion_matrix +from datalib_f import ( + FakeMusicCapsDataset, + closed_test_files, closed_test_labels, + open_test_files, open_test_labels, + val_files, val_labels +) +from networks_f import CCV_Wav2Vec2 +import argparse + +parser = argparse.ArgumentParser(description="AI Music Detection Testing") +parser.add_argument('--gpu', type=str, default='1', help='GPU ID') +parser.add_argument('--model_name', type=str, choices=['audiocnn', 'CCV'], default='CCV_Wav2Vec2', help='Model name') +parser.add_argument('--batch_size', type=int, default=32, help='Batch size') +parser.add_argument('--ckpt_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/tensorboard/wav2vec', help='Checkpoint directory') +parser.add_argument('--closed_test', action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)") +parser.add_argument('--open_test', action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)") +parser.add_argument('--output_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/test_results/w_celoss_repreprocess/wav2vec', help='Path to save test results') + +args = parser.parse_args() +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def plot_confusion_matrix(y_true, y_pred, classes, output_path): + cm = confusion_matrix(y_true, y_pred) + fig, ax = plt.subplots(figsize=(6, 6)) + im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) + ax.figure.colorbar(im, ax=ax) + + num_classes = cm.shape[0] + tick_labels = classes[:num_classes] + + ax.set(xticks=np.arange(num_classes), + yticks=np.arange(num_classes), + xticklabels=tick_labels, + yticklabels=tick_labels, + ylabel='True label', + xlabel='Predicted label') + + thresh = cm.max() / 2. + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cm[i, j], 'd'), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + + fig.tight_layout() + plt.savefig(output_path) + plt.close(fig) + +if args.model_name == 'CCV_Wav2Vec2': + model = CCV_Wav2Vec2(embed_dim=512, num_heads=8, num_layers=6, num_classes=2).to(device) +else: + raise ValueError(f"Invalid model name: {args.model_name}") + +ckpt_file = os.path.join(args.ckpt_path, f"best_model_{args.model_name}.pth") +if not os.path.exists(ckpt_file): + raise FileNotFoundError(f"Checkpoint not found: {ckpt_file}") + +print(f"\nLoading model from {ckpt_file}") + +# model.load_state_dict(torch.load(ckpt_file, map_location=device)) +# 병렬 +state_dict = torch.load(ckpt_file, map_location=device) +from collections import OrderedDict +new_state_dict = OrderedDict() +for k, v in state_dict.items(): + name = k[7:] if k.startswith("module.") else k + new_state_dict[name] = v +model.load_state_dict(new_state_dict) +# 병렬 +model.eval() + +torch.cuda.empty_cache() + +if args.closed_test: + print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...") + test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, feat_type="mel", target_duration=10.0) +elif args.open_test: + print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...") + test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, feat_type="mel", target_duration=10.0) +else: + print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...") + test_dataset = FakeMusicCapsDataset(val_files, val_labels, feat_type="mel", target_duration=10.0) + +test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16) + +def Test(model, test_loader, device): + model.eval() + test_loss, test_correct, test_total = 0, 0, 0 + all_preds, all_labels = [], [] + + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + loss = F.cross_entropy(output, target) + + test_loss += loss.item() * data.size(0) + preds = output.argmax(dim=1) + test_correct += (preds == target).sum().item() + test_total += target.size(0) + + all_labels.extend(target.cpu().numpy()) + all_preds.extend(preds.cpu().numpy()) + + test_loss /= test_total + test_acc = test_correct / test_total + test_bal_acc = balanced_accuracy_score(all_labels, all_preds) + test_precision = precision_score(all_labels, all_preds, average="binary") + test_recall = recall_score(all_labels, all_preds, average="binary") + test_f1 = f1_score(all_labels, all_preds, average="binary") + + print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | " + f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | " + f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}") + + os.makedirs(args.output_path, exist_ok=True) + conf_matrix_path = os.path.join(args.output_path, f"confusion_matrix_{args.model_name}.png") + plot_confusion_matrix(all_labels, all_preds, classes=["real", "generative"], output_path=conf_matrix_path) + +print("\nEvaluating Model on Test Set...") +Test(model, test_loader, device) diff --git a/ISMIR_2025/music2vec/__pycache__/datalib.cpython-311.pyc b/ISMIR_2025/music2vec/__pycache__/datalib.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ed2fc175e6d0bd0ac67453d34fe473075f3a1c7 Binary files /dev/null and b/ISMIR_2025/music2vec/__pycache__/datalib.cpython-311.pyc differ diff --git a/ISMIR_2025/music2vec/__pycache__/networks.cpython-311.pyc b/ISMIR_2025/music2vec/__pycache__/networks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..392b7845da3002c5a1d4588c1162b82c9ec40b75 Binary files /dev/null and b/ISMIR_2025/music2vec/__pycache__/networks.cpython-311.pyc differ diff --git a/ISMIR_2025/music2vec/__pycache__/networks.cpython-312.pyc b/ISMIR_2025/music2vec/__pycache__/networks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bb6c0f9728de048a1dc2be8f3a720842f1791a3 Binary files /dev/null and b/ISMIR_2025/music2vec/__pycache__/networks.cpython-312.pyc differ diff --git a/ISMIR_2025/music2vec/__pycache__/networks.cpython-39.pyc b/ISMIR_2025/music2vec/__pycache__/networks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be8c867854d29edf217120eabb2f60ca6730a6dd Binary files /dev/null and b/ISMIR_2025/music2vec/__pycache__/networks.cpython-39.pyc differ diff --git a/ISMIR_2025/music2vec/datalib.py b/ISMIR_2025/music2vec/datalib.py new file mode 100644 index 0000000000000000000000000000000000000000..cac3cc9f878d5db91d2cfb3311ec859d8ac826c5 --- /dev/null +++ b/ISMIR_2025/music2vec/datalib.py @@ -0,0 +1,144 @@ +import os +import glob +import torch +import torchaudio +import librosa +import numpy as np +from sklearn.model_selection import train_test_split +from torch.utils.data import Dataset +from imblearn.over_sampling import RandomOverSampler +from transformers import Wav2Vec2Processor +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence +import scipy.signal as signal +import random + +class FakeMusicCapsDataset(Dataset): + def __init__(self, file_paths, labels, sr=16000, target_duration=10.0, augment=True): + self.file_paths = file_paths + self.labels = labels + self.sr = sr + self.target_samples = int(target_duration * sr) + self.augment = augment + def __len__(self): + return len(self.file_paths) + + def augment_audio(self, y, sr): + if isinstance(y, torch.Tensor): + y = y.numpy() + if random.random() < 0.5: + rate = random.uniform(0.8, 1.2) + y = librosa.effects.time_stretch(y=y, rate=rate) + if random.random() < 0.5: + n_steps = random.randint(-2, 2) + y = librosa.effects.pitch_shift(y=y, sr=sr, n_steps=n_steps) + if random.random() < 0.5: + noise_level = np.random.uniform(0.001, 0.005) + y = y + np.random.normal(0, noise_level, y.shape) + if random.random() < 0.5: + gain = np.random.uniform(0.9, 1.1) + y = y * gain + return torch.tensor(y, dtype=torch.float32) + + + def __getitem__(self, idx): + audio_path = self.file_paths[idx] + label = self.labels[idx] + + waveform, sr = torchaudio.load(audio_path) + waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sr)(waveform) + waveform = waveform.mean(dim=0) + current_samples = waveform.shape[0] + + if label == 0: + waveform = self.augment_audio(waveform, self.sr) + if label == 1: + waveform = self.highpass_filter(waveform, self.sr) + waveform = self.augment_audio(waveform, self.sr) + + if current_samples > self.target_samples: + waveform = waveform[:self.target_samples] + elif current_samples < self.target_samples: + pad_length = self.target_samples - current_samples + waveform = torch.nn.functional.pad(waveform, (0, pad_length)) + + # waveform = waveform.squeeze(0) + if isinstance(waveform, np.ndarray): + waveform = torch.tensor(waveform, dtype=torch.float32) + + return waveform.unsqueeze(0), torch.tensor(label, dtype=torch.long) + + def highpass_filter(self, y, sr, cutoff=500, order=5): + if isinstance(sr, np.ndarray): + sr = np.mean(sr) + if not isinstance(sr, (int, float)): + raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}") + if sr <= 0: + raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.") + nyquist = 0.5 * sr + if cutoff <= 0 or cutoff >= nyquist: + print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...") + cutoff = max(10, min(cutoff, nyquist - 1)) + normal_cutoff = cutoff / nyquist + b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) + y_filtered = signal.lfilter(b, a, y) + return y_filtered + + def preprocess_audio(audio_path, target_sr=16000, max_length=160000): + waveform, sr = torchaudio.load(audio_path) + if sr != target_sr: + waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)(waveform) + + waveform = waveform.mean(dim=0).unsqueeze(0) + + current_samples = waveform.shape[1] + if current_samples > max_length: + start_idx = (current_samples - max_length) // 2 + waveform = waveform[:, start_idx:start_idx + max_length] + elif current_samples < max_length: + pad_length = max_length - current_samples + waveform = torch.nn.functional.pad(waveform, (0, pad_length)) + + return waveform + + +DATASET_PATH = "/data/kym/AI_Music_Detection/audio/FakeMusicCaps" +SUNOCAPS_PATH = "/data/kym/Audio/SunoCaps" # Open Set 포함 데이터 + +real_files = glob.glob(os.path.join(DATASET_PATH, "real", "**", "*.wav"), recursive=True) +gen_files = glob.glob(os.path.join(DATASET_PATH, "generative", "**", "*.wav"), recursive=True) + +open_real_files = real_files + glob.glob(os.path.join(SUNOCAPS_PATH, "real", "**", "*.wav"), recursive=True) +open_gen_files = gen_files + glob.glob(os.path.join(SUNOCAPS_PATH, "generative", "**", "*.wav"), recursive=True) + +real_labels = [0] * len(real_files) +gen_labels = [1] * len(gen_files) + +open_real_labels = [0] * len(open_real_files) +open_gen_labels = [1] * len(open_gen_files) + +real_train, real_val, real_train_labels, real_val_labels = train_test_split(real_files, real_labels, test_size=0.2, random_state=42) +gen_train, gen_val, gen_train_labels, gen_val_labels = train_test_split(gen_files, gen_labels, test_size=0.2, random_state=42) + +train_files = real_train + gen_train +train_labels = real_train_labels + gen_train_labels +val_files = real_val + gen_val +val_labels = real_val_labels + gen_val_labels + +closed_test_files = real_files + gen_files +closed_test_labels = real_labels + gen_labels + +open_test_files = open_real_files + open_gen_files +open_test_labels = open_real_labels + open_gen_labels + +ros = RandomOverSampler(sampling_strategy='auto', random_state=42) +train_files_resampled, train_labels_resampled = ros.fit_resample(np.array(train_files).reshape(-1, 1), train_labels) + +train_files = train_files_resampled.reshape(-1).tolist() +train_labels = train_labels_resampled + +print(f"Train Original FAKE: {len(gen_train)}") +print(f"Train set (Oversampled) - REAL: {sum(1 for label in train_labels if label == 0)}, " + f"FAKE: {sum(1 for label in train_labels if label == 1)}, Total: {len(train_files)}") +print(f"Validation set - REAL: {len(real_val)}, FAKE: {len(gen_val)}, Total: {len(val_files)}") diff --git a/ISMIR_2025/music2vec/inference.py b/ISMIR_2025/music2vec/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..f3200bff3f3f86877f747fb54e5602f912c16106 --- /dev/null +++ b/ISMIR_2025/music2vec/inference.py @@ -0,0 +1,64 @@ +import os +import torch +import torch.nn.functional as F +import torchaudio +import argparse +from datalib import preprocess_audio +from networks import Wav2Vec2ForFakeMusic + +# Argument Parsing +parser = argparse.ArgumentParser(description="Wav2Vec2 AI Music Detection Inference") +parser.add_argument('--gpu', type=str, default='0', help='GPU ID') +parser.add_argument('--model_name', type=str, choices=['Wav2Vec2ForFakeMusic'], default='Wav2Vec2ForFakeMusic', help='Model name') +parser.add_argument('--ckpt_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/wav2vec/ckpt/', help='Checkpoint directory') +parser.add_argument('--model_type', type=str, choices=['pretrain', 'finetune'], required=True, help='Choose between pretrained or fine-tuned model') +parser.add_argument('--inference', type=str, required=True, help='Path to a .wav file for inference') +args = parser.parse_args() +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Load Model Checkpoint +if args.model_type == 'pretrain': + model_file = os.path.join(args.ckpt_path, "wav2vec2_pretrain_10.pth") +elif args.model_type == 'finetune': + model_file = os.path.join(args.ckpt_path, "wav2vec2_finetune_5.pth") +else: + raise ValueError("Invalid model type. Choose between 'pretrain' or 'finetune'.") + +if not os.path.exists(model_file): + raise FileNotFoundError(f"Model checkpoint not found: {model_file}") + +if args.model_name == 'Wav2Vec2ForFakeMusic': + model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=(args.model_type == 'finetune')) +else: + raise ValueError(f"Invalid model name: {args.model_name}") + +def predict(audio_path): + print(f"\n🔍 Loading model from {model_file}") + + if not os.path.exists(audio_path): + raise FileNotFoundError(f"[ERROR] Audio file not found: {audio_path}") + + model.to(device) + model.eval() + + input_tensor = preprocess_audio(audio_path).to(device) + print(f"Input shape after preprocessing: {input_tensor.shape}") + + with torch.no_grad(): + output = model(input_tensor) + print(f"Raw model output (logits): {output}") + + probabilities = F.softmax(output, dim=1) + ai_music_prob = probabilities[0, 1].item() + + print(f"Softmax Probabilities: {probabilities}") + print(f"AI Music Probability: {ai_music_prob:.4f}") + + if ai_music_prob > 0.5: + print(f" FAKE MUSIC DETECTED ({ai_music_prob:.2%})") + else: + print(f" REAL MUSIC DETECTED ({100 - ai_music_prob * 100:.2f}%)") + +if __name__ == "__main__": + predict(args.inference) diff --git a/ISMIR_2025/music2vec/main.py b/ISMIR_2025/music2vec/main.py new file mode 100644 index 0000000000000000000000000000000000000000..7a06ae4cf712d8496142acd34d506f6e14a391a2 --- /dev/null +++ b/ISMIR_2025/music2vec/main.py @@ -0,0 +1,155 @@ +import os +import random +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm +from torch.utils.data import DataLoader +from sklearn.metrics import f1_score, precision_score, recall_score, balanced_accuracy_score +import wandb +import argparse +from transformers import Wav2Vec2Processor +from datalib import FakeMusicCapsDataset, train_files, train_labels, val_files, val_labels +from networks import Music2VecClassifier, CCV + +parser = argparse.ArgumentParser(description='AI Music Detection Training with Music2Vec + CCV') +parser.add_argument('--gpu', type=str, default='2', help='GPU ID') +parser.add_argument('--batch_size', type=int, default=32, help='Batch size') +parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate') +parser.add_argument('--finetune_lr', type=float, default=1e-3, help='Fine-Tune Learning rate') +parser.add_argument('--pretrain_epochs', type=int, default=20, help='Pretraining epochs (REAL data only)') +parser.add_argument('--finetune_epochs', type=int, default=10, help='Fine-tuning epochs (REAL + FAKE data)') +parser.add_argument('--checkpoint_dir', type=str, default='', help='Checkpoint directory') +parser.add_argument('--weight_decay', type=float, default=0.001, help="Weight decay for optimizer") + +args = parser.parse_args() + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +torch.manual_seed(42) +random.seed(42) +np.random.seed(42) + +wandb.init(project="music2vec_ccv", name=f"pretrain_{args.pretrain_epochs}_finetune_{args.finetune_epochs}", config=args) + +print("Preparing datasets...") +train_dataset = FakeMusicCapsDataset(train_files, train_labels) +val_dataset = FakeMusicCapsDataset(val_files, val_labels) + +train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) +val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) + +pretrain_ckpt = os.path.join(args.checkpoint_dir, f"music2vec_pretrain_{args.pretrain_epochs}.pth") +finetune_ckpt = os.path.join(args.checkpoint_dir, f"music2vec_ccv_finetune_{args.finetune_epochs}.pth") + +print("Initializing Music2Vec model for Pretraining...") +processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h") +model = Music2VecClassifier(freeze_feature_extractor=False).to(device) # Pretraining에서는 freeze + +criterion = nn.CrossEntropyLoss() +optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) + +def train(model, dataloader, optimizer, criterion, device, epoch, phase="Pretrain"): + model.train() + total_loss, total_correct, total_samples = 0, 0, 0 + all_preds, all_labels = [], [] + + for inputs, labels in tqdm(dataloader, desc=f"{phase} Training Epoch {epoch+1}"): + labels = labels.to(device) + inputs = inputs.to(device) + + logits = model(inputs) + loss = criterion(logits, labels) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + preds = logits.argmax(dim=1) + total_correct += (preds == labels).sum().item() + total_samples += labels.size(0) + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + + scheduler.step() + accuracy = total_correct / total_samples + f1 = f1_score(all_labels, all_preds, average="binary") + balanced_acc = balanced_accuracy_score(all_labels, all_preds) + precision = precision_score(all_labels, all_preds, average="binary") + recall = recall_score(all_labels, all_preds, average="binary") + + wandb.log({ + f"{phase} Train Loss": total_loss / len(dataloader), + f"{phase} Train Accuracy": accuracy, + f"{phase} Train F1 Score": f1, + f"{phase} Train Precision": precision, + f"{phase} Train Recall": recall, + f"{phase} Train Balanced Accuracy": balanced_acc, + }) + + print(f"{phase} Train Epoch {epoch+1}: Train Loss: {total_loss / len(dataloader):.4f}, " + f"Train Acc: {accuracy:.4f}, Train F1: {f1:.4f}, Train Prec: {precision:.4f}, Train Rec: {recall:.4f}, B_ACC: {balanced_acc:.4f}") + +def validate(model, dataloader, criterion, device, phase="Validation"): + model.eval() + total_loss, total_correct, total_samples = 0, 0, 0 + all_preds, all_labels = [], [] + + with torch.no_grad(): + for inputs, labels in tqdm(dataloader, desc=f"{phase}"): + inputs, labels = inputs.to(device), labels.to(device) + inputs = inputs.squeeze(1) + outputs = model(inputs) + loss = criterion(outputs, labels) + + total_loss += loss.item() + preds = outputs.argmax(dim=1) + total_correct += (preds == labels).sum().item() + total_samples += labels.size(0) + + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + + accuracy = total_correct / total_samples + f1 = f1_score(all_labels, all_preds, average="weighted") + val_bal_acc = balanced_accuracy_score(all_labels, all_preds) + val_precision = precision_score(all_labels, all_preds, average="binary") + val_recall = recall_score(all_labels, all_preds, average="binary") + + wandb.log({ + f"{phase} Val Loss": total_loss / len(dataloader), + f"{phase} Val Accuracy": accuracy, + f"{phase} Val F1 Score": f1, + f"{phase} Val Precision": val_precision, + f"{phase} Val Recall": val_recall, + f"{phase} Val Balanced Accuracy": val_bal_acc, + }) + print(f"{phase} Val Loss: {total_loss / len(dataloader):.4f}, " + f"Val Acc: {accuracy:.4f}, Val F1: {f1:.4f}, Val Prec: {val_precision:.4f}, Val Rec: {val_recall:.4f}, Val B_ACC: {val_bal_acc:.4f}") + return total_loss / len(dataloader), accuracy, f1 + +print("\nStep 1: Self-Supervised Pretraining on REAL Data") +for epoch in range(args.pretrain_epochs): + train(model, train_loader, optimizer, criterion, device, epoch, phase="Pretrain") + +torch.save(model.state_dict(), pretrain_ckpt) +print(f"\nPretraining completed! Model saved at: {pretrain_ckpt}") + +print("\nInitializing Music2Vec + CCV Model for Fine-Tuning...") +model.load_state_dict(torch.load(pretrain_ckpt)) + +# model = CCV(embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True).to(device) +model = Music2VecClassifier(freeze_feature_extractor=False).to(device) +optimizer = optim.Adam(model.parameters(), lr=args.finetune_lr, weight_decay=args.weight_decay) +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) + +print("\nStep 2: Fine-Tuning CCV Model using Music2Vec Features") +for epoch in range(args.finetune_epochs): + train(model, train_loader, optimizer, criterion, device, epoch, phase="Fine-Tune") + +torch.save(model.state_dict(), finetune_ckpt) +print(f"\nFine-Tuning completed! Model saved at: {finetune_ckpt}") diff --git a/ISMIR_2025/music2vec/networks.py b/ISMIR_2025/music2vec/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..621f30c95a3933ac286b9a2a54c8b4be2e7aa550 --- /dev/null +++ b/ISMIR_2025/music2vec/networks.py @@ -0,0 +1,247 @@ +import torch +import torch.nn as nn +from transformers import Data2VecAudioModel, Wav2Vec2Processor + +class Music2VecClassifier(nn.Module): + def __init__(self, num_classes=2, freeze_feature_extractor=True): + super(Music2VecClassifier, self).__init__() + + self.processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h") + self.music2vec = Data2VecAudioModel.from_pretrained("m-a-p/music2vec-v1") + + if freeze_feature_extractor: + for param in self.music2vec.parameters(): + param.requires_grad = False + + # Conv1d for learnable weighted average across layers + self.conv1d = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1) + + # Classification head + self.classifier = nn.Sequential( + nn.Linear(self.music2vec.config.hidden_size, 256), + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(256, num_classes) + ) + + def forward(self, input_values): + input_values = input_values.squeeze(1) # Ensure shape [batch, time] + + with torch.no_grad(): + outputs = self.music2vec(input_values, output_hidden_states=True) + hidden_states = torch.stack(outputs.hidden_states) + time_reduced = hidden_states.mean(dim=2) + time_reduced = time_reduced.permute(1, 0, 2) + weighted_avg = self.conv1d(time_reduced).squeeze(1) + + return self.classifier(weighted_avg), weighted_avg + + def unfreeze_feature_extractor(self): + for param in self.music2vec.parameters(): + param.requires_grad = True + +class Music2VecFeatureExtractor(nn.Module): + def __init__(self, freeze_feature_extractor=True): + super(Music2VecFeatureExtractor, self).__init__() + self.processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h") + self.music2vec = Data2VecAudioModel.from_pretrained("m-a-p/music2vec-v1") + + if freeze_feature_extractor: + for param in self.music2vec.parameters(): + param.requires_grad = False + + # Conv1d for learnable weighted average across layers + self.conv1d = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1) + + def forward(self, input_values): + # input_values: [batch, time] + input_values = input_values.squeeze(1) + with torch.no_grad(): + outputs = self.music2vec(input_values, output_hidden_states=True) + hidden_states = torch.stack(outputs.hidden_states) # [num_layers, batch, time, hidden_dim] + time_reduced = hidden_states.mean(dim=2) # [num_layers, batch, hidden_dim] + time_reduced = time_reduced.permute(1, 0, 2) # [batch, num_layers, hidden_dim] + weighted_avg = self.conv1d(time_reduced).squeeze(1) # [batch, hidden_dim] + return weighted_avg + +''' +music2vec+CCV +# ''' +# import torch +# import torch.nn as nn +# from transformers import Data2VecAudioModel, Wav2Vec2Processor +# import torch.nn.functional as F + + +# ### Music2Vec Feature Extractor (Pretrained Model) +# class Music2VecFeatureExtractor(nn.Module): +# def __init__(self, freeze_feature_extractor=True): +# super(Music2VecFeatureExtractor, self).__init__() + +# self.processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h") +# self.music2vec = Data2VecAudioModel.from_pretrained("m-a-p/music2vec-v1") + +# if freeze_feature_extractor: +# for param in self.music2vec.parameters(): +# param.requires_grad = False + +# # Conv1d for learnable weighted average across layers +# self.conv1d = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1) + +# def forward(self, input_values): +# with torch.no_grad(): +# outputs = self.music2vec(input_values, output_hidden_states=True) + +# hidden_states = torch.stack(outputs.hidden_states) # [13, batch, time, hidden_size] +# time_reduced = hidden_states.mean(dim=2) # 평균 풀링: [13, batch, hidden_size] +# time_reduced = time_reduced.permute(1, 0, 2) # [batch, 13, hidden_size] +# weighted_avg = self.conv1d(time_reduced).squeeze(1) # [batch, hidden_size] + +# return weighted_avg # Extracted feature representation + + +# def unfreeze_feature_extractor(self): +# for param in self.music2vec.parameters(): +# param.requires_grad = True # Unfreeze for Fine-tuning + +# ### CNN Feature Extractor for CCV +class CNNEncoder(nn.Module): + def __init__(self, embed_dim=512): + super(CNNEncoder, self).__init__() + self.conv_block = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d((2,1)), # 기존 MaxPool2d(2)를 MaxPool2d((2,1))으로 변경 + nn.Conv2d(16, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d((1,1)), # 추가된 MaxPool2d(1,1)로 크기 유지 + nn.AdaptiveAvgPool2d((4, 4)) # 최종 크기 조정 + ) + self.projection = nn.Linear(32 * 4 * 4, embed_dim) + + def forward(self, x): + # print(f"Input shape before CNNEncoder: {x.shape}") # 디버깅용 출력 + x = self.conv_block(x) + B, C, H, W = x.shape + x = x.view(B, -1) + x = self.projection(x) + return x + + +### Cross-Attention Module +class CrossAttentionLayer(nn.Module): + def __init__(self, embed_dim, num_heads): + super(CrossAttentionLayer, self).__init__() + self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) + self.layer_norm = nn.LayerNorm(embed_dim) + self.feed_forward = nn.Sequential( + nn.Linear(embed_dim, embed_dim * 4), + nn.ReLU(), + nn.Linear(embed_dim * 4, embed_dim) + ) + self.attention_weights = None + + def forward(self, x, cross_input): + attn_output, attn_weights = self.multihead_attn(query=x, key=cross_input, value=cross_input) + self.attention_weights = attn_weights + x = self.layer_norm(x + attn_output) + feed_forward_output = self.feed_forward(x) + x = self.layer_norm(x + feed_forward_output) + return x + +### Cross-Attention Transformer +class CrossAttentionViT(nn.Module): + def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): + super(CrossAttentionViT, self).__init__() + + self.cross_attention_layers = nn.ModuleList([ + CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers) + ]) + + encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + self.classifier = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, num_classes) + ) + + def forward(self, x, cross_attention_input): + self.attention_maps = [] + for layer in self.cross_attention_layers: + x = layer(x, cross_attention_input) + self.attention_maps.append(layer.attention_weights) + + x = x.unsqueeze(1).permute(1, 0, 2) + x = self.transformer(x) + x = x.mean(dim=0) + x = self.classifier(x) + return x + +### CCV Model (Final Classifier) +# class CCV(nn.Module): +# def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True): +# super(CCV, self).__init__() + +# self.music2vec_extractor = Music2VecClassifier(freeze_feature_extractor=freeze_feature_extractor) + +# # CNN Encoder for Image Representation +# self.encoder = CNNEncoder(embed_dim=embed_dim) + +# # Transformer with Cross-Attention +# self.decoder = CrossAttentionViT(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes) + +# def forward(self, x, cross_attention_input=None): +# x = self.music2vec_extractor(x) +# # print(f"After Music2VecExtractor: {x.shape}") # (batch, 2) 출력됨 + +# # CNNEncoder가 기대하는 입력 크기 맞추기 +# x = x.unsqueeze(1).unsqueeze(-1) # (batch, 1, 2, 1) 형태로 변환 +# # print(f"Before CNNEncoder: {x.shape}") # CNN 입력 확인 + +# x = self.encoder(x) + +# if cross_attention_input is None: +# cross_attention_input = x + +# x = self.decoder(x, cross_attention_input) + +# return x + +class CCV(nn.Module): + def __init__(self, embed_dim=768, num_heads=8, num_layers=6, num_classes=2, freeze_feature_extractor=True): + super(CCV, self).__init__() + self.feature_extractor = Music2VecFeatureExtractor(freeze_feature_extractor=freeze_feature_extractor) + + # Cross-Attention Transformer + self.cross_attention_layers = nn.ModuleList([ + CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers) + ]) + + # Transformer Encoder + encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + # Classification Head + self.classifier = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, num_classes) + ) + + def forward(self, input_values): + # Extract feature embeddings + features = self.feature_extractor(input_values) # [batch, feature_dim] + # Average over layer dimension if necessary (여기서는 이미 [batch, hidden_dim]) + # Apply Cross-Attention Layers + for layer in self.cross_attention_layers: + features = layer(features.unsqueeze(1), features.unsqueeze(1)).squeeze(1) + # Transformer Encoding + encoded = self.transformer(features.unsqueeze(1)) + encoded = encoded.mean(dim=1) + # Classification Head + logits = self.classifier(encoded) + return logits + + def get_attention_maps(self): + # 만약 CrossAttentionLayer의 attention_maps를 사용하고 싶다면 구현 + return None diff --git a/ISMIR_2025/music2vec/test.py b/ISMIR_2025/music2vec/test.py new file mode 100644 index 0000000000000000000000000000000000000000..010c83dcd2e5bb47d7cd79fcb613ee30e5e3d6c1 --- /dev/null +++ b/ISMIR_2025/music2vec/test.py @@ -0,0 +1,119 @@ +import os +import torch +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt +from torch.utils.data import DataLoader +from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, confusion_matrix +from datalib import ( + FakeMusicCapsDataset, + closed_test_files, closed_test_labels, + open_test_files, open_test_labels, + val_files, val_labels +) +from networks import Music2VecClassifier +import argparse + +''' +python3 test.py --gpu 1 --closed_test --ckpt_path "" +''' +parser = argparse.ArgumentParser(description="AI Music Detection Testing with Music2Vec") +parser.add_argument('--gpu', type=str, default='1', help='GPU ID') +parser.add_argument('--batch_size', type=int, default=32, help='Batch size') +parser.add_argument('--ckpt_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/music2vec/ckpt/music2vec_pretrain_10.pth', help='Checkpoint directory') +parser.add_argument('--model_name', type=str, default="music2vec", help="Model name") +parser.add_argument('--closed_test', action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)") +parser.add_argument('--open_test', action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)") +parser.add_argument('--output_path', type=str, default='', help='Path to save test results') + +args = parser.parse_args() +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def plot_confusion_matrix(y_true, y_pred, classes, output_path): + cm = confusion_matrix(y_true, y_pred) + fig, ax = plt.subplots(figsize=(6, 6)) + im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) + ax.figure.colorbar(im, ax=ax) + + num_classes = cm.shape[0] + tick_labels = classes[:num_classes] + + ax.set(xticks=np.arange(num_classes), + yticks=np.arange(num_classes), + xticklabels=tick_labels, + yticklabels=tick_labels, + ylabel='True label', + xlabel='Predicted label') + + thresh = cm.max() / 2. + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cm[i, j], 'd'), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + + fig.tight_layout() + plt.savefig(output_path) + plt.close(fig) + +model = Music2VecClassifier().to(device) + +ckpt_file = os.path.join(args.ckpt_path) +if not os.path.exists(ckpt_file): + raise FileNotFoundError(f"Checkpoint not found: {ckpt_file}") + +print(f"\nLoading model from {ckpt_file}") +model.load_state_dict(torch.load(ckpt_file, map_location=device)) +model.eval() + +torch.cuda.empty_cache() + +if args.closed_test: + print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...") + test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels, target_duration=10.0) +elif args.open_test: + print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...") + test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels, target_duration=10.0) +else: + print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...") + test_dataset = FakeMusicCapsDataset(val_files, val_labels, target_duration=10.0) + +test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8) + +def Test(model, test_loader, device): + model.eval() + test_loss, test_correct, test_total = 0, 0, 0 + all_preds, all_labels = [], [] + + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + loss = F.cross_entropy(output, target) + + test_loss += loss.item() * data.size(0) + preds = output.argmax(dim=1) + test_correct += (preds == target).sum().item() + test_total += target.size(0) + + all_labels.extend(target.cpu().numpy()) + all_preds.extend(preds.cpu().numpy()) + + test_loss /= test_total + test_acc = test_correct / test_total + test_bal_acc = balanced_accuracy_score(all_labels, all_preds) + test_precision = precision_score(all_labels, all_preds, average="binary") + test_recall = recall_score(all_labels, all_preds, average="binary") + test_f1 = f1_score(all_labels, all_preds, average="binary") + + print(f"\nTest Results - Loss: {test_loss:.4f} | Test Acc: {test_acc:.3f} | " + f"Test B_ACC: {test_bal_acc:.4f} | Test Prec: {test_precision:.3f} | " + f"Test Rec: {test_recall:.3f} | Test F1: {test_f1:.3f}") + + os.makedirs(args.output_path, exist_ok=True) + conf_matrix_path = os.path.join(args.output_path, f"confusion_matrix_{args.model_name}.png") + plot_confusion_matrix(all_labels, all_preds, classes=["real", "generative"], output_path=conf_matrix_path) + +print("\nEvaluating Model on Test Set...") +Test(model, test_loader, device) diff --git a/ISMIR_2025/wav2vec/__pycache__/datalib.cpython-311.pyc b/ISMIR_2025/wav2vec/__pycache__/datalib.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d01365bf65abec49d0a9dc587356b45c18bc946 Binary files /dev/null and b/ISMIR_2025/wav2vec/__pycache__/datalib.cpython-311.pyc differ diff --git a/ISMIR_2025/wav2vec/__pycache__/loss.cpython-311.pyc b/ISMIR_2025/wav2vec/__pycache__/loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..784044f606e8cc8e5530ac9e685e3bbde15466ad Binary files /dev/null and b/ISMIR_2025/wav2vec/__pycache__/loss.cpython-311.pyc differ diff --git a/ISMIR_2025/wav2vec/__pycache__/networks.cpython-311.pyc b/ISMIR_2025/wav2vec/__pycache__/networks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3098268ad5d0adf9a55381ccd3f9d04fe59a86e3 Binary files /dev/null and b/ISMIR_2025/wav2vec/__pycache__/networks.cpython-311.pyc differ diff --git a/ISMIR_2025/wav2vec/__pycache__/networks.cpython-312.pyc b/ISMIR_2025/wav2vec/__pycache__/networks.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66822dd6b37c07f737cfe0cf7e15e8f67c17ccd0 Binary files /dev/null and b/ISMIR_2025/wav2vec/__pycache__/networks.cpython-312.pyc differ diff --git a/ISMIR_2025/wav2vec/__pycache__/wav2vec_datalib.cpython-311.pyc b/ISMIR_2025/wav2vec/__pycache__/wav2vec_datalib.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea8412e40b9fecb971b99a19a4888affcd9a93d0 Binary files /dev/null and b/ISMIR_2025/wav2vec/__pycache__/wav2vec_datalib.cpython-311.pyc differ diff --git a/ISMIR_2025/wav2vec/datalib.py b/ISMIR_2025/wav2vec/datalib.py new file mode 100644 index 0000000000000000000000000000000000000000..af0d13fe30e562a272c953e58c6579c581caf372 --- /dev/null +++ b/ISMIR_2025/wav2vec/datalib.py @@ -0,0 +1,139 @@ +import os +import glob +import random +import torch +import librosa +import numpy as np +import utils +from sklearn.model_selection import train_test_split +from torch.utils.data import Dataset, DataLoader +import scipy.signal as signal +import scipy.signal +from scipy.signal import butter, lfilter +import numpy as np +import scipy.signal as signal +import librosa +import torch +import random +from torch.utils.data import Dataset +import logging +import csv +import logging +import time +import numpy as np +import h5py +import torch +import torchaudio +from imblearn.over_sampling import RandomOverSampler +from networks import Wav2Vec2ForFakeMusic +from transformers import Wav2Vec2Processor +import torchaudio.transforms as T + +class FakeMusicCapsDataset(Dataset): + def __init__(self, file_paths, labels, sr=16000, target_duration=10.0): + self.file_paths = file_paths + self.labels = labels + self.sr = sr + self.target_duration = target_duration + self.target_samples = int(target_duration * sr) + + self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base") + + def highpass_filter(self, y, sr, cutoff=500, order=5): + if isinstance(sr, np.ndarray): + sr = np.mean(sr) + if not isinstance(sr, (int, float)): + raise ValueError(f"[ERROR] sr must be a number, but got {type(sr)}: {sr}") + if sr <= 0: + raise ValueError(f"Invalid sample rate: {sr}. It must be greater than 0.") + nyquist = 0.5 * sr + if cutoff <= 0 or cutoff >= nyquist: + print(f"[WARNING] Invalid cutoff frequency {cutoff}, adjusting...") + cutoff = max(10, min(cutoff, nyquist - 1)) + normal_cutoff = cutoff / nyquist + b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) + y_filtered = signal.lfilter(b, a, y) + return y_filtered + + def __len__(self): + return len(self.file_paths) + + def __getitem__(self, idx): + audio_path = self.file_paths[idx] + label = self.labels[idx] + + waveform, sr = torchaudio.load(audio_path) + waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sr)(waveform) + + waveform = waveform.squeeze(0) + if label == 0: + waveform = self.augment_audio(waveform, self.sr) + if label == 1: + waveform = self.highpass_filter(waveform, self.sr) + + current_samples = waveform.shape[0] + if current_samples > self.target_samples: + start_idx = (current_samples - self.target_samples) // 2 + waveform = waveform[start_idx:start_idx + self.target_samples] + elif current_samples < self.target_samples: + waveform = torch.nn.functional.pad(waveform, (0, self.target_samples - current_samples)) + + waveform = torch.tensor(waveform, dtype=torch.float32).unsqueeze(0) + label = torch.tensor(label, dtype=torch.long) + + return waveform, label + +def preprocess_audio(audio_path, target_sr=16000, target_duration=10.0): + waveform, sr = librosa.load(audio_path, sr=target_sr) + + target_samples = int(target_duration * target_sr) + current_samples = len(waveform) + + if current_samples > target_samples: + waveform = waveform[:target_samples] + elif current_samples < target_samples: + waveform = np.pad(waveform, (0, target_samples - current_samples)) + + waveform = torch.tensor(waveform).unsqueeze(0) + return waveform + + +DATASET_PATH = "/data/kym/AI_Music_Detection/audio/FakeMusicCaps" +SUNOCAPS_PATH = "/data/kym/Audio/SunoCaps" # Open Set 포함 데이터 + +real_files = glob.glob(os.path.join(DATASET_PATH, "real", "**", "*.wav"), recursive=True) +gen_files = glob.glob(os.path.join(DATASET_PATH, "generative", "**", "*.wav"), recursive=True) + +open_real_files = real_files + glob.glob(os.path.join(SUNOCAPS_PATH, "real", "**", "*.wav"), recursive=True) +open_gen_files = gen_files + glob.glob(os.path.join(SUNOCAPS_PATH, "generative", "**", "*.wav"), recursive=True) + +real_labels = [0] * len(real_files) +gen_labels = [1] * len(gen_files) + +open_real_labels = [0] * len(open_real_files) +open_gen_labels = [1] * len(open_gen_files) + +real_train, real_val, real_train_labels, real_val_labels = train_test_split(real_files, real_labels, test_size=0.2, random_state=42) +gen_train, gen_val, gen_train_labels, gen_val_labels = train_test_split(gen_files, gen_labels, test_size=0.2, random_state=42) + +train_files = real_train + gen_train +train_labels = real_train_labels + gen_train_labels +val_files = real_val + gen_val +val_labels = real_val_labels + gen_val_labels + +closed_test_files = real_files + gen_files +closed_test_labels = real_labels + gen_labels + +open_test_files = open_real_files + open_gen_files +open_test_labels = open_real_labels + open_gen_labels + +ros = RandomOverSampler(sampling_strategy='auto', random_state=42) +train_files_resampled, train_labels_resampled = ros.fit_resample(np.array(train_files).reshape(-1, 1), train_labels) + +train_files = train_files_resampled.reshape(-1).tolist() +train_labels = train_labels_resampled + +print(f"Train Original FAKE: {len(gen_train)}") +print(f"Train set (Oversampled) - REAL: {sum(1 for label in train_labels if label == 0)}, " + f"FAKE: {sum(1 for label in train_labels if label == 1)}, Total: {len(train_files)}") +print(f"Validation set - REAL: {len(real_val)}, FAKE: {len(gen_val)}, Total: {len(val_files)}") diff --git a/ISMIR_2025/wav2vec/inference.py b/ISMIR_2025/wav2vec/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..d0951f644b839155028b6b5fcd2811fd0126cbbf --- /dev/null +++ b/ISMIR_2025/wav2vec/inference.py @@ -0,0 +1,71 @@ +import os +import torch +import torch.nn.functional as F +import torchaudio +import argparse +from AI_Music_Detection.Code.model.wav2vec.wav2vec_datalib import preprocess_audio +from networks import Wav2Vec2ForFakeMusic + +''' +command: python inference.py --gpu 0 --model_type pretrain --inference .wav +''' +parser = argparse.ArgumentParser(description="Wav2Vec2 AI Music Detection Inference") +parser.add_argument('--gpu', type=str, default='0', help='GPU ID') +parser.add_argument('--model_name', type=str, choices=['Wav2Vec2ForFakeMusic'], default='Wav2Vec2ForFakeMusic', help='Model name') +parser.add_argument('--ckpt_path', type=str, default='/data/kym/AI_Music_Detection/Code/model/wav2vec/ckpt/', help='Checkpoint directory') +parser.add_argument('--model_type', type=str, choices=['pretrain', 'finetune'], required=True, help='Choose between pretrained or fine-tuned model') +parser.add_argument('--inference', type=str, help='Path to a .wav file for inference') +args = parser.parse_args() +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +if args.model_type == 'pretrain': + model_file = os.path.join(args.ckpt_path, "wav2vec2_pretrain_10.pth") +elif args.model_type == 'finetune': + model_file = os.path.join(args.ckpt_path, "wav2vec2_finetune_5.pth") +else: + raise ValueError("Invalid model type. Choose between 'pretrain' or 'finetune'.") + +if not os.path.exists(model_file): + raise FileNotFoundError(f"Model checkpoint not found: {model_file}") + +if args.model_name == 'Wav2Vec2ForFakeMusic': + model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=(args.model_type == 'finetune')) +else: + raise ValueError(f"Invalid model name: {args.model_name}") + +def predict(audio_path): + print(f"\n🔍 Loading model from {model_file}") + + if not os.path.exists(audio_path): + raise FileNotFoundError(f"[ERROR] Audio file not found: {audio_path}") + + model.to(device) + + input_tensor = preprocess_audio(audio_path).to(device) + print(f"Input shape after preprocessing: {input_tensor.shape}") + + with torch.no_grad(): + output = model(input_tensor) + print(f"Raw model output (logits): {output}") + + probabilities = F.softmax(output, dim=1) + ai_music_prob = probabilities[0, 1].item() + + print(f"Softmax Probabilities: {probabilities}") + print(f"AI Music Probability: {ai_music_prob:.4f}") + + if ai_music_prob > 0.5: + print(f" FAKE MUSIC DETECTED ({ai_music_prob:.2%})") + else: + print(f" REAL MUSIC DETECTED ({100 - ai_music_prob * 100:.2f}%)") + + + +if __name__ == "__main__": + if args.inference: + if not os.path.exists(args.inference): + print(f"[ERROR] No File Found: {args.inference}") + else: + predict(args.inference) + diff --git a/ISMIR_2025/wav2vec/main.py b/ISMIR_2025/wav2vec/main.py new file mode 100644 index 0000000000000000000000000000000000000000..6b0fca7906256ada966b563a613bf6842c475c18 --- /dev/null +++ b/ISMIR_2025/wav2vec/main.py @@ -0,0 +1,162 @@ +import os +import random +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm +from torch.utils.data import DataLoader +from sklearn.metrics import f1_score, precision_score, recall_score, balanced_accuracy_score, classification_report +import wandb +import argparse +from datalib import FakeMusicCapsDataset, train_files, train_labels, val_files, val_labels +from networks import Wav2Vec2ForFakeMusic + +''' +python inference.py --gpu 0 --model_type finetune --inference +''' +parser = argparse.ArgumentParser(description='AI Music Detection Training with Wav2Vec 2.0') +parser.add_argument('--gpu', type=str, default='2', help='GPU ID') +parser.add_argument('--batch_size', type=int, default=32, help='Batch size') +parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate') +parser.add_argument('--pretrain_epochs', type=int, default=20, help='Pretraining epochs (REAL data only)') +parser.add_argument('--finetune_epochs', type=int, default=10, help='Fine-tuning epochs (REAL + FAKE data)') +parser.add_argument('--checkpoint_dir', type=str, default='', help='Checkpoint directory') +parser.add_argument('--weight_decay', type=float, default=0.05, help="Weight decay for optimizer") + +args = parser.parse_args() + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +torch.manual_seed(42) +random.seed(42) +np.random.seed(42) + +wandb.init(project="", name=f"pretrain_{args.pretrain_epochs}_finetune_{args.finetune_epochs}", config=args) + +print("Preparing datasets...") +train_dataset = FakeMusicCapsDataset(train_files, train_labels) +val_dataset = FakeMusicCapsDataset(val_files, val_labels) + +train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4) +val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4) + +pretrain_ckpt = os.path.join(args.checkpoint_dir, f"wav2vec2_pretrain_{args.pretrain_epochs}.pth") +finetune_ckpt = os.path.join(args.checkpoint_dir, f"wav2vec2_finetune_{args.finetune_epochs}.pth") + +print("Initializing model...") +model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=True).to(device) + +criterion = nn.CrossEntropyLoss() +optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) +scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) + +def train(model, dataloader, optimizer, criterion, scheduler, device, epoch, phase="Pretrain"): + model.train() + total_loss, total_correct, total_samples = 0, 0, 0 + all_preds, all_labels = [], [] + attention_maps = [] + + for inputs, labels in tqdm(dataloader, desc=f"{phase} Training Epoch {epoch+1}"): + inputs, labels = inputs.to(device), labels.to(device) + inputs = inputs.float() + + outputs = model(inputs) + loss = criterion(outputs, labels) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + preds = outputs.argmax(dim=1) + total_correct += (preds == labels).sum().item() + total_samples += labels.size(0) + + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + + if hasattr(model, "get_attention_maps"): + attention_maps.append(model.get_attention_maps()) + + scheduler.step() + + accuracy = total_correct / total_samples + f1 = f1_score(all_labels, all_preds, average="weighted") + precision = precision_score(all_labels, all_preds, average="binary") + recall = recall_score(all_labels, all_preds, average="binary") + balanced_acc = balanced_accuracy_score(all_labels, all_preds) + + wandb.log({ + f"{phase} Train Loss": total_loss / len(dataloader), + f"{phase} Train Accuracy": accuracy, + f"{phase} Train F1 Score": f1, + f"{phase} Train Precision": precision, + f"{phase} Train Recall": recall, + f"{phase} Train Balanced Accuracy": balanced_acc, + }) + + print(f"{phase} Train Epoch {epoch+1}: Train Loss: {total_loss / len(dataloader):.4f}, " + f"Train Acc: {accuracy:.4f}, Train F1: {f1:.4f}, Train Prec: {precision:.4f}, Train Rec: {recall:.4f}, B_ACC: {balanced_acc:.4f}") + +def validate(model, dataloader, criterion, device, phase="Validation"): + model.eval() + total_loss, total_correct, total_samples = 0, 0, 0 + all_preds, all_labels = [], [] + + with torch.no_grad(): + for inputs, labels in tqdm(dataloader, desc=f"{phase}"): + inputs, labels = inputs.to(device), labels.to(device) + inputs = inputs.squeeze(1) + + outputs = model(inputs) + loss = criterion(outputs, labels) + + + total_loss += loss.item() + preds = outputs.argmax(dim=1) + total_correct += (preds == labels).sum().item() + total_samples += labels.size(0) + + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + + accuracy = total_correct / total_samples + f1 = f1_score(all_labels, all_preds, average="weighted") + val_bal_acc = balanced_accuracy_score(all_labels, all_preds) + val_precision = precision_score(all_labels, all_preds, average="binary") + val_recall = recall_score(all_labels, all_preds, average="binary") + + wandb.log({ + f"{phase} Val Loss": total_loss / len(dataloader), + f"{phase} Val Accuracy": accuracy, + f"{phase} Val F1 Score": f1, + f"{phase} Val Precision": val_precision, + f"{phase} Val Recall": val_recall, + f"{phase} Val Balanced Accuracy": val_bal_acc, + }) + print(f"{phase} Val Loss: {total_loss / len(dataloader):.4f}, " + f"Val Acc: {accuracy:.4f}, Val F1: {f1:.4f}, Val Prec: {val_precision:.4f}, Val Rec: {val_recall:.4f}, Val B_ACC: {val_bal_acc:.4f}") + return total_loss / len(dataloader), accuracy, f1 + +print("\nStep 1: Self-Supervised Pretraining on REAL Data") +for epoch in range(args.pretrain_epochs): + train(model, train_loader, optimizer, criterion, scheduler, device, epoch, phase="Pretrain") + +torch.save(model.state_dict(), pretrain_ckpt) +print(f"\nPretraining completed! Model saved at: {pretrain_ckpt}") + +model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=False).to(device) +model.load_state_dict(torch.load(pretrain_ckpt)) +print(f"\n🔍 Loaded Pretrained Model from {pretrain_ckpt}") + +optimizer = optim.Adam(model.parameters(), lr=args.learning_rate / 10, weight_decay=args.weight_decay) + +print("\nStep 2: Fine-Tuning on REAL + FAKE Data") +for epoch in range(args.finetune_epochs): + train(model, train_loader, optimizer, criterion, scheduler, device, epoch, phase="Fine-Tune") + validate(model, val_loader, criterion, device, phase="Fine-Tune Validation") + +torch.save(model.state_dict(), finetune_ckpt) +print(f"\nFine-Tuning completed! Model saved at: {finetune_ckpt}") diff --git a/ISMIR_2025/wav2vec/networks.py b/ISMIR_2025/wav2vec/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..63634b0f5323a5ae7cd0d2618c448c8ad261a7e7 --- /dev/null +++ b/ISMIR_2025/wav2vec/networks.py @@ -0,0 +1,161 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import matplotlib.pyplot as plt +import seaborn as sns + +''' +freeze_feature_extractor=True 시 Feature Extractor를 동결 (Pretraining) +unfreeze_feature_extractor()를 호출하면 Fine-Tuning 가능 +''' +import torch +import torch.nn as nn +import torch.nn.functional as F +import matplotlib.pyplot as plt +import seaborn as sns +from transformers import Wav2Vec2Model + +class cnn(nn.Module): + def __init__(self, embed_dim=512): + super(cnn, self).__init__() + self.conv_block = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Conv2d(16, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.AdaptiveAvgPool2d((4, 4)) + ) + self.projection = nn.Linear(32 * 4 * 4, embed_dim) + + def forward(self, x): + x = self.conv_block(x) + B, C, H, W = x.shape + x = x.view(B, -1) + x = self.projection(x) + return x + +class CrossAttentionLayer(nn.Module): + def __init__(self, embed_dim, num_heads): + super(CrossAttentionLayer, self).__init__() + self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) + self.layer_norm = nn.LayerNorm(embed_dim) + self.feed_forward = nn.Sequential( + nn.Linear(embed_dim, embed_dim * 4), + nn.ReLU(), + nn.Linear(embed_dim * 4, embed_dim) + ) + self.attention_weights = None + + def forward(self, x, cross_input): + # Cross-attention between x and cross_input + attn_output, attn_weights = self.multihead_attn(query=x, key=cross_input, value=cross_input) + self.attention_weights = attn_weights + x = self.layer_norm(x + attn_output) + feed_forward_output = self.feed_forward(x) + x = self.layer_norm(x + feed_forward_output) + return x + +class CrossAttentionViT(nn.Module): + def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): + super(CrossAttentionViT, self).__init__() + + self.cross_attention_layers = nn.ModuleList([ + CrossAttentionLayer(embed_dim, num_heads) for _ in range(num_layers) + ]) + + encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads) + self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) + + self.classifier = nn.Sequential( + nn.LayerNorm(embed_dim), + nn.Linear(embed_dim, num_classes) + ) + + def forward(self, x, cross_attention_input): + self.attention_maps = [] + for layer in self.cross_attention_layers: + x = layer(x, cross_attention_input) + self.attention_maps.append(layer.attention_weights) + + x = x.unsqueeze(1).permute(1, 0, 2) + x = self.transformer(x) + x = x.mean(dim=0) + x = self.classifier(x) + return x + +class CCV(nn.Module): + def __init__(self, embed_dim=512, num_heads=8, num_layers=6, num_classes=2): + super(CCV, self).__init__() + self.encoder = cnn(embed_dim=embed_dim) + self.decoder = CrossAttentionViT(embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers, num_classes=num_classes) + + def forward(self, x, cross_attention_input=None): + x = self.encoder(x) + + if cross_attention_input is None: + cross_attention_input = x + + x = self.decoder(x, cross_attention_input) + + # Attention Map 저장 + self.attention_maps = self.decoder.attention_maps + + return x + + def get_attention_maps(self): + return self.attention_maps + +import torch +import torch.nn as nn +from transformers import Wav2Vec2Model + +class Wav2Vec2ForFakeMusic(nn.Module): + def __init__(self, num_classes=2, freeze_feature_extractor=True): + super(Wav2Vec2ForFakeMusic, self).__init__() + + self.wav2vec = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") + + if freeze_feature_extractor: + for param in self.wav2vec.parameters(): + param.requires_grad = False + + self.classifier = nn.Sequential( + nn.Linear(self.wav2vec.config.hidden_size, 256), # 768 → 256 + nn.ReLU(), + nn.Dropout(0.3), + nn.Linear(256, num_classes) # 256 → 2 (Binary Classification) + ) + + def forward(self, x): + x = x.squeeze(1) + output = self.wav2vec(x) + features = output["last_hidden_state"] # (batch_size, seq_len, feature_dim) + pooled_features = features.mean(dim=1) # ✅ Mean Pooling 적용 (batch_size, feature_dim) + logits = self.classifier(pooled_features) # (batch_size, num_classes) + + return logits, pooled_features + + +def visualize_attention_map(attn_map, mel_spec, layer_idx): + attn_map = attn_map.mean(dim=1).squeeze().cpu().numpy() # 여러 head 평균 + mel_spec = mel_spec.squeeze().cpu().numpy() + + fig, axs = plt.subplots(2, 1, figsize=(10, 8)) + + # 1Log-Mel Spectrogram 시각화 + sns.heatmap(mel_spec, cmap='inferno', ax=axs[0]) + axs[0].set_title("Log-Mel Spectrogram") + axs[0].set_xlabel("Time Frames") + axs[0].set_ylabel("Mel Frequency Bins") + + # Attention Map 시각화 + sns.heatmap(attn_map, cmap='viridis', ax=axs[1]) + axs[1].set_title(f"Attention Map (Layer {layer_idx})") + axs[1].set_xlabel("Time Frames") + axs[1].set_ylabel("Query Positions") + + plt.tight_layout() + plt.show() + plt.savefig("/data/kym/AI_Music_Detection/Code/model/attention_map/crossattn.png") diff --git a/ISMIR_2025/wav2vec/test.py b/ISMIR_2025/wav2vec/test.py new file mode 100644 index 0000000000000000000000000000000000000000..d74d07983c843a247b1a15a4112d67f82ef34521 --- /dev/null +++ b/ISMIR_2025/wav2vec/test.py @@ -0,0 +1,148 @@ +import os +import torch +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt +from torch.utils.data import DataLoader +from sklearn.metrics import precision_score, recall_score, f1_score, balanced_accuracy_score, confusion_matrix +from datalib import ( + FakeMusicCapsDataset, + closed_test_files, closed_test_labels, + open_test_files, open_test_labels, + val_files, val_labels +) +from networks import Wav2Vec2ForFakeMusic +import tqdm +from tqdm import tqdm +import argparse +''' +python3 test.py --finetune_test --closed_test | --open_test +''' +parser = argparse.ArgumentParser(description="AI Music Detection Testing with Wav2Vec 2.0") +parser.add_argument('--gpu', type=str, default='0', help='GPU ID') +parser.add_argument('--batch_size', type=int, default=32, help='Batch size') +parser.add_argument('--ckpt_path', type=str, default='', help='Checkpoint directory') +parser.add_argument('--pretrain_test', action="store_true", help="Test Pretrained Wav2Vec2 Model") +parser.add_argument('--finetune_test', action="store_true", help="Test Fine-Tuned Wav2Vec2 Model") +parser.add_argument('--closed_test', action="store_true", help="Use Closed Test (FakeMusicCaps full dataset)") +parser.add_argument('--open_test', action="store_true", help="Use Open Set Test (SUNOCAPS_PATH included)") +parser.add_argument('--output_path', type=str, default='', help='Path to save test results') + +args = parser.parse_args() +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def plot_confusion_matrix(y_true, y_pred, classes, output_path): + cm = confusion_matrix(y_true, y_pred) + fig, ax = plt.subplots(figsize=(6, 6)) + im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) + ax.figure.colorbar(im, ax=ax) + + num_classes = cm.shape[0] + tick_labels = classes[:num_classes] + + ax.set(xticks=np.arange(num_classes), + yticks=np.arange(num_classes), + xticklabels=tick_labels, + yticklabels=tick_labels, + ylabel='True label', + xlabel='Predicted label') + + thresh = cm.max() / 2. + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cm[i, j], 'd'), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + + fig.tight_layout() + plt.savefig(output_path) + plt.close(fig) + +if args.pretrain_test: + ckpt_file = os.path.join(args.ckpt_path, "wav2vec2_pretrain_20.pth") + print("\n🔍 Loading Pretrained Model:", ckpt_file) + model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=True).to(device) + +elif args.finetune_test: + ckpt_file = os.path.join(args.ckpt_path, "wav2vec2_finetune_10.pth") + print("\n🔍 Loading Fine-Tuned Model:", ckpt_file) + model = Wav2Vec2ForFakeMusic(num_classes=2, freeze_feature_extractor=False).to(device) + +else: + raise ValueError("You must specify --pretrain_test or --finetune_test") + +if not os.path.exists(ckpt_file): + raise FileNotFoundError(f"Checkpoint not found: {ckpt_file}") + +# model.load_state_dict(torch.load(ckpt_file, map_location=device)) +# model.eval() + +ckpt = torch.load(ckpt_file, map_location=device) + +keys_to_remove = [key for key in ckpt.keys() if "masked_spec_embed" in key] +for key in keys_to_remove: + print(f"Removing unexpected key: {key}") + del ckpt[key] + +try: + model.load_state_dict(ckpt, strict=False) +except RuntimeError as e: + print("Model loading error:", e) + print("Trying to load entire model...") + model = torch.load(ckpt_file, map_location=device) +model.to(device) +model.eval() + +torch.cuda.empty_cache() + +if args.closed_test: + print("\nRunning Closed Test (FakeMusicCaps Full Dataset)...") + test_dataset = FakeMusicCapsDataset(closed_test_files, closed_test_labels) +elif args.open_test: + print("\nRunning Open Set Test (FakeMusicCaps + SunoCaps)...") + test_dataset = FakeMusicCapsDataset(open_test_files, open_test_labels) +else: + print("\nRunning Validation Test (FakeMusicCaps 20% Validation Set)...") + test_dataset = FakeMusicCapsDataset(val_files, val_labels) + +test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16) + +def Test(model, test_loader, device, phase="Test"): + model.eval() + test_loss, test_correct, test_total = 0, 0, 0 + all_preds, all_labels = [], [] + + with torch.no_grad(): + for inputs, labels in tqdm(test_loader, desc=f"{phase}"): + inputs, labels = inputs.to(device), labels.to(device) + inputs = inputs.squeeze(1) # Ensure correct input shape + + output = model(inputs) + loss = F.cross_entropy(output, labels) + + test_loss += loss.item() * inputs.size(0) + preds = output.argmax(dim=1) + test_correct += (preds == labels).sum().item() + test_total += labels.size(0) + + all_labels.extend(labels.cpu().numpy()) + all_preds.extend(preds.cpu().numpy()) + + test_loss /= test_total + test_acc = test_correct / test_total + test_bal_acc = balanced_accuracy_score(all_labels, all_preds) + test_precision = precision_score(all_labels, all_preds, average="binary") + test_recall = recall_score(all_labels, all_preds, average="binary") + test_f1 = f1_score(all_labels, all_preds, average="binary") + + print(f"\n{phase} Test Results - Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.3f} | " + f"Test Balanced Acc: {test_bal_acc:.4f} | Test Precision: {test_precision:.3f} | " + f"Test Recall: {test_recall:.3f} | Test F1: {test_f1:.3f}") + + os.makedirs(args.output_path, exist_ok=True) + conf_matrix_path = os.path.join(args.output_path, f"confusion_matrix_{phase}_opentest.png") + plot_confusion_matrix(all_labels, all_preds, classes=["real", "generative"], output_path=conf_matrix_path) + +print("\nEvaluating Model on Test Set...") +Test(model, test_loader, device, phase="Pretrained Model" if args.pretrain_test else "Fine-Tuned Model") diff --git a/ISMIR_2025/wav2vec/utils/__pycache__/config.cpython-311.pyc b/ISMIR_2025/wav2vec/utils/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5105baa6b5cc20b8d2fe0a9b029b40c9e6f7c6af Binary files /dev/null and b/ISMIR_2025/wav2vec/utils/__pycache__/config.cpython-311.pyc differ diff --git a/ISMIR_2025/wav2vec/utils/__pycache__/idr_torch.cpython-311.pyc b/ISMIR_2025/wav2vec/utils/__pycache__/idr_torch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb9379cca8c0e244eb2a7b169fa2ce307f2bab05 Binary files /dev/null and b/ISMIR_2025/wav2vec/utils/__pycache__/idr_torch.cpython-311.pyc differ diff --git a/ISMIR_2025/wav2vec/utils/__pycache__/utilities.cpython-311.pyc b/ISMIR_2025/wav2vec/utils/__pycache__/utilities.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97ed2e09ebbddb16a923974a1ef3c30f2259ac34 Binary files /dev/null and b/ISMIR_2025/wav2vec/utils/__pycache__/utilities.cpython-311.pyc differ diff --git a/ISMIR_2025/wav2vec/utils/config.py b/ISMIR_2025/wav2vec/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..69f72ecd472eed266bb9a0d811d7eeb07a3c06db --- /dev/null +++ b/ISMIR_2025/wav2vec/utils/config.py @@ -0,0 +1,565 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import csv + +import numpy as np + +sample_rate = 32000 +clip_samples = sample_rate * 10 # Audio clips are 10-second + +# Load label +with open( + "/gpfswork/rech/djl/uzj43um/audio_retrieval/audioset_tagging_cnn/metadata/class_labels_indices.csv", + "r", +) as f: + reader = csv.reader(f, delimiter=",") + lines = list(reader) + +labels = [] +ids = [] # Each label has a unique id such as "/m/068hy" +for i1 in range(1, len(lines)): + id = lines[i1][1] + label = lines[i1][2] + ids.append(id) + labels.append(label) + +classes_num = len(labels) + +lb_to_ix = {label: i for i, label in enumerate(labels)} +ix_to_lb = {i: label for i, label in enumerate(labels)} + +id_to_ix = {id: i for i, id in enumerate(ids)} +ix_to_id = {i: id for i, id in enumerate(ids)} + +full_samples_per_class = np.array( + [ + 937432, + 16344, + 7822, + 10271, + 2043, + 14420, + 733, + 1511, + 1258, + 424, + 1751, + 704, + 369, + 590, + 1063, + 1375, + 5026, + 743, + 853, + 1648, + 714, + 1497, + 1251, + 2139, + 1093, + 133, + 224, + 39469, + 6423, + 407, + 1559, + 4546, + 6826, + 7464, + 2468, + 549, + 4063, + 334, + 587, + 238, + 1766, + 691, + 114, + 2153, + 236, + 209, + 421, + 740, + 269, + 959, + 137, + 4192, + 485, + 1515, + 655, + 274, + 69, + 157, + 1128, + 807, + 1022, + 346, + 98, + 680, + 890, + 352, + 4169, + 2061, + 1753, + 9883, + 1339, + 708, + 37857, + 18504, + 12864, + 2475, + 2182, + 757, + 3624, + 677, + 1683, + 3583, + 444, + 1780, + 2364, + 409, + 4060, + 3097, + 3143, + 502, + 723, + 600, + 230, + 852, + 1498, + 1865, + 1879, + 2429, + 5498, + 5430, + 2139, + 1761, + 1051, + 831, + 2401, + 2258, + 1672, + 1711, + 987, + 646, + 794, + 25061, + 5792, + 4256, + 96, + 8126, + 2740, + 752, + 513, + 554, + 106, + 254, + 1592, + 556, + 331, + 615, + 2841, + 737, + 265, + 1349, + 358, + 1731, + 1115, + 295, + 1070, + 972, + 174, + 937780, + 112337, + 42509, + 49200, + 11415, + 6092, + 13851, + 2665, + 1678, + 13344, + 2329, + 1415, + 2244, + 1099, + 5024, + 9872, + 10948, + 4409, + 2732, + 1211, + 1289, + 4807, + 5136, + 1867, + 16134, + 14519, + 3086, + 19261, + 6499, + 4273, + 2790, + 8820, + 1228, + 1575, + 4420, + 3685, + 2019, + 664, + 324, + 513, + 411, + 436, + 2997, + 5162, + 3806, + 1389, + 899, + 8088, + 7004, + 1105, + 3633, + 2621, + 9753, + 1082, + 26854, + 3415, + 4991, + 2129, + 5546, + 4489, + 2850, + 1977, + 1908, + 1719, + 1106, + 1049, + 152, + 136, + 802, + 488, + 592, + 2081, + 2712, + 1665, + 1128, + 250, + 544, + 789, + 2715, + 8063, + 7056, + 2267, + 8034, + 6092, + 3815, + 1833, + 3277, + 8813, + 2111, + 4662, + 2678, + 2954, + 5227, + 1472, + 2591, + 3714, + 1974, + 1795, + 4680, + 3751, + 6585, + 2109, + 36617, + 6083, + 16264, + 17351, + 3449, + 5034, + 3931, + 2599, + 4134, + 3892, + 2334, + 2211, + 4516, + 2766, + 2862, + 3422, + 1788, + 2544, + 2403, + 2892, + 4042, + 3460, + 1516, + 1972, + 1563, + 1579, + 2776, + 1647, + 4535, + 3921, + 1261, + 6074, + 2922, + 3068, + 1948, + 4407, + 712, + 1294, + 1019, + 1572, + 3764, + 5218, + 975, + 1539, + 6376, + 1606, + 6091, + 1138, + 1169, + 7925, + 3136, + 1108, + 2677, + 2680, + 1383, + 3144, + 2653, + 1986, + 1800, + 1308, + 1344, + 122231, + 12977, + 2552, + 2678, + 7824, + 768, + 8587, + 39503, + 3474, + 661, + 430, + 193, + 1405, + 1442, + 3588, + 6280, + 10515, + 785, + 710, + 305, + 206, + 4990, + 5329, + 3398, + 1771, + 3022, + 6907, + 1523, + 8588, + 12203, + 666, + 2113, + 7916, + 434, + 1636, + 5185, + 1062, + 664, + 952, + 3490, + 2811, + 2749, + 2848, + 15555, + 363, + 117, + 1494, + 1647, + 5886, + 4021, + 633, + 1013, + 5951, + 11343, + 2324, + 243, + 372, + 943, + 734, + 242, + 3161, + 122, + 127, + 201, + 1654, + 768, + 134, + 1467, + 642, + 1148, + 2156, + 1368, + 1176, + 302, + 1909, + 61, + 223, + 1812, + 287, + 422, + 311, + 228, + 748, + 230, + 1876, + 539, + 1814, + 737, + 689, + 1140, + 591, + 943, + 353, + 289, + 198, + 490, + 7938, + 1841, + 850, + 457, + 814, + 146, + 551, + 728, + 1627, + 620, + 648, + 1621, + 2731, + 535, + 88, + 1736, + 736, + 328, + 293, + 3170, + 344, + 384, + 7640, + 433, + 215, + 715, + 626, + 128, + 3059, + 1833, + 2069, + 3732, + 1640, + 1508, + 836, + 567, + 2837, + 1151, + 2068, + 695, + 1494, + 3173, + 364, + 88, + 188, + 740, + 677, + 273, + 1533, + 821, + 1091, + 293, + 647, + 318, + 1202, + 328, + 532, + 2847, + 526, + 721, + 370, + 258, + 956, + 1269, + 1641, + 339, + 1322, + 4485, + 286, + 1874, + 277, + 757, + 1393, + 1330, + 380, + 146, + 377, + 394, + 318, + 339, + 1477, + 1886, + 101, + 1435, + 284, + 1425, + 686, + 621, + 221, + 117, + 87, + 1340, + 201, + 1243, + 1222, + 651, + 1899, + 421, + 712, + 1016, + 1279, + 124, + 351, + 258, + 7043, + 368, + 666, + 162, + 7664, + 137, + 70159, + 26179, + 6321, + 32236, + 33320, + 771, + 1169, + 269, + 1103, + 444, + 364, + 2710, + 121, + 751, + 1609, + 855, + 1141, + 2287, + 1940, + 3943, + 289, + ] +) \ No newline at end of file diff --git a/ISMIR_2025/wav2vec/utils/confusion_matrix_plot.py b/ISMIR_2025/wav2vec/utils/confusion_matrix_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..e57d6d77e51949970ea76d8400d78ed6540cc155 --- /dev/null +++ b/ISMIR_2025/wav2vec/utils/confusion_matrix_plot.py @@ -0,0 +1,29 @@ +from sklearn.metrics import confusion_matrix +import matplotlib.pyplot as plt +import numpy as np + +def plot_confusion_matrix(y_true, y_pred, classes, writer, epoch): + cm = confusion_matrix(y_true, y_pred) + fig, ax = plt.subplots(figsize=(6, 6)) + im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) + ax.figure.colorbar(im, ax=ax) + + num_classes = cm.shape[0] + tick_labels = classes[:num_classes] + + ax.set(xticks=np.arange(num_classes), + yticks=np.arange(num_classes), + xticklabels=tick_labels, + yticklabels=tick_labels, + ylabel='True label', + xlabel='Predicted label') + + thresh = cm.max() / 2. + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cm[i, j], 'd'), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + + fig.tight_layout() + writer.add_figure("Confusion Matrix", fig, epoch) \ No newline at end of file diff --git a/ISMIR_2025/wav2vec/utils/freqeuncy.py b/ISMIR_2025/wav2vec/utils/freqeuncy.py new file mode 100644 index 0000000000000000000000000000000000000000..b21c5222467ec4906c63e5b9d02052a69aeb67e2 --- /dev/null +++ b/ISMIR_2025/wav2vec/utils/freqeuncy.py @@ -0,0 +1,24 @@ +import librosa +import librosa.display +import numpy as np +import matplotlib.pyplot as plt + +# 🔹 오디오 파일 로드 +file_real = "/path/to/real_audio.wav" # Real 오디오 경로 +file_fake = "/path/to/generative_audio.wav" # AI 생성 오디오 경로 + +def plot_spectrogram(audio_file, title): + y, sr = librosa.load(audio_file, sr=16000) # 샘플링 레이트 16kHz + D = librosa.amplitude_to_db(np.abs(librosa.stft(y)), ref=np.max) # STFT 변환 + + plt.figure(figsize=(10, 4)) + librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='hz', cmap='magma') + plt.colorbar(format='%+2.0f dB') + plt.title(title) + plt.ylim(4000, 16000) # 4kHz 이상 고주파 영역만 표시 + plt.show() + +# 🔹 Real vs Generative Spectrogram 비교 +plot_spectrogram(file_real, "Real Audio Spectrogram (4kHz+)") +plot_spectrogram(file_fake, "Generative Audio Spectrogram (4kHz+)") + diff --git a/ISMIR_2025/wav2vec/utils/hf_vis.py b/ISMIR_2025/wav2vec/utils/hf_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..c99b61bfb27f99880b0c44313daf476e6c0c278f --- /dev/null +++ b/ISMIR_2025/wav2vec/utils/hf_vis.py @@ -0,0 +1,89 @@ +import librosa +import librosa.display +import numpy as np +import matplotlib.pyplot as plt +import scipy.signal as signal +import torch +import torch.nn as nn +import soundfile as sf + +from networks import audiocnn, AudioCNNWithViTDecoder, AudioCNNWithViTDecoderAndCrossAttention + + +def highpass_filter(y, sr, cutoff=500, order=5): + """High-pass filter to remove low frequencies below `cutoff` Hz.""" + nyquist = 0.5 * sr + normal_cutoff = cutoff / nyquist + b, a = signal.butter(order, normal_cutoff, btype='high', analog=False) + y_filtered = signal.lfilter(b, a, y) + return y_filtered + +def plot_combined_visualization(y_original, y_filtered, sr, save_path="combined_visualization.png"): + """Plot waveform comparison and spectrograms in a single figure.""" + fig, axes = plt.subplots(3, 1, figsize=(12, 12)) + + # 1️⃣ Waveform Comparison + time = np.linspace(0, len(y_original) / sr, len(y_original)) + axes[0].plot(time, y_original, label='Original', alpha=0.7) + axes[0].plot(time, y_filtered, label='High-pass Filtered', alpha=0.7, linestyle='dashed') + axes[0].set_xlabel("Time (s)") + axes[0].set_ylabel("Amplitude") + axes[0].set_title("Waveform Comparison (Original vs High-pass Filtered)") + axes[0].legend() + + # 2️⃣ Spectrogram - Original + S_orig = librosa.amplitude_to_db(np.abs(librosa.stft(y_original)), ref=np.max) + img = librosa.display.specshow(S_orig, sr=sr, x_axis='time', y_axis='log', ax=axes[1]) + axes[1].set_title("Original Spectrogram") + fig.colorbar(img, ax=axes[1], format="%+2.0f dB") + + # 3️⃣ Spectrogram - High-pass Filtered + S_filt = librosa.amplitude_to_db(np.abs(librosa.stft(y_filtered)), ref=np.max) + img = librosa.display.specshow(S_filt, sr=sr, x_axis='time', y_axis='log', ax=axes[2]) + axes[2].set_title("High-pass Filtered Spectrogram") + fig.colorbar(img, ax=axes[2], format="%+2.0f dB") + + plt.tight_layout() + plt.savefig(save_path, dpi=300) + plt.show() + + +def load_model(checkpoint_path, model_class, device): + """Load a trained model from checkpoint.""" + model = model_class() + model.load_state_dict(torch.load(checkpoint_path, map_location=device)) + model.to(device) + model.eval() + return model + +def predict_audio(model, audio_tensor, device): + """Make predictions using a trained model.""" + with torch.no_grad(): + audio_tensor = audio_tensor.unsqueeze(0).to(device) # Add batch dimension + output = model(audio_tensor) + prediction = torch.argmax(output, dim=1).cpu().numpy()[0] + return prediction + +# Load audio +audio_path = "/data/kym/AI Music Detection/audio/FakeMusicCaps/real/musiccaps/_RrA-0lfIiU.wav" # Replace with actual file path +y, sr = librosa.load(audio_path, sr=None) +y_filtered = highpass_filter(y, sr, cutoff=500) + +# Convert audio to tensor +audio_tensor = torch.tensor(librosa.feature.melspectrogram(y=y, sr=sr), dtype=torch.float).unsqueeze(0) +audio_tensor_filtered = torch.tensor(librosa.feature.melspectrogram(y=y_filtered, sr=sr), dtype=torch.float).unsqueeze(0) + +# Load models +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +original_model = load_model("/data/kym/AI Music Detection/AudioCNN/ckpt/FakeMusicCaps/pretraining/best_model_audiocnn.pth", audiocnn, device) +highpass_model = load_model("/data/kym/AI Music Detection/AudioCNN/ckpt/FakeMusicCaps/500hz_Add_crossattn_decoder/best_model_AudioCNNWithViTDecoderAndCrossAttention.pth", AudioCNNWithViTDecoderAndCrossAttention, device) + +# Predict +original_pred = predict_audio(original_model, audio_tensor, device) +highpass_pred = predict_audio(highpass_model, audio_tensor_filtered, device) + +print(f"Original Model Prediction: {original_pred}") +print(f"High-pass Filter Model Prediction: {highpass_pred}") + +# Generate combined visualization (all plots in one image) +plot_combined_visualization(y, y_filtered, sr, save_path="/data/kym/AI Music Detection/AudioCNN/hf_vis/rawvs500.png") diff --git a/ISMIR_2025/wav2vec/utils/idr_torch.py b/ISMIR_2025/wav2vec/utils/idr_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..b2e76040394ce27390c27bd8ef022e126d8e55dc --- /dev/null +++ b/ISMIR_2025/wav2vec/utils/idr_torch.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# coding: utf-8 + +import os +import hostlist + +# get SLURM variables +# rank = int(os.environ["SLURM_PROCID"]) +local_rank = int(os.environ["SLURM_LOCALID"]) +size = int(os.environ["SLURM_NTASKS"]) +cpus_per_task = int(os.environ["SLURM_CPUS_PER_TASK"]) + +# get node list from slurm +hostnames = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"]) + +# get IDs of reserved GPU +gpu_ids = os.environ["SLURM_STEP_GPUS"].split(",") + +# define MASTER_ADD & MASTER_PORT +os.environ["MASTER_ADDR"] = hostnames[0] +os.environ["MASTER_PORT"] = str( + 12345 + int(min(gpu_ids)) +) # to avoid port conflict on the same node \ No newline at end of file diff --git a/ISMIR_2025/wav2vec/utils/mfcc.py b/ISMIR_2025/wav2vec/utils/mfcc.py new file mode 100644 index 0000000000000000000000000000000000000000..5d63db14375fedcc1cc60f2ef3cecf5c70e9a8fb --- /dev/null +++ b/ISMIR_2025/wav2vec/utils/mfcc.py @@ -0,0 +1,266 @@ +import os +import glob +import librosa +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import Dataset, DataLoader, random_split +import torch.nn.functional as F +from sklearn.metrics import precision_score, recall_score, f1_score +from tqdm import tqdm +import argparse +import wandb + +class RealFakeDataset(Dataset): + """ + audio/FakeMusicCaps/ + ├─ real/ + │ └─ MusicCaps/*.wav (label=0) + └─ generative/ + └─ .../*.wav (label=1) + """ + def __init__(self, root_dir, sr=16000, n_mels=64, target_duration=10.0): + + self.sr = sr + self.n_mels = n_mels + self.target_duration = target_duration + self.target_samples = int(target_duration * sr) # 10초 = 160,000 샘플 + + self.file_paths = [] + self.labels = [] + + # Real 데이터 (label=0) + real_dir = os.path.join(root_dir, "real") + real_wav_files = glob.glob(os.path.join(real_dir, "**", "*.wav"), recursive=True) + for f in real_wav_files: + self.file_paths.append(f) + self.labels.append(0) + + # Generative 데이터 (label=1) + gen_dir = os.path.join(root_dir, "generative") + gen_wav_files = glob.glob(os.path.join(gen_dir, "**", "*.wav"), recursive=True) + for f in gen_wav_files: + self.file_paths.append(f) + self.labels.append(1) + + def __len__(self): + return len(self.file_paths) + + def __getitem__(self, idx): + audio_path = self.file_paths[idx] + label = self.labels[idx] + # print(f"[DEBUG] Path: {audio_path}, Label: {label}") # 추가 + + waveform, sr = librosa.load(audio_path, sr=self.sr, mono=True) + + current_samples = waveform.shape[0] + if current_samples > self.target_samples: + waveform = waveform[:self.target_samples] + elif current_samples < self.target_samples: + stretch_factor = self.target_samples / current_samples + waveform = librosa.effects.time_stretch(waveform, rate=stretch_factor) + waveform = waveform[:self.target_samples] + + mfcc = librosa.feature.mfcc( + y=waveform, sr=self.sr, n_mfcc=self.n_mels, n_fft=1024, hop_length=256 + ) + mfcc = librosa.util.normalize(mfcc) + + mfcc = np.expand_dims(mfcc, axis=0) + mfcc_tensor = torch.tensor(mfcc, dtype=torch.float) + label_tensor = torch.tensor(label, dtype=torch.long) + + return mfcc_tensor, label_tensor + + + +class AudioCNN(nn.Module): + def __init__(self, num_classes=2): + super(AudioCNN, self).__init__() + self.conv_block = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Conv2d(16, 32, kernel_size=3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.AdaptiveAvgPool2d((4,4)) # 최종 -> (B,32,4,4) + ) + self.fc_block = nn.Sequential( + nn.Linear(32*4*4, 128), + nn.ReLU(), + nn.Linear(128, num_classes) + ) + + + def forward(self, x): + x = self.conv_block(x) + # x.shape: (B,32,new_freq,new_time) + + # 1) Flatten + B, C, H, W = x.shape # 동적 shape + x = x.view(B, -1) # (B, 32*H*W) + + # 2) FC + x = self.fc_block(x) + return x + + +def my_collate_fn(batch): + mel_list, label_list = zip(*batch) + + max_frames = max(m.shape[2] for m in mel_list) + + padded = [] + for m in mel_list: + diff = max_frames - m.shape[2] + if diff > 0: + print(f"Padding applied: Original frames = {m.shape[2]}, Target frames = {max_frames}") + m = F.pad(m, (0, diff), mode='constant', value=0) + padded.append(m) + + + mel_batch = torch.stack(padded, dim=0) + label_batch = torch.tensor(label_list, dtype=torch.long) + return mel_batch, label_batch + + +class EarlyStopping: + def __init__(self, patience=5, delta=0, path='./ckpt/mfcc/early_stop_best_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}.pth', verbose=False): + self.patience = patience + self.delta = delta + self.path = path + self.verbose = verbose + self.counter = 0 + self.best_loss = None + self.early_stop = False + + def __call__(self, val_loss, model): + if self.best_loss is None: + self.best_loss = val_loss + self._save_checkpoint(val_loss, model) + elif val_loss > self.best_loss - self.delta: + self.counter += 1 + if self.verbose: + print(f"EarlyStopping counter: {self.counter} out of {self.patience}") + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_loss = val_loss + self._save_checkpoint(val_loss, model) + self.counter = 0 + + def _save_checkpoint(self, val_loss, model): + if self.verbose: + print(f"Validation loss decreased ({self.best_loss:.6f} --> {val_loss:.6f}). Saving model ...") + torch.save(model.state_dict(), self.path) + +def train(batch_size, epochs, learning_rate, root_dir="audio/FakeMusicCaps"): + if not os.path.exists("./ckpt/mfcc/"): + os.makedirs("./ckpt/mfcc/") + + wandb.init( + project="AI Music Detection", + name=f"mfcc_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}", + config={"batch_size": batch_size, "epochs": epochs, "learning_rate": learning_rate}, + ) + + dataset = RealFakeDataset(root_dir=root_dir) + n_total = len(dataset) + n_train = int(n_total * 0.8) + n_val = n_total - n_train + train_ds, val_ds = random_split(dataset, [n_train, n_val]) + + train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn) + val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=my_collate_fn) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = AudioCNN(num_classes=2).to(device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + + best_val_loss = float('inf') + patience = 3 + patience_counter = 0 + + for epoch in range(1, epochs + 1): + print(f"\n[Epoch {epoch}/{epochs}]") + + # Training + model.train() + train_loss, train_correct, train_total = 0, 0, 0 + train_pbar = tqdm(train_loader, desc="Train", leave=False) + for mel_batch, labels in train_pbar: + mel_batch, labels = mel_batch.to(device), labels.to(device) + optimizer.zero_grad() + outputs = model(mel_batch) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + train_loss += loss.item() * mel_batch.size(0) + preds = outputs.argmax(dim=1) + train_correct += (preds == labels).sum().item() + train_total += labels.size(0) + + train_pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + + train_loss /= train_total + train_acc = train_correct / train_total + + # Validation + model.eval() + val_loss, val_correct, val_total = 0, 0, 0 + all_preds, all_labels = [], [] + val_pbar = tqdm(val_loader, desc=" Val ", leave=False) + with torch.no_grad(): + for mel_batch, labels in val_pbar: + mel_batch, labels = mel_batch.to(device), labels.to(device) + outputs = model(mel_batch) + loss = criterion(outputs, labels) + val_loss += loss.item() * mel_batch.size(0) + preds = outputs.argmax(dim=1) + val_correct += (preds == labels).sum().item() + val_total += labels.size(0) + all_preds.extend(preds.cpu().numpy()) + all_labels.extend(labels.cpu().numpy()) + + val_loss /= val_total + val_acc = val_correct / val_total + val_precision = precision_score(all_labels, all_preds, average="macro") + val_recall = recall_score(all_labels, all_preds, average="macro") + val_f1 = f1_score(all_labels, all_preds, average="macro") + + print(f"Train Loss: {train_loss:.4f} Acc: {train_acc:.3f} | " + f"Val Loss: {val_loss:.4f} Acc: {val_acc:.3f} " + f"Precision: {val_precision:.3f} Recall: {val_recall:.3f} F1: {val_f1:.3f}") + + wandb.log({"train_loss": train_loss, "train_acc": train_acc, + "val_loss": val_loss, "val_acc": val_acc, + "val_precision": val_precision, "val_recall": val_recall, "val_f1": val_f1}) + + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + best_model_path = f"./ckpt/mfcc/best_batch_{batch_size}_epochs_{epochs}_lr_{learning_rate}.pth" + torch.save(model.state_dict(), best_model_path) + print(f"[INFO] New best model saved: {best_model_path}") + else: + patience_counter += 1 + if patience_counter >= patience: + print("Early stopping triggered!") + break + + wandb.finish() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train AI Music Detection model.") + parser.add_argument('--batch_size', type=int, required=True, help="Batch size for training") + parser.add_argument('--epochs', type=int, required=True, help="Number of epochs") + parser.add_argument('--learning_rate', type=float, required=True, help="Learning rate") + parser.add_argument('--root_dir', type=str, default="audio/FakeMusicCaps", help="Root directory for dataset") + + args = parser.parse_args() + + train(batch_size=args.batch_size, epochs=args.epochs, learning_rate=args.learning_rate, root_dir=args.root_dir) diff --git a/ISMIR_2025/wav2vec/utils/utilities.py b/ISMIR_2025/wav2vec/utils/utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..e0be98e8645b8bb1c838d3dc9ae49daac706df62 --- /dev/null +++ b/ISMIR_2025/wav2vec/utils/utilities.py @@ -0,0 +1,305 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import logging +import pickle + +import numpy as np + +from scipy import stats + +import csv +import json + +def create_folder(fd): + if not os.path.exists(fd): + os.makedirs(fd, exist_ok=True) + + +def get_filename(path): + path = os.path.realpath(path) + na_ext = path.split("/")[-1] + na = os.path.splitext(na_ext)[0] + return na + + +def get_sub_filepaths(folder): + paths = [] + for root, dirs, files in os.walk(folder): + for name in files: + path = os.path.join(root, name) + paths.append(path) + return paths + + +def create_logging(log_dir, filemode): + create_folder(log_dir) + i1 = 0 + + while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))): + i1 += 1 + + log_path = os.path.join(log_dir, "{:04d}.log".format(i1)) + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s", + datefmt="%a, %d %b %Y %H:%M:%S", + filename=log_path, + filemode=filemode, + ) + + # Print to console + console = logging.StreamHandler() + console.setLevel(logging.INFO) + formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s") + console.setFormatter(formatter) + logging.getLogger("").addHandler(console) + + return logging + + +def read_metadata(csv_path, audio_dir, classes_num, id_to_ix): + """Read metadata of AudioSet from a csv file. + + Args: + csv_path: str + + Returns: + meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)} + """ + + with open(csv_path, "r") as fr: + lines = fr.readlines() + lines = lines[3:] # Remove heads + + # first, count the audio names only of existing files on disk only + + audios_num = 0 + for n, line in enumerate(lines): + items = line.split(", ") + """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']""" + + # audio_name = 'Y{}.wav'.format(items[0]) # Audios are started with an extra 'Y' when downloading + audio_name = "{}_{}_{}.flac".format( + items[0], items[1].replace(".", ""), items[2].replace(".", "") + ) + audio_name = audio_name.replace("_0000_", "_0_") + + if os.path.exists(os.path.join(audio_dir, audio_name)): + audios_num += 1 + + print("CSV audio files: %d" % (len(lines))) + print("Existing audio files: %d" % audios_num) + + # audios_num = len(lines) + targets = np.zeros((audios_num, classes_num), dtype=bool) + audio_names = [] + + n = 0 + for line in lines: + items = line.split(", ") + """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']""" + + # audio_name = 'Y{}.wav'.format(items[0]) # Audios are started with an extra 'Y' when downloading + audio_name = "{}_{}_{}.flac".format( + items[0], items[1].replace(".", ""), items[2].replace(".", "") + ) + audio_name = audio_name.replace("_0000_", "_0_") + + if not os.path.exists(os.path.join(audio_dir, audio_name)): + continue + + label_ids = items[3].split('"')[1].split(",") + + audio_names.append(audio_name) + + # Target + for id in label_ids: + ix = id_to_ix[id] + targets[n, ix] = 1 + n += 1 + + meta_dict = {"audio_name": np.array(audio_names), "target": targets} + return meta_dict + + +def read_audioset_ontology(id_to_ix): + with open('../metadata/audioset_ontology.json', 'r') as f: + data = json.load(f) + + # Output: {'name': 'Bob', 'languages': ['English', 'French']} + sentences = [] + for el in data: + print(el.keys()) + id = el['id'] + if id in id_to_ix: + name = el['name'] + desc = el['description'] + # if '(' in desc: + # print(name, '---', desc) + # print(id_to_ix[id], name, '---', ) + + # sent = name + # sent = name + ', ' + desc.replace('(', '').replace(')', '').lower() + # sent = desc.replace('(', '').replace(')', '').lower() + # sentences.append(sent) + sentences.append(desc) + # print(sent) + # break + return sentences + + +def original_read_metadata(csv_path, classes_num, id_to_ix): + """Read metadata of AudioSet from a csv file. + + Args: + csv_path: str + + Returns: + meta_dict: {'audio_name': (audios_num,), 'target': (audios_num, classes_num)} + """ + + with open(csv_path, "r") as fr: + lines = fr.readlines() + lines = lines[3:] # Remove heads + + # Thomas Pellegrini: added 02/12/2022 + # check if the audio files indeed exist, otherwise remove from list + + audios_num = len(lines) + targets = np.zeros((audios_num, classes_num), dtype=bool) + audio_names = [] + + for n, line in enumerate(lines): + items = line.split(", ") + """items: ['--4gqARaEJE', '0.000', '10.000', '"/m/068hy,/m/07q6cd_,/m/0bt9lr,/m/0jbk"\n']""" + + audio_name = "{}_{}_{}.flac".format( + items[0], items[1].replace(".", ""), items[2].replace(".", "") + ) # Audios are started with an extra 'Y' when downloading + audio_name = audio_name.replace("_0000_", "_0_") + + label_ids = items[3].split('"')[1].split(",") + + audio_names.append(audio_name) + + # Target + for id in label_ids: + ix = id_to_ix[id] + targets[n, ix] = 1 + + meta_dict = {"audio_name": np.array(audio_names), "target": targets} + return meta_dict + +def read_audioset_label_tags(class_labels_indices_csv): + with open(class_labels_indices_csv, 'r') as f: + reader = csv.reader(f, delimiter=',') + lines = list(reader) + + labels = [] + ids = [] # Each label has a unique id such as "/m/068hy" + for i1 in range(1, len(lines)): + id = lines[i1][1] + label = lines[i1][2] + ids.append(id) + labels.append(label) + + classes_num = len(labels) + + lb_to_ix = {label : i for i, label in enumerate(labels)} + ix_to_lb = {i : label for i, label in enumerate(labels)} + + id_to_ix = {id : i for i, id in enumerate(ids)} + ix_to_id = {i : id for i, id in enumerate(ids)} + + return lb_to_ix, ix_to_lb, id_to_ix, ix_to_id + + + +def float32_to_int16(x): + # assert np.max(np.abs(x)) <= 1.5 + x = np.clip(x, -1, 1) + return (x * 32767.0).astype(np.int16) + + +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + + +def pad_or_truncate(x, audio_length): + """Pad all audio to specific length.""" + if len(x) <= audio_length: + return np.concatenate((x, np.zeros(audio_length - len(x))), axis=0) + else: + return x[0:audio_length] + + +def pad_audio(x, audio_length): + """Pad all audio to specific length.""" + if len(x) <= audio_length: + return np.concatenate((x, np.zeros(audio_length - len(x))), axis=0) + else: + return x + + +def d_prime(auc): + d_prime = stats.norm().ppf(auc) * np.sqrt(2.0) + return d_prime + + +class Mixup(object): + def __init__(self, mixup_alpha, random_seed=1234): + """Mixup coefficient generator.""" + self.mixup_alpha = mixup_alpha + self.random_state = np.random.RandomState(random_seed) + + def get_lambda(self, batch_size): + """Get mixup random coefficients. + Args: + batch_size: int + Returns: + mixup_lambdas: (batch_size,) + """ + mixup_lambdas = [] + for n in range(0, batch_size, 2): + lam = self.random_state.beta(self.mixup_alpha, self.mixup_alpha, 1)[0] + mixup_lambdas.append(lam) + mixup_lambdas.append(1.0 - lam) + + return np.array(mixup_lambdas) + + +class StatisticsContainer(object): + def __init__(self, statistics_path): + """Contain statistics of different training iterations.""" + self.statistics_path = statistics_path + + self.backup_statistics_path = "{}_{}.pkl".format( + os.path.splitext(self.statistics_path)[0], + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), + ) + + self.statistics_dict = {"bal": [], "test": []} + + def append(self, iteration, statistics, data_type): + statistics["iteration"] = iteration + self.statistics_dict[data_type].append(statistics) + + def dump(self): + pickle.dump(self.statistics_dict, open(self.statistics_path, "wb")) + pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb")) + logging.info(" Dump statistics to {}".format(self.statistics_path)) + logging.info(" Dump statistics to {}".format(self.backup_statistics_path)) + + def load_state_dict(self, resume_iteration): + self.statistics_dict = pickle.load(open(self.statistics_path, "rb")) + + resume_statistics_dict = {"bal": [], "test": []} + + for key in self.statistics_dict.keys(): + for statistics in self.statistics_dict[key]: + if statistics["iteration"] <= resume_iteration: + resume_statistics_dict[key].append(statistics) + + self.statistics_dict = resume_statistics_dict \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..56cfe643121989e674fd868e6cdfb3b9852cd5e9 --- /dev/null +++ b/app.py @@ -0,0 +1,189 @@ +import gradio as gr +import torch +import librosa +import numpy as np +from inference import inference + +def detect_ai_audio(audio_file): + """ + Detect whether the uploaded audio file was generated by AI + """ + result = inference(audio_file) + + # Format result with better styling + if "AI" in str(result).upper() or "artificial" in str(result).lower(): + status = "AI Generated" + color = "#ff6b6b" + else: + status = "Human Generated" + color = "#51cf66" + + formatted_result = f""" +
Advanced AI technology to accurately detect whether uploaded audio was generated by AI!
+Supported formats: MP3, WAV, M4A, FLAC and various audio formats
+Fast and accurate real-time analysis
+