File size: 3,819 Bytes
b53f252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals


import torch.utils.data as data
# from PIL import Image
import soundfile as sf
import PIL
import os
import os.path
import pickle
import random
import numpy as np
import pandas as pd
from scipy import signal

from miscc.config import cfg


class TextDataset(data.Dataset):
    def __init__(self, data_dir, split='train',rirsize=4096): #, transform=None, target_transform=None):

        # self.transform = transform
        # self.target_transform = target_transform
        self.rirsize = rirsize
        self.data = []
        self.data_dir = data_dir       
        self.bbox = None
        
        split_dir = os.path.join(data_dir, split)

        self.filenames = self.load_filenames(split_dir)
        self.embeddings = self.load_embedding(split_dir)

    def get_RIR(self, RIR_path):
        wav,fs = sf.read(RIR_path) #Image.open(RIR_path).convert('RGB')
        length = wav.size
        # crop_length = int((16384*(80))/(64))
        crop_length = 4096 #int(16384)
        if(length<crop_length):
            zeros = np.zeros(crop_length-length)
            RIR_original = np.concatenate([wav,zeros])
        else:
            RIR_original = wav[0:crop_length]

        # resample_length = int((self.rirsize*(80))/(64))
        resample_length = int(self.rirsize)
        if(resample_length==16384):
            RIR = RIR_original
        else:
            RIR = RIR_original#signal.resample(RIR_original,resample_length)
        RIR = np.array([RIR]).astype('float32')



        # if bbox is not None:
        #     R = int(np.maximum(bbox[2], bbox[3]) * 0.75)
        #     center_x = int((2 * bbox[0] + bbox[2]) / 2)
        #     center_y = int((2 * bbox[1] + bbox[3]) / 2)
        #     y1 = np.maximum(0, center_y - R)
        #     y2 = np.minimum(height, center_y + R)
        #     x1 = np.maximum(0, center_x - R)
        #     x2 = np.minimum(width, center_x + R)
        #     RIR = RIR.crop([x1, y1, x2, y2])
        # load_size = int(self.rirsize * 76 / 64)
        # RIR = RIR.resize((load_size, load_size), PIL.Image.BILINEAR)
        # if self.transform is not None:
        #     RIR = self.transform(RIR)
        return RIR


    def load_embedding(self, data_dir):
        embedding_filename   = '/embeddings.pickle'  
        with open(data_dir + embedding_filename, 'rb') as f:
            embeddings = pickle.load(f)
            # embeddings = np.array(embeddings)
            # # embedding_shape = [embeddings.shape[-1]]
            # print('embeddings: ', embeddings.shape)
        return embeddings

    # def load_class_id(self, data_dir, total_num):
    #     if os.path.isfile(data_dir + '/class_info.pickle'):
    #         with open(data_dir + '/class_info.pickle', 'rb') as f:
    #             class_id = pickle.load(f)
    #     else:
    #         class_id = np.arange(total_num)
    #     return class_id

    def load_filenames(self, data_dir):
        filepath = os.path.join(data_dir, 'filenames.pickle')
        with open(filepath, 'rb') as f:
            filenames = pickle.load(f)
        print('Load filenames from: %s (%d)' % (filepath, len(filenames)))
        return filenames

    def __getitem__(self, index):
        key = self.filenames[index]

        data_dir = self.data_dir

        # captions = self.captions[key]
        embeddings = self.embeddings[key]
        RIR_name = '%s/RIR/%s.wav' % (data_dir, key)
        RIR = self.get_RIR(RIR_name)
        embedding = np.array(embeddings).astype('float32')
        # if self.target_transform is not None:
        #     embedding = self.target_transform(embedding)
        return RIR, embedding

    def __len__(self):
        return len(self.filenames)