File size: 81,438 Bytes
549d36c 3b3d375 d066e2c 3b3d375 549d36c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 |
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Physics NeMo External Aerodynamics DLI\n",
"\n",
"## Notebook 3 - Training DoMINO Model on the Ahmed body surface dataset and Running Inference\n",
"\n",
"In this notebook, we will first provide a detailed explanation of the DoMINO architecture, which is a multi-scale, iterative neural operator designed for modeling large-scale engineering simulations. We will break down the key components of DoMINO, including its use of local geometry representations, multi-scale point convolution kernels, and its efficient handling of complex geometries. Afterward, we will train the model using the **Ahmed body surface dataset**, a widely used dataset in automotive aerodynamics simulations. *As indicated in the previous notebook this dataset was created by the NVIDIA Physics NeMo development team and differs from other similar datasets hosted on cloud platforms like AWS.*\n",
"\n",
"*The DoMINO model is capable of training both volume fields (such as velocity and pressure) and surface fields (including pressure and wall shear stress). However, for the sake of simplicity and educational purposes, this notebook will *focus solely on training the surface fields* using the Ahmed body surface dataset.*\n",
"\n",
"## Guid line\n",
"- Before starting training, ensure GPU memory is cleared by running:\n",
"```python\n",
"import os\n",
"os._exit(00)\n",
"```\n",
"- The data preprocessing notebook should be run beforehand as a prerequisite to prepare the data in NPY format.\n",
"- The amount of GPU memory allocated during training can be controlled through the following configuration parameters:\n",
"```yaml\n",
" interp_res=GRID_RESOLUTION # Resolution of the latent space; coarser resolutions reduce memory usage.\n",
" NUM_SURFACE_NEIGHBORS = 7 # Number of neighboring surface points used to compute solution variables such as pressure. \n",
" SURFACE_POINTS_SAMPLE = 8192 # Number of surface points sampled per epoch; fewer points reduce memory consumption\n",
"```\n",
" To adjust GPU memory usage, modify the above configuration values in the \"Experiment Parameters and Variables\" section.\n",
"\n",
"\n",
"# Table of Contents\n",
"- [DoMINO Architecture](#Domino-Architecture)\n",
" - [Geometric Understanding: Global and Local Perspectives](#Geometric-Understanding:-Global-and-Local-Perspectives)\n",
" - [The Prediction Engine: Basis Functions and the Aggregation Network](#The-Prediction-Engine:-Basis-Functions-and-the-Aggregation-Network)\n",
"- [Training Process](#training)\n",
" - [Step 1: Define Experiment Parameters and Dependencies](#step-1-define-experiment-parameters-and-dependencies)\n",
" - [Loading Required Libraries](#loading-required-libraries)\n",
" - [Dependencies](#dependencies)\n",
" - [Experiment Parameters and Variables](#experiment-parameters-and-variables)\n",
" - [Step 2: Train the DoMINO Model](#step-2-train-the-domino-model)\n",
" - [Understanding the Training Process](#understanding-the-training-process)\n",
" - [Key Components and Libraries](#key-components-and-libraries)\n",
" - [Important Training Parameters](#important-training-parameters)\n",
" - [Implementation Overview](#implementation-overview)\n",
"- [Load Model Checkpoint & Run Inference](#Load-Model-Checkpoint-&-Run-Inference)\n",
"- [Visualizing the predicted results](#Visualizing-the-predicted-results)\n",
"\n",
"\n",
"## DoMINO Architecture\n",
"\n",
"DoMINO, which stands for Decomposable Multiscale Iterative Neural Operator, is a novel machine learning model designed to address key challenges in accelerating simulations, such as accuracy, scalability, and generalization to new geometries—particularly in the context of automotive aerodynamics. As a neural operator, DoMINO is capable of predicting point-wise volume and surface fields and is inherently scalable to large domains. Its decomposable architecture is central to its performance: it learns local geometric representations within sub-regions of the domain, enabling greater accuracy by focusing on areas with detailed physical features. The model also employs a multiscale approach using learnable point kernels, allowing it to capture both fine and coarse geometric patterns directly from STL files.\n",
"\n",
"DoMINO operates iteratively to enable long-range interactions, progressively propagating the learned local geometric features across the entire computational domain. Separately, it builds dynamic, point-based computational stencils—drawing inspiration from traditional numerical methods—which it uses to learn non-linear basis functions tailored to the geometry. These basis functions, combined with localized geometry encodings, allow the model to predict volume and surface solution fields at specified points. A major advantage of DoMINO is that it only requires STL geometry as input during inference; it does not depend on mesh generation or the density and structure of input point clouds. The process begins by encoding global geometry onto a fixed grid using a combination of learnable point convolutions, CNNs, and dense layers. From this, the model constructs localized subdomains around evaluation points to extract rich geometric encodings, which are then used in conjunction with the learned basis functions for final prediction.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### **Geometric Understanding: Global and Local Perspectives**\n",
"DoMINO's ability to accurately predict complex flow phenomena relies heavily on its sophisticated understanding of geometry, processed at both global and local scales.\n",
"\n",
"\n",
"- **A.** **Global Geometry Representation**:\n",
"The Global Geometry Representation refers to the overall shape and structure of the entire object or domain that you are modeling. This representation captures all the geometric details across the entire computational domain.\\\n",
"Step-by-Step Explanation of Global Geometry Representation:\n",
"\n",
" - **Step 1**: Construct Bounding Boxes\n",
" \t- A tight-fitting surface bounding box is created around the STL (3D geometry) to hold the geometry.\n",
" - A computational domain bounding box is also defined, which is larger than the surface bounding box to encompass the whole computational domain.\n",
" - Both bounding boxes can be specified in ```conf.yaml```\n",
" - **Step 2**: Project STL Vertices onto Structured Grid\n",
" \t- The geometric features of the point cloud, such as spatial coordinates, are projected onto an N-dimensional structured grid of resolution m×m×m×f, which is overlaid on the surface bounding box using **learnable point convolution kernels**.\n",
" \t- The learnable point convolution kernels are created using **differentiable ball query layers**. This means that the method:\n",
" \t- Uses a \"ball\" (a sphere in 3D space) around each point to query or find its neighbors.\n",
" \t- The ball query layer is \"differentiable,\" meaning it can be included in the neural network and updated via back propagation (i.e., during training, the network can learn how to adjust the kernels to improve performance).\n",
" \t- The radius of the ball (radius of influence) defines how far around each point we look for neighboring points to include in the convolution. This defines, in fact, how far the geometry can affect the grid. A range of point convolutional kernel sizes can be learned by specifying several radii. Moreover, different kernels are learned to represent information on the surface bounding box and computational domain bounding box. This enables multi-scale learning of geometry encoding by representing both short- and long-range interactions of the surface and flow fields. The radii of influence are defined as **list** in the ```conf.yaml``` file:\n",
" ```yaml\n",
" volume_radii: [0.1, 0.5]\n",
" surface_radii: [0.05]\n",
" ```\n",
" These radius are used in the DoMINO model (```physicsnemo/models/domino/model.py```) to compute two **BQWarp** accordingly:\n",
" \n",
"\n",
" \n",
" ```python\n",
" class GeometryRep(nn.Module):\n",
" \n",
" def __init__(self, input_features: int, radii, model_parameters=None):\n",
" \"\"\"\n",
" Initialize the GeometryRep module.\n",
" \n",
" Args:\n",
" input_features: Number of input feature dimensions\n",
" model_parameters: Configuration parameters for the model\n",
" \"\"\"\n",
" super().__init__()\n",
" geometry_rep = model_parameters.geometry_rep\n",
" self.geo_encoding_type = model_parameters.geometry_encoding_type\n",
" \n",
" self.bq_warp = nn.ModuleList()\n",
" self.geo_processors = nn.ModuleList()\n",
" for j, p in enumerate(radii):\n",
" self.bq_warp.append(\n",
" BQWarp(\n",
" grid_resolution=model_parameters.interp_res,\n",
" radius=radii[j],\n",
" )\n",
" )\n",
" self.geo_processors.append(\n",
" GeoProcessor(\n",
" input_filters=geometry_rep.geo_conv.base_neurons_out,\n",
" model_parameters=geometry_rep.geo_processor,\n",
" )\n",
" )\n",
" ```\n",
" - **Step 3**: Use Multi-Resolution Approach for Detailed and Coarse Features\n",
" \t- The grid resolution in the bounding box determines the level of detail of the geometry: \n",
" \t- Finer resolution captures more detailed features of the geometry.\n",
" \t- Coarser resolution captures larger, broader features.\n",
" \t- A multi-resolution approach is adopted, meaning multiple grids at different resolutions (levels) are maintained to capture both fine and coarse features of the geometry. The number of resolution levels is a parameter that can be adjusted in conf.yaml file as \\\n",
" ```yaml GRID_RESOLUTION = [128, 64, 48] # Resolution of the interpolation grid ```\n",
" - **Currently, the DoMINO model allows specification of a single resolution but this configuration will be provided in a future release.**\n",
"\n",
" - **Step 4**: Propagate Geometry Features into the Computational Domain\n",
" - The computational domain is much larger than the surface bounding box, so the geometry information needs to be extended.\n",
" \t- Geometry features are propagated into the computational domain using two methods: \n",
" \t- As explained in **step 2** Multi-scale **point convolution kernels** project the **geometry information** onto the surface bounding box (**see below the left figure**).\n",
" \t- **Features** from the surface grid of the bounding box (i.e., Gs) are propagated into the computational domain grid (i.e., Gc) using **CNN blocks** that contain convolution, pooling, and unpooling layers (**see below the left figure**).\n",
" As you see in the code snippet above first `BQWarp` is calculated and then CNN blocks using `GeoProcessor`:\n",
" ```python\n",
" class GeoProcessor(nn.Module):\n",
" ```\n",
" - The **CNN blocks are iterated** for a specified number of steps to refine the geometry representation. Currently, the DoMINO model is configured to run a single iteration. An option to change this will be provided in ```conf.yaml``` in a future release.\n",
"\n",
" - **Step 5**: Calculate Signed Distance Function (SDF) and its Gradients\n",
" \t- Additionally, the Signed Distance Function (SDF) and its gradient components are calculated on the computational domain grid.\n",
" \t- These SDF and gradient values are added to the learned features, providing additional information about the topology of the geometry (i.e., the geometry's shape, distances to surfaces, etc.).\n",
"\n",
" - **Step 6**: Final Global Geometry Representation\n",
" \t- The final geometry representation of the STL is formed by combining the learned features from the structured grids at different resolutions in both the bounding box and the computational domain.\n",
" \n",
" \n",
" \n",
" **Once the computational domain is created, the next step would be local geometry representation.**\n",
"\n",
"\n",
"- **B.** **Local geometry representation**\n",
"The Local Geometry Representation focuses on the geometry in the immediate vicinity of a sampled point p (the points in simulation mesh). The idea is to understand how the geometry behaves around a specific point and its neighbors, which can be important for accurate predictions. While the Global Geometry Representation gives the big picture, the Local Geometry Representation zooms in on a small region of interest around each sampled point. The key difference is that local geometry represents a smaller, more detailed portion of the global geometry, typically focusing on the small-scale features close to a point. For each sampled point p, neighboring points are sampled randomly around them to form a computational stencil of points similar to finite volume and element methods. The local geometry representation is learned by drawing a subregion around the computational stencil of\n",
"p + 1 points. The size of the subregion are defined as **list** in the ```conf.yaml``` file:\n",
"\n",
" ```yaml\n",
" geometry_local.volume_radii: [0.05, 0.1]\n",
" geometry_local.surface_radii: [0.05]\n",
" ```\n",
"Similar to Global geometry representation a point convolution kernel is used here to extract the local features in the subregion from the global geometry representation on the computational domain.\n",
"The **BQWarp** are computed for Local geometry representation in ```physicsnemo/models/domino/model.py``` in ```class DoMINO(nn.Module): ```:\n",
"\n",
"```python\n",
" for ct, j in enumerate(self.surface_radius):\n",
" if self.geo_encoding_type == \"both\":\n",
" total_neighbors_in_radius = self.surface_neighbors_in_radius[ct] * (\n",
" len(model_parameters.geometry_rep.geo_conv.surface_radii) + 1\n",
" )\n",
" elif self.geo_encoding_type == \"stl\":\n",
" total_neighbors_in_radius = self.surface_neighbors_in_radius[ct] * (\n",
" len(model_parameters.geometry_rep.geo_conv.surface_radii)\n",
" )\n",
" elif self.geo_encoding_type == \"sdf\":\n",
" total_neighbors_in_radius = self.surface_neighbors_in_radius[ct]\n",
"\n",
" self.surface_bq_warp.append(\n",
" BQWarp(\n",
" grid_resolution=model_parameters.interp_res,\n",
" radius=self.surface_radius[ct],\n",
" neighbors_in_radius=self.surface_neighbors_in_radius[ct],\n",
" )\n",
" )\n",
" self.surface_local_point_conv.append(\n",
" LocalPointConv(\n",
" input_features=total_neighbors_in_radius,\n",
" base_layer=512,\n",
" output_features=self.surface_neighbors_in_radius[ct],\n",
" )\n",
" )\n",
"```\n",
"\n",
"Note that **BQWarp** for **Global geometry representation** are computed in ``` class GeometryRep(nn.Module): ```\n",
"\n",
"**How Does the Multi-Resolution Global Geometry Affect the Local Geometry Representation?**\n",
" - Coarse resolution: At the coarse resolution, you get a broad view of the object. This can give information about the general shape and large-scale features of the geometry (e.g., the overall shape of the object, major boundaries, etc.). When local geometry is extracted from the coarse resolution, the features are relatively less detailed, and it might capture larger, more general features of the object.\n",
" - Fine resolution: At the fine resolution, you get a detailed view of the geometry, capturing small features such as intricate surface details, small holes, or sharp edges. The local geometry representation derived from the fine resolution will be more detailed and capture smaller variations in the geometry near each sampled point.\n",
"\n",
"**Thus, the global multi-resolution geometry allows the local geometry to be learned at different levels of detail, depending on the resolution of the grid that is used to represent the geometry.** \n",
"\n",
"\n",
"### **The Prediction Engine: Basis Functions and the Aggregation Network**\n",
"Once DoMINO has processed the local geometry around each point of interest, the next stage involves taking these learned geometric features, along with other relevant information, and predicting the actual flow field values. This is accomplished by an \"Aggregation Network,\" which has a structure inspired by DeepONet.\n",
"\n",
"\n",
"- **A.** **Adding Spatial Context: The Role of Positional Encoding**\n",
"Neural networks, particularly simpler architectures like Multi-Layer Perceptrons (MLPs), often process sets of input features without an inherent understanding of their absolute or relative spatial positions. If a network receives only a list of coordinates, it might not easily distinguish whether a particular point is \"at the front\" or \"at the back\" of an object in a global sense, or how two distant points are spatially related if they are not immediate neighbors. To overcome this, Positional Encoding (PE) techniques are employed. PE injects information about the position of points into the feature vectors that the network processes (*see the right panel in the figure below*). In DoMINO, \"Position encoding is calculated between the coordinates of the sampled point and the center of mass of the geometry STL.\n",
"\n",
"- **B.** **Basis Function Neural Network (Latent Vector):**\n",
" - Separately, DoMINO builds dynamic, point-based computational stencils—drawing inspiration from traditional numerical methods—which it uses to learn non-linear basis functions tailored to the geometry. These basis functions, combined with localized geometry encodings, allow the model to predict volume and surface solution fields at specified points (*see the right panel in the figure below*).\\\n",
" What happens here: \n",
" - The input features (coordinates, SDF, normal vectors, etc. and their fourier features) for each point in the stencil are fed into the Basis Function Neural Network. This is a fully connected neural network that processes these features.\n",
" - The network then computes a latent vector for each point in the stencil. A latent vector is a compressed mathematical representation that encodes the important information about each point’s geometry and position.\n",
" - Purpose: The latent vector captures the essential characteristics of each point’s geometry in a compact form, which will be used in later steps for predicting the solution at that point.\n",
"\n",
"- **C.** **Concatenating the Latent Vector with the Local Geometry Encoding:**\n",
" - After calculating the latent vector for each point, this vector is concatenated with the local geometry encoding (which includes the previously computed information from the surrounding points and the global geometry) and positional encoding.\n",
" - Why this is done: Concatenating these two representations allows the network to use both the specific local features of each point and the broader context of the surrounding geometry to make predictions.\n",
"\n",
"- **D.** **Passing Through Additional Neural Network Layers (Solution Prediction):**\n",
" - The combined information (latent vector + positional encoding + local geometry encoding) is passed through another set of fully connected layers (a new neural network).\n",
" - What happens here: These layers process the combined information and predict a solution vector for each point in the stencil. The solution vector could represent various physical quantities such as temperature, pressure, or other simulation results at the sampled point.\n",
" - Purpose: This step produces the predicted solution at each point, based on the local and global geometry.\n",
"\n",
"- **E.** **Aggregation Network:**\n",
"\n",
" The core of the prediction engine in DoMINO is the **Aggregation network**. This is described as **a fully connected neural network with a DeepONet like structure.** This network takes the processed local geometric features and the basis functions (derived from the Basis Function Neural Network and local geometry encodings) and combines them to compute the final solution field values. \n",
" To understand DoMINO's aggregation network, it's helpful to recall the DeepONet architecture. DeepONet is specifically designed for learning operators, mathematical mappings between function spaces. This makes it highly suitable for problems in physics described by PDEs. A DeepONet typically consists of two main sub-networks: \n",
"\n",
" - Branch Network: Processes the input function (e.g., PDE coefficients). \n",
" - Trunk Network: Processes the coordinates where the output function is evaluated. The outputs of these two networks are then combined to produce the prediction.\n",
"\n",
" In DoMINO, this DeepONet structure is adapted as follows:\n",
"\n",
" - **Branch Net**: This part of the aggregation network takes the Local geometry rep as input. These are the η features, which encapsulate the detailed local geometric information around point i and its neighbors j.\n",
" - **Trunk Net**: This part processes the basis functions. As described above, these \"basis functions\" are the output of the Basis Function Neural Network after concatenation with local geometry encodings and further MLP layers. They represent a learned, rich description of the query point i's properties, its positional encoding, and its context within its local stencil. \n",
" The DeepONet-like architecture allows DoMINO to effectively learn how different local geometric environments (processed by the branch net) influence the flow solution at various specific locations (whose context, encoded in the basis functions, is processed by the trunk net).\n",
" The aggregation network, with its DeepONet-like structure, \"computes the solution field on the sampled point, i and its neighbors j.\" This means that for a given point of interest i, the network doesn't just predict the solution at i in isolation. Instead, it leverages the local stencil of points (point i itself and its defined neighbors j) to make predictions across this local cloud.\n",
"\n",
" After the aggregation network produces these individual predictions for point i and its neighbors j, the solutions are then averaged using an inverse distance weighted interpolation (**IDW**) scheme.\n",
" In DoMINO, the \"*known values*\" $u_k$ are the predictions made by the aggregation network at point $i$ and its neighbors $j$. The IDW scheme then blends these predictions to yield the final solution at point $i$, giving more credence to the predictions made at or very near $i$. \n",
" This IDW step can be seen as a form of learned solution refinement or a consensus mechanism. By predicting solutions across a local stencil and then averaging them with IDW, the model can produce a more robust and spatially consistent output at point $i$, potentially smoothing out minor errors or noise from individual network predictions within that stencil. \n",
" This enforces a degree of local coherence in the predicted flow field, which is a desirable characteristic for physical simulations. IDW is often applied within a defined \"search neighborhood\" ; in DoMINO, this neighborhood is implicitly the set of points $j$ (and $i$ itself) for which the aggregation network computes initial solutions. \n",
"\n",
"<div style=\"display: flex; justify-content: center; gap: 10px;\">\n",
" <figure style=\"text-align: center;\">\n",
" <img src=\"https://raw.githubusercontent.com/openhackathons-org/End-to-End-AI-for-Science/main/workspace/python/jupyter_notebook/DoMINO/images/global_geo_rep.png\" style=\"width: 100%; height: auto;\">\n",
" <figcaption>Computation and surface Bounding box representation.</figcaption>\n",
" </figure>\n",
" <figure style=\"text-align: center;\">\n",
" <img src=\"https://raw.githubusercontent.com/openhackathons-org/End-to-End-AI-for-Science/main/workspace/python/jupyter_notebook/DoMINO/images/aggregation_net.png\" style=\"width: 55%; height: auto;\">\n",
" <figcaption>Aggregation network is a fully connected neural network with a DeepONet like structure, where Local geometry rep is branch net and basis functions are trunk net.</figcaption>\n",
" </figure>\n",
"</div>\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## **Training**\n",
"### **Step 1: Define Experiment Parameters and Dependencies**\n",
"\n",
"The first step in training the DoMINO model on the Ahmed body dataset is to set up our experiment environment and define the necessary parameters. This includes specifying paths to our data, configuring training settings, and ensuring all required libraries are available.\n",
"\n",
"Key components we need to set up:\n",
"- Data paths for training and validation sets\n",
"- Model hyperparameters and training configurations\n",
"- Visualization settings for results\n",
"- Required Python libraries for mesh processing and deep learning\n",
"\n",
"#### Loading Required Libraries\n",
"\n",
"Before we proceed with the experiment setup, let's first import all the necessary libraries. These libraries will be used for:\n",
"- Deep learning and numerical computations (torch, numpy)\n",
"- Progress tracking and visualization (tqdm, matplotlib)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"import os\n",
"import re\n",
"import torch\n",
"import torchinfo\n",
"\n",
"\n",
"\n",
"import pyvista as pv\n",
"from tqdm import tqdm\n",
"from pathlib import Path\n",
"from types import SimpleNamespace\n",
"import matplotlib.pyplot as plt\n",
"from pathlib import Path\n",
"import apex\n",
"import numpy as np\n",
"import hydra\n",
"from hydra.utils import to_absolute_path\n",
"from omegaconf import DictConfig, OmegaConf\n",
"\n",
"from torch.cuda.amp import GradScaler, autocast\n",
"from torch.nn.parallel import DistributedDataParallel\n",
"from torch.utils.data import DataLoader\n",
"from torch.utils.data.distributed import DistributedSampler\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"\n",
"from physicsnemo.distributed import DistributedManager\n",
"from physicsnemo.launch.utils import load_checkpoint, save_checkpoint\n",
"from physicsnemo.utils.sdf import signed_distance_field\n",
"\n",
"from physicsnemo.datapipes.cae.domino_datapipe import DoMINODataPipe\n",
"from physicsnemo.models.domino.model import DoMINO\n",
"from physicsnemo.utils.domino.utils import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Experiment Parameters and Variables\n",
"\n",
"In this section, we define all the necessary parameters and variables for our Ahmed body experiment. These parameters control various aspects of the training process, data processing, and model configuration.\n",
"\n",
"These parameters are carefully chosen based on:\n",
"- The physical dimensions of the Ahmed body\n",
"- The computational requirements of the DoMINO model\n",
"- The desired resolution for accurate flow prediction\n",
"- The available computational resources\n",
"- The specific requirements of the aerodynamic analysis\n",
"- `GEOMETRY_REP` contains the hyperparameters for the global geometry representation network. \n",
"- `GEOMETRY_LOCAL` contains the hyperparameters for the local geometry representation. \n",
"- As described in the theoretical section, the point convolution kernel relies on two additional factors: the radius of influence and the number of points included in the kernel. The radii of influence are specified as `volume_radii` and `surface_radii` within both `GEOMETRY_REP` and `GEOMETRY_LOCAL`. \n",
"- The number of points within the kernel is defined as `volume_neighbors_in_radius=[128, 128]` and `surface_neighbors_in_radius=[128]` for the local geometry representation. For the global geometry representation, these values are not explicitly set in the `config.yaml` file, so the default value of `10` is used. \n",
"- The **bounding box parameters** play a crucial role, as they define the computational domain for both volume and surface meshes, ensuring that all relevant flow features around the Ahmed body are accurately captured.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Directory and Path Configuration\n",
"EXPERIMENT_TAG = 4 # Unique identifier for this experiment run\n",
"PROJECT_NAME = \"ahmed_body_dataset\" # Name of the project\n",
"OUTPUT_DIR = Path(f\"./outputs/{PROJECT_NAME}/{EXPERIMENT_TAG}\") # Directory for experiment outputs\n",
"DATA_DIR = Path(\"/data/physicsnemo_ahmed_body_dataset_vv1/dataset\") # Root directory for dataset\n",
"\n",
"CHECKPOINT_DIR = OUTPUT_DIR / \"models\" # Directory for saving model checkpoints\n",
"SAVE_PATH = DATA_DIR / \"mesh_predictions_surf_final1\" # path to save prediction results\n",
"\n",
"# Ensure directories exist\n",
"\n",
"os.makedirs(CHECKPOINT_DIR, exist_ok=True)\n",
"\n",
"# Physical Variables\n",
"VOLUME_VARS = [\"p\"] # Volume variables to predict (pressure)\n",
"SURFACE_VARS = [\"p\", \"wallShearStress\"] # Surface variables to predict\n",
"MODEL_TYPE = \"surface\" # Type of model (surface-only prediction)\n",
"AIR_DENSITY = 1.205 # Air density in kg/m³\n",
"\n",
"# Training Hyperparameters\n",
"NUM_EPOCHS = 3 # Number of training epochs\n",
"LR = 0.001 # Learning rate\n",
"BATCH_SIZE = 1 # Batch size for training\n",
"GRID_RESOLUTION = [128, 64, 48] # Resolution of the interpolation grid\n",
"SURFACE_POINTS_SAMPLE = 8192\n",
"GEOMETRY_REP=SimpleNamespace(\n",
" geo_conv=SimpleNamespace(base_neurons=32, base_neurons_out=1, volume_radii=[0.1, 0.5], surface_radii=[0.05], hops=1),\n",
" geo_processor=SimpleNamespace(base_filters=8),\n",
" geo_processor_sdf=SimpleNamespace(base_filters=8)\n",
") # Hyperparameters for global geometry representation network\n",
"\n",
"GEOMETRY_LOCAL=SimpleNamespace(volume_neighbors_in_radius=[128, 128], surface_neighbors_in_radius=[128], volume_radii=[0.05, 0.1], surface_radii=[0.05], base_layer=512) # Hyperparameters for local geometry extraction \n",
"\n",
"NUM_SURFACE_NEIGHBORS = 7 # Number of neighbors for surface operations\n",
"NORMALIZATION = \"min_max_scaling\" # Data normalization method\n",
"INTEGRAL_LOSS_SCALING = 0 # Scaling factor for integral loss\n",
"GEOMETRY_ENCODING_TYPE= \"both\" # geometry encoder type, sdf, stl, both\n",
"NUM_SURF_VARS = 4 # Number of surface variables to predict, 3 for vectore (wallShearStress) and 1 for scalar (p)\n",
"CHECKPOINT_INTERVAL = 1 # Save checkpoint every N epochs\n",
"\n",
"\n",
"# Dataset Paths\n",
"DATA_PATHS = {\n",
" \"train\": f\"{DATA_DIR}/train_prepared_surface_data/\",\n",
" \"val\": f\"{DATA_DIR}/validation_prepared_surface_data/\",\n",
" \"test\": f\"{DATA_DIR}/test\"\n",
"}\n",
"\n",
"# Model and Scaling Factor Paths\n",
"MODEL_SAVE_DIR = \"./outputs/ahmed_body_dataset/4/models\"\n",
"SURF_SAVE_PATH = './outputs/ahmed_body_dataset/surface_scaling_factors.npy'\n",
"\n",
"# Bounding Box Configuration for Volume and Surface Meshes\n",
"BOUNDING_BOX = SimpleNamespace(\n",
" max=[0.5, 0.6, 0.6], # Maximum coordinates for volume mesh\n",
" min=[-2.5, -0.5, -0.5] # Minimum coordinates for volume mesh\n",
")\n",
"BOUNDING_BOX_SURF = SimpleNamespace(\n",
" max=[0.01, 0.6, 0.4], # Maximum coordinates for surface mesh\n",
" min=[-1.6, -0.01, -0.01] # Minimum coordinates for surface mesh\n",
")\n",
"\n",
"# Set cuDNN benchmark mode\n",
"torch.backends.cudnn.benchmark = True\n",
"torch.backends.cudnn.deterministic = False"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### **Step 2: Train the DoMINO Model**\n",
"\n",
"The fifth step in our workflow focuses on training the DoMINO model on our processed CFD data. This step is crucial because:\n",
"- It enables the model to learn complex fluid dynamics patterns\n",
"- Provides a foundation for accurate flow field predictions\n",
"- Allows for efficient inference on new geometries\n",
"- Supports distributed training for improved performance\n",
"\n",
"#### Understanding the Training Process\n",
"\n",
"The training process involves several key components:\n",
"1. Setting up distributed training environment\n",
"2. Creating and configuring datasets and dataloaders\n",
"3. Initializing the DoMINO model architecture\n",
"4. Implementing training and validation loops\n",
"5. Managing model checkpoints and metrics\n",
"\n",
"#### Key Components and Libraries\n",
"\n",
"We'll use the following for training:\n",
"\n",
"- **PyTorch**\n",
" - `torch.distributed`: For distributed training\n",
" - `torch.cuda`: For GPU acceleration\n",
" - `torch.optim`: For optimization algorithms\n",
"\n",
"- **Data Management**\n",
" - Custom dataset classes for CFD data\n",
" - Distributed samplers for efficient data loading\n",
" - Distributed samplers for efficient data loading\n",
"\n",
"#### Important Training Parameters\n",
"\n",
"During the training process, we need to consider:\n",
"- Batch size and learning rate\n",
"- Number of epochs and validation frequency\n",
"- Model architecture parameters\n",
"- Loss function configuration\n",
"- Checkpointing strategy\n",
"\n",
"#### Implementation Overview\n",
"\n",
"The training is implemented through several key components:\n",
"\n",
"1. **Model Creation**\n",
"Creates the DoMINO model, applies configuration, and wraps it with DistributedDataParallel if training is distributed.\n",
"DoMINO model is initialized with full configuration. the model is moved to the given device (usually a CUDA GPU).\n",
"If training in distributed mode, wrap the model with DistributedDataParallel to sync gradients across processes.\n",
" \n",
"```python\n",
"def create_model(device, rank, world_size):\n",
" \"\"\"Create and configure DoMINO model.\"\"\"\n",
" # Initializes model with specified parameters\n",
"```\n",
"2. **Create DoMINO dataset**\n",
"Constructs and returns a DoMINODataPipe object for a specified phase (`train` or `val`), fully configured for loading data.\n",
"Loads path-specific and global configs (grid resolution, surface vars, encoding, etc.).\n",
"\n",
"```python\n",
"def create_dataset(phase):\n",
" \"\"\"\n",
" Create DoMINO dataset for specified phase (train/val).\n",
" \n",
" Args:\n",
" phase (str): Dataset phase ('train' or 'val')\n",
" \n",
" Returns:\n",
" DoMINODataPipe: Configured dataset\n",
" \"\"\"\n",
"```\n",
"\n",
"3. **Training Loop**\n",
" Main training loop for running multiple epochs. Handles distributed settings, optimizer/scaler setup, and calls `run_epoch()` for training and evaluation.\n",
" It gets dataloaders and samplers for train/val datasets. The optimizer (FusedAdam) and gradient scaler will be set up.\n",
" \n",
" For each epoch:\n",
" - Update distributed sampler epoch (important for shuffling).\n",
" - Call run_epoch() to train and validate.\n",
" - Update best validation loss.\n",
" \n",
"```python\n",
"def train(model, device, rank, world_size):\n",
" \"\"\"Orchestrates the training process.\"\"\"\n",
" # Handles training loop, validation, and checkpointing\n",
"```\n",
"\n",
"4. **Runs one training epoch run_epoch()**\n",
"Runs one training epoch, performs forward and backward passes, computes losses, and evaluates on validation data. Supports mixed precision and distributed training.\n",
" \n",
"Step-by-step:\n",
"- Training phase:\n",
" - Set model to training mode.\n",
" - Use tqdm progress bar if rank is 0 (main process).\n",
" - For each batch:\n",
" - Move it to the correct device.\n",
" - Run model and compute predictions.\n",
" - Compute masked MSE loss using mse_loss_fn.\n",
" - Apply gradient scaling for mixed-precision training.\n",
" - Step the optimizer and clear gradients.\n",
" - Log training loss (on rank 0).\n",
"- Validation phase:\n",
" - Run inference without gradient tracking.\n",
" - Average validation loss over the val loader.\n",
"\n",
"- Checkpointing:\n",
" - Save best model if current val loss is better.\n",
" - Save periodic checkpoint if epoch meets interval.\n",
" - Return validation loss for tracking best model.\n",
"\n",
"Let's proceed with implementing these components and training our model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def mse_loss_fn(output, target, padded_value=-10):\n",
" \"\"\"\n",
" Compute masked MSE loss, ignoring padded values.\n",
" \n",
" Args:\n",
" output (torch.Tensor): Model predictions\n",
" target (torch.Tensor): Ground truth values\n",
" padded_value (float): Value used for padding (default: -10)\n",
" \n",
" Returns:\n",
" torch.Tensor: Mean squared error loss\n",
" \"\"\"\n",
" # Move target to same device as output\n",
" target = target.to(output.device)\n",
" # Create mask for non-padded values\n",
" mask = torch.abs(target - padded_value) > 1e-3\n",
" # Compute masked loss\n",
" masked_loss = torch.sum(((output - target) ** 2) * mask) / torch.sum(mask)\n",
" return masked_loss.mean()\n",
"\n",
"def create_dataset(phase):\n",
" \"\"\"\n",
" Create DoMINO dataset for specified phase (train/val).\n",
" \n",
" Args:\n",
" phase (str): Dataset phase ('train' or 'val')\n",
" \n",
" Returns:\n",
" DoMINODataPipe: Configured dataset\n",
" \"\"\"\n",
" return DoMINODataPipe(\n",
" DATA_PATHS[phase],\n",
" phase=phase,\n",
" grid_resolution=GRID_RESOLUTION,\n",
" surface_variables=SURFACE_VARS,\n",
" normalize_coordinates=True,\n",
" sampling=True,\n",
" sample_in_bbox=True,\n",
" volume_points_sample=8192,\n",
" surface_points_sample=SURFACE_POINTS_SAMPLE,\n",
" geom_points_sample=60000,\n",
" positional_encoding=False,\n",
" surface_factors=np.load(SURF_SAVE_PATH),\n",
" scaling_type=NORMALIZATION,\n",
" model_type=MODEL_TYPE,\n",
" bounding_box_dims=BOUNDING_BOX,\n",
" bounding_box_dims_surf=BOUNDING_BOX_SURF,\n",
" num_surface_neighbors=NUM_SURFACE_NEIGHBORS,\n",
" gpu_preprocessing=False\n",
" )\n",
"\n",
"\n",
"\n",
"def create_dataloaders(rank, world_size):\n",
" \"\"\"\n",
" Create train and validation dataloaders with distributed sampling.\n",
" \n",
" Args:\n",
" rank (int): Process rank\n",
" world_size (int): Total number of processes\n",
" \n",
" Returns:\n",
" tuple: (train_loader, val_loader, train_sampler, val_sampler)\n",
" \"\"\"\n",
" # Create datasets\n",
" train_dataset, val_dataset = create_dataset(\"train\"), create_dataset(\"val\")\n",
" \n",
" # Configure distributed samplers if needed\n",
" train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) if world_size > 1 else None\n",
" val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank) if world_size > 1 else None\n",
" \n",
" # Create dataloaders\n",
" return (\n",
" DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler, shuffle=train_sampler is None),\n",
" DataLoader(val_dataset, batch_size=BATCH_SIZE, sampler=val_sampler, shuffle=False),\n",
" train_sampler, val_sampler\n",
" )\n",
"\n",
"def create_model(device, rank, world_size):\n",
" \"\"\"\n",
" Create and configure DoMINO model with distributed training support.\n",
" \n",
" Args:\n",
" device (torch.device): Computation device\n",
" rank (int): Process rank\n",
" world_size (int): Total number of processes\n",
" \n",
" Returns:\n",
" DoMINO: Configured model (wrapped in DistributedDataParallel if world_size > 1)\n",
" \"\"\"\n",
"\n",
" \n",
" # Initialize model with configuration\n",
" model = DoMINO(\n",
" input_features=3,\n",
" output_features_vol=None,\n",
" output_features_surf=NUM_SURF_VARS,\n",
" model_parameters=SimpleNamespace(\n",
" interp_res=GRID_RESOLUTION,\n",
" surface_neighbors=NUM_SURFACE_NEIGHBORS,\n",
" use_surface_normals=True,\n",
" use_surface_area=True,\n",
" encode_parameters=True,\n",
" positional_encoding=False,\n",
" integral_loss_scaling_factor=INTEGRAL_LOSS_SCALING,\n",
" normalization=NORMALIZATION,\n",
" use_sdf_in_basis_func=True,\n",
" geometry_encoding_type= GEOMETRY_ENCODING_TYPE, # geometry encoder type, sdf, stl, both\n",
" geometry_rep=GEOMETRY_REP,\n",
" nn_basis_functions=SimpleNamespace(base_layer=512, fourier_features=False, num_modes=5),\n",
" parameter_model=SimpleNamespace(base_layer=512, scaling_params=[30.0, 1.226], fourier_features=False, num_modes=5),\n",
" position_encoder=SimpleNamespace(base_neurons=512),\n",
" geometry_local=GEOMETRY_LOCAL,\n",
" aggregation_model=SimpleNamespace(base_layer=512),\n",
" model_type=MODEL_TYPE\n",
" ),\n",
" ).to(device)\n",
" \n",
" # Wrap model for distributed training if needed\n",
" if world_size > 1:\n",
" model = DistributedDataParallel(\n",
" model, \n",
" device_ids=[rank], \n",
" output_device=rank, \n",
" find_unused_parameters=True\n",
" )\n",
" \n",
" return model\n",
"\n",
"def run_epoch(train_loader, val_loader, model, optimizer, scaler, device, epoch, best_vloss, rank, world_size):\n",
" \"\"\"\n",
" Run one training epoch with validation.\n",
" \n",
" Args:\n",
" train_loader (DataLoader): Training data loader\n",
" val_loader (DataLoader): Validation data loader\n",
" model (DoMINO): Model to train\n",
" optimizer (torch.optim.Optimizer): Optimizer\n",
" scaler (GradScaler): Gradient scaler for mixed precision\n",
" device (torch.device): Computation device\n",
" epoch (int): Current epoch number\n",
" best_vloss (float): Best validation loss so far\n",
" rank (int): Process rank\n",
" world_size (int): Total number of processes\n",
" \n",
" Returns:\n",
" float: Validation loss for this epoch\n",
" \"\"\"\n",
" # Training phase\n",
" model.train()\n",
" train_loss = 0.0\n",
" pbar = tqdm(train_loader, desc=f\"Epoch {epoch+1}/{NUM_EPOCHS}\") if rank == 0 else train_loader\n",
" \n",
" for batch in pbar:\n",
" # Move batch to device\n",
" batch = dict_to_device(batch, device)\n",
" \n",
" # Forward pass with mixed precision\n",
" with autocast():\n",
" _, pred_surf = model(batch)\n",
" loss = mse_loss_fn(pred_surf, batch[\"surface_fields\"])\n",
" \n",
" # Backward pass with gradient scaling\n",
" scaler.scale(loss).backward()\n",
" scaler.step(optimizer)\n",
" scaler.update()\n",
" optimizer.zero_grad()\n",
" \n",
" # Update loss tracking\n",
" train_loss += loss.item()\n",
" if rank == 0:\n",
" pbar.set_postfix({\n",
" \"train_loss\": f\"{train_loss/(pbar.n+1):.5e}\", \n",
" \"lr\": f\"{optimizer.param_groups[0]['lr']:.2e}\"\n",
" })\n",
" \n",
" # Compute average training loss\n",
" avg_train_loss = train_loss / len(train_loader)\n",
" \n",
" # Validation phase\n",
" model.eval()\n",
" with torch.no_grad():\n",
" val_loss = sum(\n",
" mse_loss_fn(model(dict_to_device(batch, device))[1], batch[\"surface_fields\"].to(device)).item() \n",
" for batch in val_loader\n",
" ) / len(val_loader)\n",
" \n",
" # Handle distributed training metrics\n",
" if world_size > 1:\n",
" avg_train_loss, val_loss = [torch.tensor(v, device=device) for v in [avg_train_loss, val_loss]]\n",
" torch.distributed.all_reduce(avg_train_loss, op=torch.distributed.ReduceOp.SUM)\n",
" torch.distributed.all_reduce(val_loss, op=torch.distributed.ReduceOp.SUM)\n",
" avg_train_loss, val_loss = avg_train_loss.item() / world_size, val_loss.item() / world_size\n",
" \n",
" # Save checkpoints on main process\n",
" if rank == 0:\n",
" if val_loss < best_vloss:\n",
" save_checkpoint(\n",
" os.path.join(MODEL_SAVE_DIR, \"best_model\"), \n",
" models=model,\n",
" optimizer=optimizer,\n",
" scaler=scaler\n",
" )\n",
"\n",
" if (epoch + 1) % CHECKPOINT_INTERVAL == 0:\n",
" save_checkpoint(\n",
" MODEL_SAVE_DIR, \n",
" models=model, \n",
" optimizer=optimizer, \n",
" scaler=scaler, \n",
" epoch=epoch\n",
" )\n",
" \n",
" return val_loss\n",
"\n",
"def train(model, device, rank, world_size):\n",
" \"\"\"\n",
" Function that orchestrates the training process.\n",
" Handles distributed training setup, model creation and training loop.\n",
" \"\"\"\n",
"\n",
"\n",
" # Create output directory on main process\n",
" os.makedirs(MODEL_SAVE_DIR, exist_ok=True) if rank == 0 else None\n",
" \n",
" # Set up data\n",
" train_loader, val_loader, train_sampler, val_sampler = create_dataloaders(rank, world_size)\n",
"\n",
" optimizer = apex.optimizers.FusedAdam(model.parameters(), lr=0.001)\n",
" \n",
" # Initialize learning rate scheduler and gradient scaler\n",
" #scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[1, 2], gamma=0.5)\n",
" scaler = GradScaler()\n",
" \n",
" # Training loop\n",
" best_vloss = float('inf')\n",
" for epoch in range(NUM_EPOCHS):\n",
" if world_size > 1:\n",
" train_sampler.set_epoch(epoch)\n",
" best_vloss = min(\n",
" best_vloss, \n",
" run_epoch(\n",
" train_loader, val_loader, model, optimizer, \n",
" scaler, device, epoch, best_vloss, rank, world_size\n",
" )\n",
" )\n",
" #scheduler.step()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lets run the train for few epochs:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize distributed training\n",
"os.environ[\"RANK\"] = f\"0\"\n",
"os.environ[\"WORLD_SIZE\"] = f\"1\"\n",
"os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
"os.environ[\"MASTER_PORT\"] = str(12355)\n",
"os.environ[\"LOCAL_RANK\"] = f\"0\"\n",
"\n",
"\n",
"DistributedManager.initialize()\n",
"dist = DistributedManager()\n",
"device=dist.device\n",
"rank=dist.rank\n",
"world_size=dist.world_size\n",
"print(device)\n",
"# Set up model\n",
"model = create_model(device, rank, world_size)\n",
"# Run training\n",
"train(model, device, rank, world_size)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## **Load Model Checkpoint & Run Inference**\n",
"\n",
"The sixth step in our workflow focuses on evaluating our trained DoMINO model by loading the best checkpoint and running inference on sample cases. \n",
"To run the inference, the script needs several key **inputs**. It requires the 3D shape of the object defined in an STL file and a corresponding surface mesh provided as a VTP file, which also contain results from a traditional simulation for comparison purposes. \n",
"Crucially, it needs the pre-trained DoMINO AI model loaded from a checkpoint file. Additionally, basic flow conditions like air speed (`STREAM_VELOCITY`) and density (`AIR_DENSITY`), along with specific scaling factors saved from the training phase (used to convert model outputs to physical values), must be provided.\n",
"\n",
"\n",
"As its main **output**, the script generates new VTP files for each tested geometry. These files include the original surface mesh data but are augmented with new data fields representing the AI model's predictions for aerodynamic quantities such as surface pressure and wall shear stress. Furthermore, the script calculates aerodynamic forces based on these predictions and prints a comparison against forces derived from reference data directly to the console.\n",
"The code snippet below takes geometry files (STL) and corresponding simulation setup data (partially from VTP files and config parameters), preprocesses them, feeds them into the model to predict aerodynamic quantities (like surface pressure and shear stress), and saves these predictions back into VTP files for analysis and visualization.\n",
"\n",
"#### Understanding the Testing Process\n",
"\n",
"The testing process involves several key components:\n",
"1. Loading the best model checkpoint\n",
"2. Preparing test data\n",
"3. Running inference on test cases\n",
"4. Analyzing prediction results\n",
"5. Comparing with ground truth values\n",
"\n",
"#### Key Components and Libraries\n",
"\n",
"We'll use the following libraries for testing:\n",
"\n",
"1. **PyTorch**\n",
" - `torch.load()`: For loading model checkpoints\n",
" - `model.load_state_dict()`: For restoring model weights\n",
" - `torch.no_grad()`: For efficient inference\n",
"\n",
"2. **Custom Testing Functions**\n",
" - `test_step()`: For running inference on test cases\n",
" - Data processing utilities for test data preparation\n",
"\n",
"#### Implementation Overview\n",
"\n",
"The testing is implemented through several key components:\n",
"\n",
"1. **Function test_step**\n",
" Within the **test_step** function, several key operations execute in sequence. Initially, torch.no_grad() is used to disable gradient tracking in PyTorch, optimizing performance by saving memory and computation time as gradients are unnecessary during inference. \n",
"Next, the necessary data is prepared by extracting inputs like air density, stream velocity, geometry coordinates, bounding box grid information (surf_grid), and the Signed Distance Field (SDF) from the data_dict; the SDF is particularly important as it helps the model understand the position of points relative to the geometry surface. \n",
"Following this, a global geometry encoding is generated using model.geo_rep_surface, which takes normalized geometry points, the grid, and the SDF to create a comprehensive representation of the overall shape. \n",
"Surface-specific data, including mesh points, normals, areas, and neighbor details (found via methods like KDTree during preprocessing), are then extracted. \n",
"To refine the focus, model.geo_encoding_local_surface extracts relevant local geometric features from the global encoding specifically for the surface points where predictions are needed. \n",
"Positional awareness is added using model.position_encoder to encode the relative location of surface points. \n",
"The core prediction then occurs via model.calculate_solution_with_neighbors, combining local geometry, positional encoding, surface point details, neighbor information, and flow conditions to estimate the target surface fields like pressure coefficient or wall shear stress. Since the model output is normalized, a final un-normalization step converts these predictions back into physical units using the provided surf_factors, stream velocity, and air density. The function concludes by returning the predicted surface fields (prediction_surf), as this specific code path concentrates only on surface predictions.\n",
"\n",
"\n",
"\n",
"```python\n",
"def test_step(model, data_dict, surf_factors, device):\n",
" \"\"\"\n",
" Executes the core inference logic for a single test case using the trained DoMINO model.\n",
" \n",
" Args:\n",
" model (DoMINO): The trained model\n",
" data_dict: A dictionary containing all necessary input data for this specific test case (geometry, mesh points, flow conditions, etc.), already preprocessed and formatted.\n",
" surf_factors: Scaling factors used during training to normalize the target surface data. Needed here to un-normalize the model's predictions back to physical values.\n",
" device: The computational device (CPU or GPU) to run the calculations on.\n",
" \n",
" Returns:\n",
" tuple: (prediction_vol, prediction_surf) - Model predictions for volume and surface\n",
" \"\"\"\n",
"```\n",
"\n",
"\n",
"2. **Function: test**\n",
"```python\n",
"def test(model, test_dataloader, device):\n",
" \"\"\"\n",
" Run testing on the model using the provided test dataloader.\n",
" \n",
" Args:\n",
" model (DoMINO): The trained model\n",
" test_dataloader (DataLoader): DataLoader containing test data\n",
" device (torch.device): Device to run inference on\n",
" \n",
" Returns:\n",
" list: List of tuples containing (prediction_vol, prediction_surf) for each test case\n",
" \"\"\"\n",
"```\n",
"\n",
"On the other hand, the test function is a higher-level function that organizes and controls the overall testing process. It begins by checking if surface scaling factors have been pre-computed and stored in a .npy file. If the file exists, it loads these factors; if not, it defaults to None. \n",
"\n",
"**The function then loads a pre-trained model from a checkpoint file (DoMINO.0.....pt) and loads its state into the model**. \n",
"```python\n",
" # Load the best model checkpoint\n",
" best_checkpoint = torch.load(CHECKPOINT_DIR / \"best_model/DoMINO.0.401.pt\")\n",
" model.load_state_dict(best_checkpoint) # Load the model state\n",
" print(\"Model loaded\")\n",
"```\n",
"\n",
"After the model is loaded, it creates a directory for saving predictions if it doesn't already exist. The dirname parameter is used to extract a tag, which helps identify the current test case.\n",
"\n",
"Next, the function proceeds to load the necessary input files. It reads an STL file that contains the 3D geometry of the surface and extracts relevant data like vertices, faces, and areas. The bounding box dimensions are calculated, and the surface’s center of mass is computed. Then, it prepares a grid (surf_grid) and calculates the signed distance function (SDF) over this grid using the surface geometry, which helps in understanding the geometry’s proximity to the grid points. The function then reads the VTP file, which holds additional surface-related data such as pressure and shear force values.\n",
"\n",
"The surface fields are then prepared by interpolating the surface mesh data and its corresponding attributes. These fields are normalized to fit within the bounding box dimensions. The data dictionary is assembled, containing all the relevant inputs needed for the model’s prediction. This dictionary includes things like normalized surface coordinates, surface areas, and field values such as stream velocity and air density. The dictionary is converted to PyTorch tensors, making it compatible with the model.\n",
"\n",
"The test_step function is then called with this prepared data to compute the model's predictions. After the predictions are generated, the function compares the predicted surface forces (pressure and shear stress) with the true values from the surface fields. It calculates the predicted forces and prints out the comparison between the predicted and true values. The predicted surface fields are then converted to VTK format and saved to a file. Finally, the function finishes by returning, completing the testing process. This function provides a complete pipeline for testing a trained model on surface data, generating predictions, and saving them for further analysis.\n",
"\n",
"\n",
"Let's proceed with loading our trained model and running the tests:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def test_step(model, data_dict, surf_factors, device):\n",
" \"\"\"\n",
" Run a single test step on the model.\n",
" \n",
" Args:\n",
" model (DoMINO): The trained model\n",
" data_dict (dict): Dictionary containing test data\n",
" device (torch.device): Device to run inference on\n",
" \n",
" Returns:\n",
" tuple: (prediction_vol, prediction_surf) - Model predictions for volume and surface\n",
" \"\"\"\n",
" \n",
" avg_tloss_vol = 0.0 # Placeholder for average volume loss (not currently used)\n",
" avg_tloss_surf = 0.0 # Placeholder for average surface loss (not currently used)\n",
"\n",
" with torch.no_grad(): # Disable gradient computation to save memory and computation during inference\n",
" # Move input data to the specified device (CPU or GPU)\n",
" data_dict = dict_to_device(data_dict, device)\n",
"\n",
" # Extract non-dimensionalization factors (important for scaling the inputs)\n",
" air_density = data_dict[\"air_density\"]\n",
" stream_velocity = data_dict[\"stream_velocity\"]\n",
" length_scale = data_dict[\"length_scale\"]\n",
"\n",
" # Extract geometry coordinates (nodes of the surface)\n",
" geo_centers = data_dict[\"geometry_coordinates\"]\n",
"\n",
" # Extract bounding box grid and signed distance function (SDF) grid for the surface\n",
" s_grid = data_dict[\"surf_grid\"]\n",
" sdf_surf_grid = data_dict[\"sdf_surf_grid\"]\n",
"\n",
" # Extract scaling factors for surface (used for un-normalization)\n",
" surf_max = data_dict[\"surface_min_max\"][:, 1]\n",
" surf_min = data_dict[\"surface_min_max\"][:, 0]\n",
"\n",
" # Normalize geometry coordinates to fit within a bounding box [-1, 1]\n",
" geo_centers_surf = (\n",
" 2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1\n",
" )\n",
"\n",
" # Generate geometric representation of the surface\n",
" encoding_g_surf = model.geo_rep_surface(\n",
" geo_centers_surf, s_grid, sdf_surf_grid\n",
" )\n",
"\n",
" prediction_vol = None # Volume prediction is not computed in this function\n",
"\n",
" # Extract information about the surface: mesh centers, normals, areas, and neighbors\n",
" surface_mesh_centers = data_dict[\"surface_mesh_centers\"]\n",
" surface_normals = data_dict[\"surface_normals\"]\n",
" surface_areas = data_dict[\"surface_areas\"]\n",
"\n",
" surface_mesh_neighbors = data_dict[\"surface_mesh_neighbors\"]\n",
" surface_neighbors_normals = data_dict[\"surface_neighbors_normals\"]\n",
" surface_neighbors_areas = data_dict[\"surface_neighbors_areas\"]\n",
"\n",
" surface_areas = torch.unsqueeze(surface_areas, -1) # Add extra dimension\n",
" surface_neighbors_areas = torch.unsqueeze(surface_neighbors_areas, -1) # Add extra dimension\n",
" pos_surface_center_of_mass = data_dict[\"pos_surface_center_of_mass\"]\n",
" num_points = surface_mesh_centers.shape[1] # Number of surface points\n",
" \n",
" # Extract target surface fields (for comparison later)\n",
" target_surf = data_dict[\"surface_fields\"]\n",
" prediction_surf = np.zeros_like(target_surf.cpu().numpy()) # Initialize prediction array\n",
"\n",
" start_time = time.time() # Record the start time for performance measurement\n",
"\n",
" # Generate local geometric encoding for each surface point\n",
" geo_encoding_local = model.geo_encoding_local(\n",
" 0.5 * encoding_g_surf, surface_mesh_centers, s_grid, mode=\"surface\"\n",
" )\n",
"\n",
"\n",
" # Position encoding based on the center of mass of the surface\n",
" pos_encoding = pos_surface_center_of_mass\n",
" pos_encoding = model.position_encoder(pos_encoding, eval_mode=\"surface\")\n",
"\n",
" # Perform the model prediction using neighbors and other surface data\n",
" tpredictions = (\n",
" model.calculate_solution_with_neighbors(\n",
" surface_mesh_centers,\n",
" geo_encoding_local,\n",
" pos_encoding,\n",
" surface_mesh_neighbors,\n",
" surface_normals,\n",
" surface_neighbors_normals,\n",
" surface_areas,\n",
" surface_neighbors_areas,\n",
" stream_velocity,\n",
" air_density,\n",
" )\n",
" )\n",
"\n",
" # Convert model predictions to numpy arrays for further processing\n",
" prediction_surf = tpredictions.cpu().numpy()\n",
"\n",
" # Unnormalize the surface predictions and scale them using physical quantities\n",
" prediction_surf = (\n",
" unnormalize(prediction_surf, surf_factors[0], surf_factors[1])\n",
" * stream_velocity[0, 0].cpu().numpy() ** 2.0\n",
" * air_density[0, 0].cpu().numpy()\n",
" )\n",
"\n",
" return prediction_vol, prediction_surf # Return volume and surface predictions\n",
"\n",
"def test(filepath, dirname, CKPT_NUMBER: int=0):\n",
" \"\"\"\n",
" High-level function to manage the testing pipeline, including data preparation, model loading, and prediction saving.\n",
" \n",
" Args:\n",
" filepath (str): Path to the test data directory\n",
" dirname (str): Directory name for the test case\n",
" \n",
" Returns:\n",
" None\n",
" \"\"\"\n",
" # Define names of surface variables to be predicted\n",
" surface_variable_names = SURFACE_VARS\n",
" \n",
" # Check if surface scaling factors are available\n",
" surf_save_path = os.path.join(\n",
" \"outputs\", PROJECT_NAME , \"surface_scaling_factors.npy\"\n",
" )\n",
" if os.path.exists(surf_save_path):\n",
" surf_factors = np.load(surf_save_path) # Load scaling factors if available\n",
" else:\n",
" surf_factors = None # If not available, set to None\n",
" \n",
" # Load the best model checkpoint\n",
" checkpoint_path = max(\n",
" (CHECKPOINT_DIR / \"best_model\").glob(\"DoMINO.0.*.pt\"),\n",
" key=lambda p: p.stat().st_mtime,\n",
" )\n",
" print(f\"Loading checkpoint: {checkpoint_path.name}\") # Print only the checkpoint name\n",
" best_checkpoint = torch.load(checkpoint_path)\n",
" model.load_state_dict(best_checkpoint) # Load the model state\n",
" \n",
" # Set the path to save predictions\n",
" pred_save_path = SAVE_PATH\n",
" create_directory(pred_save_path) # Create the output directory if it doesn't exist\n",
" \n",
" # Extract test case identifier from the directory name\n",
" tag = int(re.findall(r\"(\\w+?)(\\d+)\", dirname)[0][1])\n",
" vtp_path = filepath # Path to the VTP file with surface data\n",
" \n",
" # Prepare the path to save predicted results\n",
" vtp_pred_save_path = os.path.join(\n",
" pred_save_path, f\"boundary_{tag}_predicted.vtp\"\n",
" )\n",
" \n",
" # Load the STL file for the geometry\n",
" path_stl = Path(filepath)\n",
" stl_path = path_stl.parent.parent.joinpath(\"test_stl_files\", path_stl.stem + \".stl\")\n",
" print(\"stl_path::\", stl_path)\n",
" print(\"filepath::\", filepath)\n",
"\n",
"\n",
" info_path = path_stl.with_name(path_stl.stem + \".txt\").parent.parent.joinpath(\"test_info\", path_stl.stem + \"_info.txt\")\n",
" print(\"info_path:::\", info_path)\n",
"\n",
" with open(info_path, \"r\") as file:\n",
" for line in file:\n",
" #print(\"line::\",line)\n",
" if \"Velocity\" in line:\n",
" velocity = float(line.split(\":\")[1].strip())\n",
" print(f\"Velocity: {velocity}\")\n",
"\n",
" STREAM_VELOCITY = velocity\n",
" \n",
" # Read and process the STL file\n",
" reader = pv.get_reader(stl_path)\n",
" mesh_stl = reader.read()\n",
" stl_vertices = mesh_stl.points\n",
" stl_faces = np.array(mesh_stl.faces).reshape((-1, 4))[:, 1:] # Extract triangular faces\n",
" mesh_indices_flattened = stl_faces.flatten()\n",
" length_scale = np.amax(np.amax(stl_vertices, 0) - np.amin(stl_vertices, 0)) # Compute scale of the geometry\n",
" stl_sizes = mesh_stl.compute_cell_sizes(length=False, area=True, volume=False)\n",
" stl_sizes = np.array(stl_sizes.cell_data[\"Area\"], dtype=np.float32)\n",
" stl_centers = np.array(mesh_stl.cell_centers().points, dtype=np.float32)\n",
" \n",
" # Calculate the center of mass of the surface\n",
" center_of_mass = calculate_center_of_mass(stl_centers, stl_sizes)\n",
" \n",
" # Extract bounding box dimensions for the surface\n",
" bounding_box_dims_surf = []\n",
" bounding_box_dims_surf.append(np.asarray(BOUNDING_BOX_SURF.max))\n",
" bounding_box_dims_surf.append(np.asarray(BOUNDING_BOX_SURF.min))\n",
" s_max = np.float32(bounding_box_dims_surf[0])\n",
" s_min = np.float32(bounding_box_dims_surf[1])\n",
" \n",
" # Create a 3D grid for the surface\n",
" nx, ny, nz = GRID_RESOLUTION\n",
" surf_grid = create_grid(s_max, s_min, [nx, ny, nz])\n",
" surf_grid_reshaped = surf_grid.reshape(nx * ny * nz, 3)\n",
" \n",
" # Compute the Signed Distance Field (SDF) on the surface grid\n",
" sdf_surf_grid = (\n",
" signed_distance_field(\n",
" stl_vertices,\n",
" mesh_indices_flattened,\n",
" surf_grid_reshaped,\n",
" use_sign_winding_number=True,\n",
" )\n",
" .reshape(nx, ny, nz)\n",
" )\n",
" surf_grid = np.float32(surf_grid)\n",
" sdf_surf_grid = np.float32(sdf_surf_grid)\n",
" surf_grid_max_min = np.float32(np.asarray([s_min, s_max]))\n",
" \n",
" # Read the VTP file containing surface data\n",
" reader = vtk.vtkXMLPolyDataReader()\n",
" reader.SetFileName(vtp_path)\n",
" reader.Update()\n",
" polydata_surf = reader.GetOutput()\n",
" celldata_all = get_node_to_elem(polydata_surf)\n",
" celldata = celldata_all.GetCellData()\n",
" surface_fields = get_fields(celldata, surface_variable_names)\n",
" surface_fields = np.concatenate(surface_fields, axis=-1)\n",
" mesh = pv.PolyData(polydata_surf)\n",
" \n",
" # Extract surface mesh coordinates, neighbors, and normals\n",
" surface_coordinates = np.array(mesh.cell_centers().points, dtype=np.float32)\n",
" interp_func = KDTree(surface_coordinates)\n",
" dd, ii = interp_func.query(surface_coordinates, k=NUM_SURFACE_NEIGHBORS)\n",
" surface_neighbors = surface_coordinates[ii]\n",
" surface_neighbors = surface_neighbors[:, 1:]\n",
" surface_normals = np.array(mesh.cell_normals, dtype=np.float32)\n",
" surface_sizes = mesh.compute_cell_sizes(length=False, area=True, volume=False)\n",
" surface_sizes = np.array(surface_sizes.cell_data[\"Area\"], dtype=np.float32)\n",
" \n",
" # Normalize the surface normals and neighbors\n",
" surface_normals = (\n",
" surface_normals / np.linalg.norm(surface_normals, axis=1)[:, np.newaxis]\n",
" )\n",
" surface_neighbors_normals = surface_normals[ii]\n",
" surface_neighbors_normals = surface_neighbors_normals[:, 1:]\n",
" surface_neighbors_sizes = surface_sizes[ii]\n",
" surface_neighbors_sizes = surface_neighbors_sizes[:, 1:]\n",
" \n",
" # Calculate the grid resolution and normalize the surface data\n",
" dx, dy, dz = (\n",
" (s_max[0] - s_min[0]) / nx,\n",
" (s_max[1] - s_min[1]) / ny,\n",
" (s_max[2] - s_min[2]) / nz,\n",
" )\n",
" pos_surface_center_of_mass = surface_coordinates - center_of_mass\n",
" surface_coordinates = normalize(surface_coordinates, s_max, s_min)\n",
" surface_neighbors = normalize(surface_neighbors, s_max, s_min)\n",
" surf_grid = normalize(surf_grid, s_max, s_min)\n",
" \n",
" # Prepare the data dictionary for model input\n",
" geom_centers = np.float32(stl_vertices)\n",
" data_dict = {\n",
" \"pos_surface_center_of_mass\": np.float32(pos_surface_center_of_mass),\n",
" \"geometry_coordinates\": np.float32(geom_centers),\n",
" \"surf_grid\": np.float32(surf_grid),\n",
" \"sdf_surf_grid\": np.float32(sdf_surf_grid),\n",
" \"surface_mesh_centers\": np.float32(surface_coordinates),\n",
" \"surface_mesh_neighbors\": np.float32(surface_neighbors),\n",
" \"surface_normals\": np.float32(surface_normals),\n",
" \"surface_neighbors_normals\": np.float32(surface_neighbors_normals),\n",
" \"surface_areas\": np.float32(surface_sizes),\n",
" \"surface_neighbors_areas\": np.float32(surface_neighbors_sizes),\n",
" \"surface_fields\": np.float32(surface_fields),\n",
" \"surface_min_max\": np.float32(surf_grid_max_min),\n",
" \"length_scale\": np.array(length_scale, dtype=np.float32),\n",
" \"stream_velocity\": np.expand_dims(\n",
" np.array(STREAM_VELOCITY, dtype=np.float32), axis=-1\n",
" ),\n",
" \"air_density\": np.expand_dims(\n",
" np.array(AIR_DENSITY, dtype=np.float32), axis=-1\n",
" ),\n",
" }\n",
" \n",
" # Convert data dictionary to PyTorch tensors\n",
" data_dict = {\n",
" key: torch.from_numpy(np.expand_dims(np.float32(value), 0))\n",
" for key, value in data_dict.items()\n",
" }\n",
" \n",
" # Perform a test step to get the predictions\n",
" prediction_vol, prediction_surf = test_step(\n",
" model, data_dict, surf_factors, device\n",
" )\n",
" \n",
" # Process the predicted and true surface values to compute forces\n",
" surface_sizes = np.expand_dims(surface_sizes, -1)\n",
" pres_x_pred = np.sum(\n",
" prediction_surf[0, :, 0] * surface_normals[:, 0] * surface_sizes[:, 0]\n",
" )\n",
" shear_x_pred = np.sum(prediction_surf[0, :, 1] * surface_sizes[:, 0])\n",
" pres_x_true = np.sum(\n",
" surface_fields[:, 0] * surface_normals[:, 0] * surface_sizes[:, 0]\n",
" )\n",
" shear_x_true = np.sum(surface_fields[:, 1] * surface_sizes[:, 0])\n",
" force_x_pred = np.sum(\n",
" prediction_surf[0, :, 0] * surface_normals[:, 0] * surface_sizes[:, 0]\n",
" - prediction_surf[0, :, 1] * surface_sizes[:, 0]\n",
" )\n",
" force_x_true = np.sum(\n",
" surface_fields[:, 0] * surface_normals[:, 0] * surface_sizes[:, 0]\n",
" - surface_fields[:, 1] * surface_sizes[:, 0]\n",
" )\n",
" \n",
" # Print the computed forces for comparison\n",
" print(dirname, force_x_pred, force_x_true)\n",
" \n",
" # Convert predictions to VTK format and save the results\n",
" surfParam_vtk = numpy_support.numpy_to_vtk(prediction_surf[0, :, 0:1])\n",
" surfParam_vtk.SetName(f\"{surface_variable_names[0]}Pred\")\n",
" celldata_all.GetCellData().AddArray(surfParam_vtk)\n",
" surfParam_vtk = numpy_support.numpy_to_vtk(prediction_surf[0, :, 1:])\n",
" surfParam_vtk.SetName(f\"{surface_variable_names[1]}Pred\")\n",
" celldata_all.GetCellData().AddArray(surfParam_vtk)\n",
" write_to_vtp(celldata_all, vtp_pred_save_path) # Save to VTP file\n",
" \n",
" return # End of the test function"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"input_path = DATA_PATHS[\"test\"]\n",
"\n",
"CHECKPOINT_NUMBER = 1\n",
"\n",
"dirnames = get_filenames(input_path)\n",
"\n",
"for count, dirname in enumerate(dirnames):\n",
" print(f\"Processing file {dirname}\")\n",
" filepath = os.path.join(input_path, dirname)\n",
" print(\"filepath::\",filepath)\n",
" test(filepath, dirname, CKPT_NUMBER=CHECKPOINT_NUMBER)\n",
"\n",
"folder = Path(SAVE_PATH)\n",
"predcited_files = list(folder.glob(\"*.vtp\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## **Visualizing the predicted results**\n",
"You can Visualize the predicted surface pressure using either PyVista or ParaView. In the following, use we `pyvista` and display both the predicted and ground truth pressure values, which are stored in .vtp files located in the SAVE_PATH directory."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"##### import pyvista as pv\n",
"\n",
"# Start virtual framebuffer for off-screen rendering (useful in Jupyter/containers)\n",
"pv.start_xvfb()\n",
"\n",
"# Read the VTP mesh\n",
"mesh = pv.read(f\"{DATA_DIR}/mesh_predictions_surf_final1/boundary_119_predicted.vtp\")\n",
"print(\"Available cell data keys:\", mesh.cell_data.keys())\n",
"\n",
"# Create a Plotter with 2 vertical subplots\n",
"plotter = pv.Plotter(shape=(2, 1), window_size=[1600, 800], off_screen=True)\n",
"\n",
"# Plot 'p' (ground truth or reference)\n",
"plotter.subplot(0, 0)\n",
"plotter.add_text(\"Pressure (p)\", font_size=12)\n",
"plotter.add_mesh(mesh, scalars=\"p\", show_edges=False)\n",
"\n",
"# Plot 'pPred' (predicted pressure)\n",
"plotter.subplot(1, 0)\n",
"plotter.add_text(\"Predicted Pressure (pPred)\", font_size=12)\n",
"plotter.add_mesh(mesh, scalars=\"pPred\", show_edges=False)\n",
"\n",
"# Show both subplots\n",
"plotter.show(jupyter_backend='static')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion\n",
"This notebook successfully demonstrates the complete workflow for training and deploying the DoMINO (Decomposable Multiscale Iterative Neural Operator) model on automotive aerodynamics data. Through comprehensive exploration of DoMINO's sophisticated architecture—which combines global and local geometric representations with a DeepONet-inspired aggregation network—we achieved successful training using distributed training capabilities with mixed precision optimization. The model learned to predict surface pressure and wall shear stress fields from the Ahmed body dataset, demonstrating the effectiveness of its multi-scale geometric encoding approach that processes STL geometries directly without requiring mesh generation. Post-training evaluation revealed the model's capability to generalize to new geometries, generating predictions that can be directly compared against traditional CFD simulations, while the force calculations and aerodynamic coefficient predictions provide quantitative validation of the model's accuracy. The successful implementation offers significant potential for rapid aerodynamic analysis during vehicle design iterations, real-time flow field predictions for design optimization, and reduced computational costs compared to traditional CFD simulations. By combining geometric understanding with deep learning, DoMINO represents a promising step toward more efficient and accessible computational fluid dynamics tools that could revolutionize how we approach aerodynamic simulations in automotive engineering."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os._exit(00)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
|