| | |
| | |
| |
|
| | |
| | |
| |
|
| |
|
| | import numpy as np |
| | import pandas as pd |
| | import pytest |
| |
|
| | from src.utils import train_test_split_and_feature_extraction |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @pytest.fixture |
| | def big_fake_data(): |
| | |
| | num_rows = 100 |
| | num_image_columns = 10 |
| | num_text_columns = 11 |
| |
|
| | data = { |
| | "id": np.arange(1, num_rows + 1), |
| | "image": [f"path/{i}.jpg" for i in range(1, num_rows + 1)], |
| | } |
| |
|
| | |
| | for i in range(num_image_columns): |
| | data[f"image_{i}"] = np.random.rand(num_rows) |
| |
|
| | |
| | for i in range(num_text_columns): |
| | data[f"text_{i}"] = np.random.rand(num_rows) |
| |
|
| | |
| | data["class_id"] = np.random.choice(["label1", "label2", "label3"], size=num_rows) |
| |
|
| | return pd.DataFrame(data) |
| |
|
| |
|
| | def test_train_test_split_and_feature_extraction(big_fake_data): |
| | |
| | train_df, test_df, text_columns, image_columns, label_columns = ( |
| | train_test_split_and_feature_extraction( |
| | big_fake_data, test_size=0.3, random_state=42 |
| | ) |
| | ) |
| |
|
| | |
| | assert text_columns == [f"text_{i}" for i in range(11)], ( |
| | "The text embedding columns extraction is incorrect" |
| | ) |
| | assert image_columns == [f"image_{i}" for i in range(10)], ( |
| | "The image embedding columns extraction is incorrect" |
| | ) |
| | assert label_columns == ["class_id"], ( |
| | "The label column extraction is incorrect, should be 'class_id'" |
| | ) |
| |
|
| | |
| | assert "image" not in image_columns, ( |
| | "'image' column is not part of the embedding columns" |
| | ) |
| |
|
| | |
| | assert len(train_df) == 70, f"Train size should be 70%, but got {len(train_df)}%" |
| | assert len(test_df) == 30, f"Test size should be 30%, but got {len(test_df)}%" |
| |
|
| | |
| | expected_train_indices = train_df.index.tolist() |
| | expected_test_indices = test_df.index.tolist() |
| |
|
| | |
| | train_df_recheck, test_df_recheck, _, _, _ = ( |
| | train_test_split_and_feature_extraction( |
| | big_fake_data, test_size=0.3, random_state=42 |
| | ) |
| | ) |
| |
|
| | assert expected_train_indices == train_df_recheck.index.tolist(), ( |
| | "Train set indices are not consistent with the random state" |
| | ) |
| | assert expected_test_indices == test_df_recheck.index.tolist(), ( |
| | "Test set indices are not consistent with the random state" |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | pytest.main() |
| |
|