summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--COPYRIGHT.txt5
-rw-r--r--modules/denoise/SCsub138
-rw-r--r--modules/denoise/config.py12
-rw-r--r--modules/denoise/denoise_wrapper.cpp66
-rw-r--r--modules/denoise/denoise_wrapper.h38
-rw-r--r--modules/denoise/lightmap_denoiser.cpp65
-rw-r--r--modules/denoise/lightmap_denoiser.h56
-rw-r--r--modules/denoise/register_types.cpp49
-rw-r--r--modules/denoise/register_types.h39
-rw-r--r--modules/denoise/resource_to_cpp.py68
-rw-r--r--thirdparty/README.md31
-rw-r--r--thirdparty/oidn/LICENSE.txt202
-rw-r--r--thirdparty/oidn/common/barrier.h52
-rw-r--r--thirdparty/oidn/common/exception.h45
-rw-r--r--thirdparty/oidn/common/platform.cpp114
-rw-r--r--thirdparty/oidn/common/platform.h131
-rw-r--r--thirdparty/oidn/common/ref.h163
-rw-r--r--thirdparty/oidn/common/tensor.cpp83
-rw-r--r--thirdparty/oidn/common/tensor.h66
-rw-r--r--thirdparty/oidn/common/thread.cpp297
-rw-r--r--thirdparty/oidn/common/thread.h202
-rw-r--r--thirdparty/oidn/common/timer.h49
-rw-r--r--thirdparty/oidn/core/api.cpp408
-rw-r--r--thirdparty/oidn/core/autoencoder.cpp535
-rw-r--r--thirdparty/oidn/core/autoencoder.h120
-rw-r--r--thirdparty/oidn/core/buffer.h75
-rw-r--r--thirdparty/oidn/core/common.h136
-rw-r--r--thirdparty/oidn/core/device.cpp238
-rw-r--r--thirdparty/oidn/core/device.h102
-rw-r--r--thirdparty/oidn/core/filter.cpp27
-rw-r--r--thirdparty/oidn/core/filter.h52
-rw-r--r--thirdparty/oidn/core/image.h111
-rw-r--r--thirdparty/oidn/core/input_reorder.h232
-rw-r--r--thirdparty/oidn/core/math.h78
-rw-r--r--thirdparty/oidn/core/network.cpp436
-rw-r--r--thirdparty/oidn/core/network.h112
-rw-r--r--thirdparty/oidn/core/node.h142
-rw-r--r--thirdparty/oidn/core/output_reorder.h126
-rw-r--r--thirdparty/oidn/core/transfer_function.cpp103
-rw-r--r--thirdparty/oidn/core/transfer_function.h201
-rw-r--r--thirdparty/oidn/core/upsample.h92
-rw-r--r--thirdparty/oidn/core/weights_reorder.h99
-rw-r--r--thirdparty/oidn/include/OpenImageDenoise/oidn.h214
-rw-r--r--thirdparty/oidn/include/OpenImageDenoise/oidn.hpp468
-rw-r--r--thirdparty/oidn/include/OpenImageDenoise/version.h23
-rw-r--r--thirdparty/oidn/mkl-dnn/LICENSE214
-rw-r--r--thirdparty/oidn/mkl-dnn/include/mkldnn.h1771
-rw-r--r--thirdparty/oidn/mkl-dnn/include/mkldnn.hpp2615
-rw-r--r--thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h98
-rw-r--r--thirdparty/oidn/mkl-dnn/include/mkldnn_types.h1415
-rw-r--r--thirdparty/oidn/mkl-dnn/include/mkldnn_version.h32
-rw-r--r--thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in32
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp104
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp240
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp550
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/concat.cpp86
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp211
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/convolution.cpp200
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp56
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp348
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp188
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp293
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp84
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp161
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/engine.cpp75
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/engine.hpp119
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp106
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp56
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp321
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/lrn.cpp91
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp170
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp280
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/memory.cpp238
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/memory.hpp63
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp212
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp400
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp295
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp131
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp365
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp115
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp277
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp77
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/nstl.hpp193
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/pooling.cpp114
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp238
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive.cpp103
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive.hpp76
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp290
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp183
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp78
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp174
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp90
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp68
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp89
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp79
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/query.cpp59
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/reorder.cpp68
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp85
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/rnn.cpp400
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp280
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp112
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp36
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp72
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp121
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/softmax.cpp68
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp161
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/stream.cpp46
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/stream.hpp44
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/sum.cpp79
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp143
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp200
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp348
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/utils.cpp135
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/utils.hpp370
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/verbose.cpp665
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/verbose.hpp62
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp46
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp112
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp60
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp40
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp140
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp43
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp51
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp41
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp74
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp46
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp45
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp324
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp70
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp84
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp151
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp42
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp277
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp89
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp40
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp83
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp544
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp334
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp262
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp48
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp41
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp45
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp48
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp39
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp372
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp72
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp2131
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp36
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp2705
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp37
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp346
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp36
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp280
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp58
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp86
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp206
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp28
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp1409
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp38
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp539
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp101
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp290
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp411
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp64
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp819
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp2209
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp564
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp501
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp1283
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp3163
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp821
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp647
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp116
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp38
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp180
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp37
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp307
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp250
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp771
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp66
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp156
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp157
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp740
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp266
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp453
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp166
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp674
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp110
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp545
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp344
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp1501
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp225
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp410
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp302
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp1255
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp108
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp816
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp344
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp4539
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp423
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp1163
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp179
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp1526
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp302
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp1215
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp318
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp853
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp96
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp1103
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp144
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp1020
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp386
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp2596
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp291
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp1284
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp128
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp820
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp131
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp292
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp159
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp140
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp1182
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp239
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp423
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp115
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp1034
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp237
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp773
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp481
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp677
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp104
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp134
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp96
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp497
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp93
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp136
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp103
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp1192
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp145
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp327
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp1407
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp100
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp1302
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp253
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp427
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp266
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp1142
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp193
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp949
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp89
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp305
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp103
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp1487
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp183
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp699
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp192
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp264
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp182
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp1006
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp127
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp313
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp115
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp32
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD27
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md1
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h595
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h94
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c293
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h673
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp317
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp147
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp382
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp160
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp392
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp210
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp288
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp169
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp265
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp127
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp97
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp395
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp194
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp199
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp502
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp297
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp168
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp285
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp159
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp252
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp136
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp381
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp119
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp153
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp111
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp264
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp186
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp101
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp90
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp180
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp170
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp143
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp113
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp191
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp401
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp788
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp328
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp380
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp426
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp225
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp126
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp155
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp98
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp1022
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp91
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp74
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp376
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT47
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h2658
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h303
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h2017
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h772
-rw-r--r--thirdparty/oidn/patches/godot-changes-c58c5216.patch337
-rw-r--r--thirdparty/oidn/patches/mkl-dnn-fix-vs2017-build.patch45
-rw-r--r--thirdparty/oidn/weights/LICENSE.txt202
-rw-r--r--thirdparty/oidn/weights/rtlightmap_hdr.tzabin5660131 -> 0 bytes
325 files changed, 0 insertions, 114456 deletions
diff --git a/COPYRIGHT.txt b/COPYRIGHT.txt
index 3db389b224..7610465f0b 100644
--- a/COPYRIGHT.txt
+++ b/COPYRIGHT.txt
@@ -425,11 +425,6 @@ Comment: Stripped down version of "nvapi.h" from the NVIDIA NVAPI SDK
Copyright: 2019-2022, NVIDIA Corporation
License: Expat
-Files: ./thirdparty/oidn/
-Comment: Intel Open Image Denoise
-Copyright: 2009-2019, Intel Corporation
-License: Apache-2.0
-
Files: ./thirdparty/openxr/
Comment: OpenXR Loader
Copyright: 2020-2023, The Khronos Group Inc.
diff --git a/modules/denoise/SCsub b/modules/denoise/SCsub
deleted file mode 100644
index 967a511e1e..0000000000
--- a/modules/denoise/SCsub
+++ /dev/null
@@ -1,138 +0,0 @@
-#!/usr/bin/env python
-
-import resource_to_cpp
-
-Import("env")
-Import("env_modules")
-
-env_oidn = env_modules.Clone()
-
-# Thirdparty source files
-
-thirdparty_obj = []
-
-thirdparty_dir = "#thirdparty/oidn/"
-thirdparty_sources = [
- "core/api.cpp",
- "core/device.cpp",
- "core/filter.cpp",
- "core/network.cpp",
- "core/autoencoder.cpp",
- "core/transfer_function.cpp",
- "weights/rtlightmap_hdr.gen.cpp",
- "mkl-dnn/src/common/batch_normalization.cpp",
- "mkl-dnn/src/common/concat.cpp",
- "mkl-dnn/src/common/convolution.cpp",
- "mkl-dnn/src/common/convolution_pd.cpp",
- "mkl-dnn/src/common/deconvolution.cpp",
- "mkl-dnn/src/common/eltwise.cpp",
- "mkl-dnn/src/common/engine.cpp",
- "mkl-dnn/src/common/inner_product.cpp",
- "mkl-dnn/src/common/inner_product_pd.cpp",
- "mkl-dnn/src/common/lrn.cpp",
- "mkl-dnn/src/common/memory.cpp",
- "mkl-dnn/src/common/memory_desc_wrapper.cpp",
- "mkl-dnn/src/common/mkldnn_debug.cpp",
- "mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp",
- "mkl-dnn/src/common/pooling.cpp",
- "mkl-dnn/src/common/primitive.cpp",
- "mkl-dnn/src/common/primitive_attr.cpp",
- "mkl-dnn/src/common/primitive_desc.cpp",
- "mkl-dnn/src/common/primitive_exec_types.cpp",
- "mkl-dnn/src/common/primitive_iterator.cpp",
- "mkl-dnn/src/common/query.cpp",
- "mkl-dnn/src/common/reorder.cpp",
- "mkl-dnn/src/common/rnn.cpp",
- "mkl-dnn/src/common/scratchpad.cpp",
- "mkl-dnn/src/common/shuffle.cpp",
- "mkl-dnn/src/common/softmax.cpp",
- "mkl-dnn/src/common/stream.cpp",
- "mkl-dnn/src/common/sum.cpp",
- "mkl-dnn/src/common/utils.cpp",
- "mkl-dnn/src/common/verbose.cpp",
- "mkl-dnn/src/cpu/cpu_barrier.cpp",
- "mkl-dnn/src/cpu/cpu_concat.cpp",
- "mkl-dnn/src/cpu/cpu_engine.cpp",
- "mkl-dnn/src/cpu/cpu_memory.cpp",
- "mkl-dnn/src/cpu/cpu_reducer.cpp",
- "mkl-dnn/src/cpu/cpu_reorder.cpp",
- "mkl-dnn/src/cpu/cpu_sum.cpp",
- "mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp",
- "mkl-dnn/src/cpu/jit_avx2_convolution.cpp",
- "mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp",
- "mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp",
- "mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp",
- "mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp",
- "mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp",
- "mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp",
- "mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp",
- "mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp",
- "mkl-dnn/src/cpu/jit_sse42_convolution.cpp",
- "mkl-dnn/src/cpu/jit_transpose_src_utils.cpp",
- "mkl-dnn/src/cpu/jit_uni_eltwise.cpp",
- "mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp",
- "mkl-dnn/src/cpu/jit_uni_pooling.cpp",
- "mkl-dnn/src/cpu/jit_uni_reorder.cpp",
- "mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp",
- "mkl-dnn/src/cpu/jit_utils/jit_utils.cpp",
- "mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c",
- "common/platform.cpp",
- "common/thread.cpp",
- "common/tensor.cpp",
-]
-thirdparty_sources = [thirdparty_dir + file for file in thirdparty_sources]
-
-thirdparty_include_dirs = [
- "",
- "include",
- "mkl-dnn/include",
- "mkl-dnn/src",
- "mkl-dnn/src/common",
- "mkl-dnn/src/cpu/xbyak",
- "mkl-dnn/src/cpu",
-]
-thirdparty_include_dirs = [thirdparty_dir + file for file in thirdparty_include_dirs]
-
-
-env_oidn.Prepend(CPPPATH=thirdparty_include_dirs)
-env_oidn.Append(
- CPPDEFINES=[
- "MKLDNN_THR=MKLDNN_THR_SEQ",
- "OIDN_STATIC_LIB",
- "__STDC_CONSTANT_MACROS",
- "__STDC_LIMIT_MACROS",
- "DISABLE_VERBOSE",
- "MKLDNN_ENABLE_CONCURRENT_EXEC",
- ]
-)
-env_oidn.AppendUnique(CPPDEFINES=["NDEBUG"]) # No assert() even in debug builds.
-
-env_thirdparty = env_oidn.Clone()
-env_thirdparty.disable_warnings()
-
-if env["disable_exceptions"]:
- # OIDN hard-requires exceptions, so we re-enable them here.
- if env.msvc and ("_HAS_EXCEPTIONS", 0) in env_thirdparty["CPPDEFINES"]:
- env_thirdparty["CPPDEFINES"].remove(("_HAS_EXCEPTIONS", 0))
- env_thirdparty.AppendUnique(CCFLAGS=["/EHsc"])
- elif not env.msvc and "-fno-exceptions" in env_thirdparty["CCFLAGS"]:
- env_thirdparty["CCFLAGS"].remove("-fno-exceptions")
-
-env_thirdparty.add_source_files(thirdparty_obj, thirdparty_sources)
-env.modules_sources += thirdparty_obj
-
-weights_in_path = thirdparty_dir + "weights/rtlightmap_hdr.tza"
-weights_out_path = thirdparty_dir + "weights/rtlightmap_hdr.gen.cpp"
-
-env_thirdparty.Depends(weights_out_path, weights_in_path)
-env_thirdparty.CommandNoCache(weights_out_path, weights_in_path, resource_to_cpp.tza_to_cpp)
-
-# Godot source files
-
-module_obj = []
-
-env_oidn.add_source_files(module_obj, "*.cpp")
-env.modules_sources += module_obj
-
-# Needed to force rebuilding the module files when the thirdparty library is updated.
-env.Depends(module_obj, thirdparty_obj)
diff --git a/modules/denoise/config.py b/modules/denoise/config.py
deleted file mode 100644
index 27d2ffbf86..0000000000
--- a/modules/denoise/config.py
+++ /dev/null
@@ -1,12 +0,0 @@
-def can_build(env, platform):
- # Thirdparty dependency OpenImage Denoise includes oneDNN library
- # and the version we use only supports x86_64.
- # It's also only relevant for tools build and desktop platforms,
- # as doing lightmap generation and denoising on Android or Web
- # would be a bit far-fetched.
- desktop_platforms = ["linuxbsd", "macos", "windows"]
- return env.editor_build and platform in desktop_platforms and env["arch"] == "x86_64"
-
-
-def configure(env):
- pass
diff --git a/modules/denoise/denoise_wrapper.cpp b/modules/denoise/denoise_wrapper.cpp
deleted file mode 100644
index 87f02cb4c6..0000000000
--- a/modules/denoise/denoise_wrapper.cpp
+++ /dev/null
@@ -1,66 +0,0 @@
-/**************************************************************************/
-/* denoise_wrapper.cpp */
-/**************************************************************************/
-/* This file is part of: */
-/* GODOT ENGINE */
-/* https://godotengine.org */
-/**************************************************************************/
-/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
-/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */
-/* */
-/* Permission is hereby granted, free of charge, to any person obtaining */
-/* a copy of this software and associated documentation files (the */
-/* "Software"), to deal in the Software without restriction, including */
-/* without limitation the rights to use, copy, modify, merge, publish, */
-/* distribute, sublicense, and/or sell copies of the Software, and to */
-/* permit persons to whom the Software is furnished to do so, subject to */
-/* the following conditions: */
-/* */
-/* The above copyright notice and this permission notice shall be */
-/* included in all copies or substantial portions of the Software. */
-/* */
-/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
-/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
-/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
-/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
-/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
-/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
-/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
-/**************************************************************************/
-
-#include "denoise_wrapper.h"
-
-#include <OpenImageDenoise/oidn.h>
-
-#include <stdio.h>
-
-void *oidn_denoiser_init() {
- OIDNDeviceImpl *device = oidnNewDevice(OIDN_DEVICE_TYPE_CPU);
- oidnCommitDevice(device);
- return device;
-}
-
-bool oidn_denoise(void *deviceptr, float *p_floats, int p_width, int p_height) {
- OIDNDeviceImpl *device = (OIDNDeviceImpl *)deviceptr;
- OIDNFilter filter = oidnNewFilter(device, "RTLightmap");
- oidnSetSharedFilterImage(filter, "color", (void *)p_floats, OIDN_FORMAT_FLOAT3, p_width, p_height, 0, 0, 0);
- oidnSetSharedFilterImage(filter, "output", (void *)p_floats, OIDN_FORMAT_FLOAT3, p_width, p_height, 0, 0, 0);
- oidnSetFilter1b(filter, "hdr", true);
- //oidnSetFilter1f(filter, "hdrScale", 1.0f);
- oidnCommitFilter(filter);
- oidnExecuteFilter(filter);
-
- const char *msg;
- bool success = true;
- if (oidnGetDeviceError(device, &msg) != OIDN_ERROR_NONE) {
- printf("LightmapDenoiser: %s\n", msg);
- success = false;
- }
-
- oidnReleaseFilter(filter);
- return success;
-}
-
-void oidn_denoiser_finish(void *device) {
- oidnReleaseDevice((OIDNDeviceImpl *)device);
-}
diff --git a/modules/denoise/denoise_wrapper.h b/modules/denoise/denoise_wrapper.h
deleted file mode 100644
index d4bf154a5d..0000000000
--- a/modules/denoise/denoise_wrapper.h
+++ /dev/null
@@ -1,38 +0,0 @@
-/**************************************************************************/
-/* denoise_wrapper.h */
-/**************************************************************************/
-/* This file is part of: */
-/* GODOT ENGINE */
-/* https://godotengine.org */
-/**************************************************************************/
-/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
-/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */
-/* */
-/* Permission is hereby granted, free of charge, to any person obtaining */
-/* a copy of this software and associated documentation files (the */
-/* "Software"), to deal in the Software without restriction, including */
-/* without limitation the rights to use, copy, modify, merge, publish, */
-/* distribute, sublicense, and/or sell copies of the Software, and to */
-/* permit persons to whom the Software is furnished to do so, subject to */
-/* the following conditions: */
-/* */
-/* The above copyright notice and this permission notice shall be */
-/* included in all copies or substantial portions of the Software. */
-/* */
-/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
-/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
-/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
-/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
-/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
-/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
-/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
-/**************************************************************************/
-
-#ifndef DENOISE_WRAPPER_H
-#define DENOISE_WRAPPER_H
-
-void *oidn_denoiser_init();
-bool oidn_denoise(void *device, float *p_floats, int p_width, int p_height);
-void oidn_denoiser_finish(void *device);
-
-#endif // DENOISE_WRAPPER_H
diff --git a/modules/denoise/lightmap_denoiser.cpp b/modules/denoise/lightmap_denoiser.cpp
deleted file mode 100644
index 72764036e1..0000000000
--- a/modules/denoise/lightmap_denoiser.cpp
+++ /dev/null
@@ -1,65 +0,0 @@
-/**************************************************************************/
-/* lightmap_denoiser.cpp */
-/**************************************************************************/
-/* This file is part of: */
-/* GODOT ENGINE */
-/* https://godotengine.org */
-/**************************************************************************/
-/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
-/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */
-/* */
-/* Permission is hereby granted, free of charge, to any person obtaining */
-/* a copy of this software and associated documentation files (the */
-/* "Software"), to deal in the Software without restriction, including */
-/* without limitation the rights to use, copy, modify, merge, publish, */
-/* distribute, sublicense, and/or sell copies of the Software, and to */
-/* permit persons to whom the Software is furnished to do so, subject to */
-/* the following conditions: */
-/* */
-/* The above copyright notice and this permission notice shall be */
-/* included in all copies or substantial portions of the Software. */
-/* */
-/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
-/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
-/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
-/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
-/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
-/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
-/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
-/**************************************************************************/
-
-#include "lightmap_denoiser.h"
-
-#include "denoise_wrapper.h"
-
-#include "core/io/image.h"
-
-LightmapDenoiser *LightmapDenoiserOIDN::create_oidn_denoiser() {
- return memnew(LightmapDenoiserOIDN);
-}
-
-void LightmapDenoiserOIDN::make_default_denoiser() {
- create_function = create_oidn_denoiser;
-}
-
-Ref<Image> LightmapDenoiserOIDN::denoise_image(const Ref<Image> &p_image) {
- Ref<Image> img = p_image->duplicate();
-
- img->convert(Image::FORMAT_RGBF);
-
- Vector<uint8_t> data = img->get_data();
- if (!oidn_denoise(device, (float *)data.ptrw(), img->get_width(), img->get_height())) {
- return p_image;
- }
-
- img->set_data(img->get_width(), img->get_height(), false, img->get_format(), data);
- return img;
-}
-
-LightmapDenoiserOIDN::LightmapDenoiserOIDN() {
- device = oidn_denoiser_init();
-}
-
-LightmapDenoiserOIDN::~LightmapDenoiserOIDN() {
- oidn_denoiser_finish(device);
-}
diff --git a/modules/denoise/lightmap_denoiser.h b/modules/denoise/lightmap_denoiser.h
deleted file mode 100644
index 8f658ab096..0000000000
--- a/modules/denoise/lightmap_denoiser.h
+++ /dev/null
@@ -1,56 +0,0 @@
-/**************************************************************************/
-/* lightmap_denoiser.h */
-/**************************************************************************/
-/* This file is part of: */
-/* GODOT ENGINE */
-/* https://godotengine.org */
-/**************************************************************************/
-/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
-/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */
-/* */
-/* Permission is hereby granted, free of charge, to any person obtaining */
-/* a copy of this software and associated documentation files (the */
-/* "Software"), to deal in the Software without restriction, including */
-/* without limitation the rights to use, copy, modify, merge, publish, */
-/* distribute, sublicense, and/or sell copies of the Software, and to */
-/* permit persons to whom the Software is furnished to do so, subject to */
-/* the following conditions: */
-/* */
-/* The above copyright notice and this permission notice shall be */
-/* included in all copies or substantial portions of the Software. */
-/* */
-/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
-/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
-/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
-/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
-/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
-/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
-/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
-/**************************************************************************/
-
-#ifndef LIGHTMAP_DENOISER_H
-#define LIGHTMAP_DENOISER_H
-
-#include "core/object/class_db.h"
-#include "scene/3d/lightmapper.h"
-
-struct OIDNDeviceImpl;
-
-class LightmapDenoiserOIDN : public LightmapDenoiser {
- GDCLASS(LightmapDenoiserOIDN, LightmapDenoiser);
-
-protected:
- void *device = nullptr;
-
-public:
- static LightmapDenoiser *create_oidn_denoiser();
-
- Ref<Image> denoise_image(const Ref<Image> &p_image) override;
-
- static void make_default_denoiser();
-
- LightmapDenoiserOIDN();
- ~LightmapDenoiserOIDN();
-};
-
-#endif // LIGHTMAP_DENOISER_H
diff --git a/modules/denoise/register_types.cpp b/modules/denoise/register_types.cpp
deleted file mode 100644
index a4264b07c5..0000000000
--- a/modules/denoise/register_types.cpp
+++ /dev/null
@@ -1,49 +0,0 @@
-/**************************************************************************/
-/* register_types.cpp */
-/**************************************************************************/
-/* This file is part of: */
-/* GODOT ENGINE */
-/* https://godotengine.org */
-/**************************************************************************/
-/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
-/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */
-/* */
-/* Permission is hereby granted, free of charge, to any person obtaining */
-/* a copy of this software and associated documentation files (the */
-/* "Software"), to deal in the Software without restriction, including */
-/* without limitation the rights to use, copy, modify, merge, publish, */
-/* distribute, sublicense, and/or sell copies of the Software, and to */
-/* permit persons to whom the Software is furnished to do so, subject to */
-/* the following conditions: */
-/* */
-/* The above copyright notice and this permission notice shall be */
-/* included in all copies or substantial portions of the Software. */
-/* */
-/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
-/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
-/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
-/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
-/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
-/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
-/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
-/**************************************************************************/
-
-#include "register_types.h"
-
-#include "lightmap_denoiser.h"
-
-#include "core/config/engine.h"
-
-void initialize_denoise_module(ModuleInitializationLevel p_level) {
- if (p_level != MODULE_INITIALIZATION_LEVEL_SCENE) {
- return;
- }
-
- LightmapDenoiserOIDN::make_default_denoiser();
-}
-
-void uninitialize_denoise_module(ModuleInitializationLevel p_level) {
- if (p_level != MODULE_INITIALIZATION_LEVEL_SCENE) {
- return;
- }
-}
diff --git a/modules/denoise/register_types.h b/modules/denoise/register_types.h
deleted file mode 100644
index 239877a5c7..0000000000
--- a/modules/denoise/register_types.h
+++ /dev/null
@@ -1,39 +0,0 @@
-/**************************************************************************/
-/* register_types.h */
-/**************************************************************************/
-/* This file is part of: */
-/* GODOT ENGINE */
-/* https://godotengine.org */
-/**************************************************************************/
-/* Copyright (c) 2014-present Godot Engine contributors (see AUTHORS.md). */
-/* Copyright (c) 2007-2014 Juan Linietsky, Ariel Manzur. */
-/* */
-/* Permission is hereby granted, free of charge, to any person obtaining */
-/* a copy of this software and associated documentation files (the */
-/* "Software"), to deal in the Software without restriction, including */
-/* without limitation the rights to use, copy, modify, merge, publish, */
-/* distribute, sublicense, and/or sell copies of the Software, and to */
-/* permit persons to whom the Software is furnished to do so, subject to */
-/* the following conditions: */
-/* */
-/* The above copyright notice and this permission notice shall be */
-/* included in all copies or substantial portions of the Software. */
-/* */
-/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, */
-/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF */
-/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. */
-/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY */
-/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, */
-/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE */
-/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */
-/**************************************************************************/
-
-#ifndef DENOISE_REGISTER_TYPES_H
-#define DENOISE_REGISTER_TYPES_H
-
-#include "modules/register_module_types.h"
-
-void initialize_denoise_module(ModuleInitializationLevel p_level);
-void uninitialize_denoise_module(ModuleInitializationLevel p_level);
-
-#endif // DENOISE_REGISTER_TYPES_H
diff --git a/modules/denoise/resource_to_cpp.py b/modules/denoise/resource_to_cpp.py
deleted file mode 100644
index a89eda9117..0000000000
--- a/modules/denoise/resource_to_cpp.py
+++ /dev/null
@@ -1,68 +0,0 @@
-#!/usr/bin/env python
-
-## ======================================================================== ##
-## Copyright 2009-2019 Intel Corporation ##
-## ##
-## Licensed under the Apache License, Version 2.0 (the "License"); ##
-## you may not use this file except in compliance with the License. ##
-## You may obtain a copy of the License at ##
-## ##
-## http://www.apache.org/licenses/LICENSE-2.0 ##
-## ##
-## Unless required by applicable law or agreed to in writing, software ##
-## distributed under the License is distributed on an "AS IS" BASIS, ##
-## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ##
-## See the License for the specific language governing permissions and ##
-## limitations under the License. ##
-## ======================================================================== ##
-
-import os
-from array import array
-
-
-# Generates a C++ file from the specified binary resource file
-def generate(in_path, out_path):
- namespace = "oidn::weights"
- scopes = namespace.split("::")
-
- file_name = os.path.basename(in_path)
- var_name = os.path.splitext(file_name)[0]
-
- with open(in_path, "rb") as in_file, open(out_path, "w") as out_file:
- # Header
- out_file.write("// Generated from: %s\n" % file_name)
- out_file.write("#include <cstddef>\n\n")
-
- # Open the namespaces
- for s in scopes:
- out_file.write("namespace %s {\n" % s)
- if scopes:
- out_file.write("\n")
-
- # Read the file
- in_data = array("B", in_file.read())
-
- # Write the size
- out_file.write("//const size_t %s_size = %d;\n\n" % (var_name, len(in_data)))
-
- # Write the data
- out_file.write("unsigned char %s[] = {" % var_name)
- for i in range(len(in_data)):
- c = in_data[i]
- if i > 0:
- out_file.write(",")
- if (i + 1) % 20 == 1:
- out_file.write("\n")
- out_file.write("%d" % c)
- out_file.write("\n};\n")
-
- # Close the namespaces
- if scopes:
- out_file.write("\n")
- for scope in reversed(scopes):
- out_file.write("} // namespace %s\n" % scope)
-
-
-def tza_to_cpp(target, source, env):
- for x in zip(source, target):
- generate(str(x[0]), str(x[1]))
diff --git a/thirdparty/README.md b/thirdparty/README.md
index 1eb95a1a7c..4327eeab8b 100644
--- a/thirdparty/README.md
+++ b/thirdparty/README.md
@@ -642,37 +642,6 @@ Files extracted from the upstream source:
- `nvapi_minimal.h` was created by using `nvapi.h` from upstream and removing unnecessary code.
-## oidn
-
-- Upstream: https://github.com/OpenImageDenoise/oidn
-- Version: 1.1.0 (c58c5216db05ceef4cde5a096862f2eeffd14c06, 2019)
-- License: Apache 2.0
-
-Files extracted from upstream source:
-
-- common/* (except tasking.* and CMakeLists.txt)
-- core/*
-- include/OpenImageDenoise/* (except version.h.in)
-- LICENSE.txt
-- mkl-dnn/include/*
-- mkl-dnn/src/* (except CMakeLists.txt)
-- weights/rtlightmap_hdr.tza
-- scripts/resource_to_cpp.py
-
-Modified files:
-Modifications are marked with `// -- GODOT start --` and `// -- GODOT end --`.
-Patch files are provided in `oidn/patches/`.
-
-- core/autoencoder.cpp
-- core/autoencoder.h
-- core/common.h
-- core/device.cpp
-- core/device.h
-- core/transfer_function.cpp
-
-- scripts/resource_to_cpp.py (used in modules/denoise/resource_to_cpp.py)
-
-
## openxr
- Upstream: https://github.com/KhronosGroup/OpenXR-SDK
diff --git a/thirdparty/oidn/LICENSE.txt b/thirdparty/oidn/LICENSE.txt
deleted file mode 100644
index d645695673..0000000000
--- a/thirdparty/oidn/LICENSE.txt
+++ /dev/null
@@ -1,202 +0,0 @@
-
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [yyyy] [name of copyright owner]
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
diff --git a/thirdparty/oidn/common/barrier.h b/thirdparty/oidn/common/barrier.h
deleted file mode 100644
index b20f670053..0000000000
--- a/thirdparty/oidn/common/barrier.h
+++ /dev/null
@@ -1,52 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "platform.h"
-#include <mutex>
-#include <condition_variable>
-
-namespace oidn {
-
- class Barrier
- {
- private:
- std::mutex m;
- std::condition_variable cv;
- volatile int count;
-
- public:
- Barrier(int count) : count(count) {}
-
- void wait()
- {
- std::unique_lock<std::mutex> lk(m);
- count--;
-
- if (count == 0)
- {
- lk.unlock();
- cv.notify_all();
- }
- else
- {
- cv.wait(lk, [&]{ return count == 0; });
- }
- }
- };
-
-} // namespace oidn
diff --git a/thirdparty/oidn/common/exception.h b/thirdparty/oidn/common/exception.h
deleted file mode 100644
index 18069c6a7d..0000000000
--- a/thirdparty/oidn/common/exception.h
+++ /dev/null
@@ -1,45 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include <exception>
-#include "platform.h"
-
-namespace oidn {
-
- class Exception : public std::exception
- {
- private:
- Error error;
- const char* message;
-
- public:
- Exception(Error error, const char* message)
- : error(error), message(message) {}
-
- Error code() const noexcept
- {
- return error;
- }
-
- const char* what() const noexcept override
- {
- return message;
- }
- };
-
-} // namespace oidn
diff --git a/thirdparty/oidn/common/platform.cpp b/thirdparty/oidn/common/platform.cpp
deleted file mode 100644
index 59a14ff47c..0000000000
--- a/thirdparty/oidn/common/platform.cpp
+++ /dev/null
@@ -1,114 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#include "platform.h"
-
-namespace oidn {
-
- // ----------------------------------------------------------------------------
- // Common functions
- // ----------------------------------------------------------------------------
-
- void* alignedMalloc(size_t size, size_t alignment)
- {
- if (size == 0)
- return nullptr;
-
- assert((alignment & (alignment-1)) == 0);
- void* ptr = _mm_malloc(size, alignment);
-
- if (ptr == nullptr)
- throw std::bad_alloc();
-
- return ptr;
- }
-
- void alignedFree(void* ptr)
- {
- if (ptr)
- _mm_free(ptr);
- }
-
- // ----------------------------------------------------------------------------
- // System information
- // ----------------------------------------------------------------------------
-
- std::string getPlatformName()
- {
- std::string name;
-
- #if defined(__linux__)
- name = "Linux";
- #elif defined(__FreeBSD__)
- name = "FreeBSD";
- #elif defined(__CYGWIN__)
- name = "Cygwin";
- #elif defined(_WIN32)
- name = "Windows";
- #elif defined(__APPLE__)
- name = "macOS";
- #elif defined(__unix__)
- name = "Unix";
- #else
- return "Unknown";
- #endif
-
- #if defined(__x86_64__) || defined(_M_X64) || defined(__ia64__) || defined(__aarch64__)
- name += " (64-bit)";
- #else
- name += " (32-bit)";
- #endif
-
- return name;
- }
-
- std::string getCompilerName()
- {
- #if defined(__INTEL_COMPILER)
- int mayor = __INTEL_COMPILER / 100 % 100;
- int minor = __INTEL_COMPILER % 100;
- std::string version = "Intel Compiler ";
- version += toString(mayor);
- version += "." + toString(minor);
- #if defined(__INTEL_COMPILER_UPDATE)
- version += "." + toString(__INTEL_COMPILER_UPDATE);
- #endif
- return version;
- #elif defined(__clang__)
- return "Clang " __clang_version__;
- #elif defined(__GNUC__)
- return "GCC " __VERSION__;
- #elif defined(_MSC_VER)
- std::string version = toString(_MSC_FULL_VER);
- version.insert(4, ".");
- version.insert(9, ".");
- version.insert(2, ".");
- return "Visual C++ Compiler " + version;
- #else
- return "Unknown";
- #endif
- }
-
- std::string getBuildName()
- {
- #if defined(NDEBUG)
- return "Release";
- #else
- return "Debug";
- #endif
- }
-
-} // namespace oidn
diff --git a/thirdparty/oidn/common/platform.h b/thirdparty/oidn/common/platform.h
deleted file mode 100644
index 9373b617b5..0000000000
--- a/thirdparty/oidn/common/platform.h
+++ /dev/null
@@ -1,131 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#if defined(_WIN32)
- #define WIN32_LEAN_AND_MEAN
- #define NOMINMAX
- #include <windows.h>
-#elif defined(__APPLE__)
- #include <sys/sysctl.h>
-#endif
-
-#include <xmmintrin.h>
-#include <cstdint>
-#include <climits>
-#include <limits>
-#include <atomic>
-#include <algorithm>
-#include <memory>
-#include <cmath>
-#include <string>
-#include <sstream>
-#include <iostream>
-#include <cassert>
-#include "include/OpenImageDenoise/oidn.hpp"
-
-namespace oidn {
-
- // ----------------------------------------------------------------------------
- // Macros
- // ----------------------------------------------------------------------------
-
- #if defined(_WIN32)
- // Windows
- #if !defined(__noinline)
- #define __noinline __declspec(noinline)
- #endif
- #else
- // Unix
- #if !defined(__forceinline)
- #define __forceinline inline __attribute__((always_inline))
- #endif
- #if !defined(__noinline)
- #define __noinline __attribute__((noinline))
- #endif
- #endif
-
- #ifndef UNUSED
- #define UNUSED(x) ((void)x)
- #endif
- #ifndef MAYBE_UNUSED
- #define MAYBE_UNUSED(x) UNUSED(x)
- #endif
-
- // ----------------------------------------------------------------------------
- // Error handling and debugging
- // ----------------------------------------------------------------------------
-
- struct Verbose
- {
- int verbose;
-
- Verbose(int v = 0) : verbose(v) {}
- __forceinline bool isVerbose(int v = 1) const { return v <= verbose; }
- };
-
- #define OIDN_WARNING(message) { if (isVerbose()) std::cerr << "Warning: " << message << std::endl; }
- #define OIDN_FATAL(message) throw std::runtime_error(message);
-
- // ----------------------------------------------------------------------------
- // Common functions
- // ----------------------------------------------------------------------------
-
- using std::min;
- using std::max;
-
- template<typename T>
- __forceinline T clamp(const T& value, const T& minValue, const T& maxValue)
- {
- return min(max(value, minValue), maxValue);
- }
-
- void* alignedMalloc(size_t size, size_t alignment);
- void alignedFree(void* ptr);
-
- template<typename T>
- inline std::string toString(const T& a)
- {
- std::stringstream sm;
- sm << a;
- return sm.str();
- }
-
-#if defined(__APPLE__)
- template<typename T>
- bool getSysctl(const char* name, T& value)
- {
- int64_t result = 0;
- size_t size = sizeof(result);
-
- if (sysctlbyname(name, &result, &size, nullptr, 0) != 0)
- return false;
-
- value = T(result);
- return true;
- }
-#endif
-
- // ----------------------------------------------------------------------------
- // System information
- // ----------------------------------------------------------------------------
-
- std::string getPlatformName();
- std::string getCompilerName();
- std::string getBuildName();
-
-} // namespace oidn
diff --git a/thirdparty/oidn/common/ref.h b/thirdparty/oidn/common/ref.h
deleted file mode 100644
index de44603af2..0000000000
--- a/thirdparty/oidn/common/ref.h
+++ /dev/null
@@ -1,163 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "platform.h"
-
-namespace oidn {
-
- class RefCount
- {
- private:
- std::atomic<size_t> count;
-
- public:
- __forceinline RefCount(int count = 0) noexcept : count(count) {}
-
- __forceinline size_t incRef() noexcept
- {
- return count.fetch_add(1) + 1;
- }
-
- __forceinline size_t decRef()
- {
- const size_t newCount = decRefKeep();
- if (newCount == 0)
- destroy();
- return newCount;
- }
-
- __forceinline size_t decRefKeep() noexcept
- {
- return count.fetch_add(-1) - 1;
- }
-
- __forceinline void destroy()
- {
- delete this;
- }
-
- protected:
- // Disable copying
- RefCount(const RefCount&) = delete;
- RefCount& operator =(const RefCount&) = delete;
-
- virtual ~RefCount() noexcept = default;
- };
-
- template<typename T>
- class Ref
- {
- private:
- T* ptr;
-
- public:
- __forceinline Ref() noexcept : ptr(nullptr) {}
- __forceinline Ref(std::nullptr_t) noexcept : ptr(nullptr) {}
- __forceinline Ref(const Ref& other) noexcept : ptr(other.ptr) { if (ptr) ptr->incRef(); }
- __forceinline Ref(Ref&& other) noexcept : ptr(other.ptr) { other.ptr = nullptr; }
- __forceinline Ref(T* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); }
-
- template<typename Y>
- __forceinline Ref(const Ref<Y>& other) noexcept : ptr(other.get()) { if (ptr) ptr->incRef(); }
-
- template<typename Y>
- __forceinline explicit Ref(Y* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); }
-
- __forceinline ~Ref() { if (ptr) ptr->decRef(); }
-
- __forceinline Ref& operator =(const Ref& other)
- {
- if (other.ptr)
- other.ptr->incRef();
- if (ptr)
- ptr->decRef();
- ptr = other.ptr;
- return *this;
- }
-
- __forceinline Ref& operator =(Ref&& other)
- {
- if (ptr)
- ptr->decRef();
- ptr = other.ptr;
- other.ptr = nullptr;
- return *this;
- }
-
- __forceinline Ref& operator =(T* other)
- {
- if (other)
- other->incRef();
- if (ptr)
- ptr->decRef();
- ptr = other;
- return *this;
- }
-
- __forceinline Ref& operator =(std::nullptr_t)
- {
- if (ptr)
- ptr->decRef();
- ptr = nullptr;
- return *this;
- }
-
- __forceinline operator bool() const noexcept { return ptr != nullptr; }
-
- __forceinline T& operator *() const noexcept { return *ptr; }
- __forceinline T* operator ->() const noexcept { return ptr; }
-
- __forceinline T* get() const noexcept { return ptr; }
-
- __forceinline T* detach() noexcept
- {
- T* res = ptr;
- ptr = nullptr;
- return res;
- }
- };
-
- template<typename T> __forceinline bool operator < (const Ref<T>& a, const Ref<T>& b) noexcept { return a.ptr < b.ptr; }
-
- template<typename T> __forceinline bool operator ==(const Ref<T>& a, std::nullptr_t) noexcept { return a.ptr == nullptr; }
- template<typename T> __forceinline bool operator ==(std::nullptr_t, const Ref<T>& b) noexcept { return nullptr == b.ptr; }
- template<typename T> __forceinline bool operator ==(const Ref<T>& a, const Ref<T>& b) noexcept { return a.ptr == b.ptr; }
-
- template<typename T> __forceinline bool operator !=(const Ref<T>& a, std::nullptr_t) noexcept { return a.ptr != nullptr; }
- template<typename T> __forceinline bool operator !=(std::nullptr_t, const Ref<T>& b) noexcept { return nullptr != b.ptr; }
- template<typename T> __forceinline bool operator !=(const Ref<T>& a, const Ref<T>& b) noexcept { return a.ptr != b.ptr; }
-
- template<typename T, typename... Args>
- __forceinline Ref<T> makeRef(Args&&... args)
- {
- return Ref<T>(new T(std::forward<Args>(args)...));
- }
-
- template<typename T, typename Y>
- __forceinline Ref<Y> staticRefCast(const Ref<T>& a)
- {
- return Ref<Y>(static_cast<Y*>(a.get()));
- }
-
- template<typename T, typename Y>
- __forceinline Ref<Y> dynamicRefCast(const Ref<T>& a)
- {
- return Ref<Y>(dynamic_cast<Y*>(a.get()));
- }
-
-} // namespace oidn
diff --git a/thirdparty/oidn/common/tensor.cpp b/thirdparty/oidn/common/tensor.cpp
deleted file mode 100644
index 0249f2e141..0000000000
--- a/thirdparty/oidn/common/tensor.cpp
+++ /dev/null
@@ -1,83 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#include "exception.h"
-#include "tensor.h"
-
-namespace oidn {
-
- std::map<std::string, Tensor> parseTensors(void* buffer)
- {
- char* input = (char*)buffer;
-
- // Parse the magic value
- const int magic = *(unsigned short*)input;
- if (magic != 0x41D7)
- throw Exception(Error::InvalidOperation, "invalid tensor archive");
- input += sizeof(unsigned short);
-
- // Parse the version
- const int majorVersion = *(unsigned char*)input++;
- const int minorVersion = *(unsigned char*)input++;
- UNUSED(minorVersion);
- if (majorVersion > 1)
- throw Exception(Error::InvalidOperation, "unsupported tensor archive version");
-
- // Parse the number of tensors
- const int numTensors = *(int*)input;
- input += sizeof(int);
-
- // Parse the tensors
- std::map<std::string, Tensor> tensorMap;
- for (int i = 0; i < numTensors; ++i)
- {
- Tensor tensor;
-
- // Parse the name
- const int nameLen = *(unsigned char*)input++;
- std::string name(input, nameLen);
- input += nameLen;
-
- // Parse the number of dimensions
- const int ndims = *(unsigned char*)input++;
-
- // Parse the shape of the tensor
- tensor.dims.resize(ndims);
- for (int i = 0; i < ndims; ++i)
- tensor.dims[i] = ((int*)input)[i];
- input += ndims * sizeof(int);
-
- // Parse the format of the tensor
- tensor.format = std::string(input, input + ndims);
- input += ndims;
-
- // Parse the data type of the tensor
- const char type = *(unsigned char*)input++;
- if (type != 'f') // only float32 is supported
- throw Exception(Error::InvalidOperation, "unsupported tensor data type");
-
- // Skip the data
- tensor.data = (float*)input;
- input += tensor.size() * sizeof(float);
-
- // Add the tensor to the map
- tensorMap.emplace(name, std::move(tensor));
- }
-
- return tensorMap;
- }
-
-} // namespace oidn
diff --git a/thirdparty/oidn/common/tensor.h b/thirdparty/oidn/common/tensor.h
deleted file mode 100644
index 48e7d1123d..0000000000
--- a/thirdparty/oidn/common/tensor.h
+++ /dev/null
@@ -1,66 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "platform.h"
-#include <vector>
-#include <map>
-
-namespace oidn {
-
- template<typename T>
- using shared_vector = std::shared_ptr<std::vector<T>>;
-
- // Generic tensor
- struct Tensor
- {
- float* data;
- std::vector<int64_t> dims;
- std::string format;
- shared_vector<char> buffer; // optional, only for reference counting
-
- __forceinline Tensor() : data(nullptr) {}
-
- __forceinline Tensor(const std::vector<int64_t>& dims, const std::string& format)
- : dims(dims),
- format(format)
- {
- buffer = std::make_shared<std::vector<char>>(size() * sizeof(float));
- data = (float*)buffer->data();
- }
-
- __forceinline operator bool() const { return data != nullptr; }
-
- __forceinline int ndims() const { return (int)dims.size(); }
-
- // Returns the number of values
- __forceinline size_t size() const
- {
- size_t size = 1;
- for (int i = 0; i < ndims(); ++i)
- size *= dims[i];
- return size;
- }
-
- __forceinline float& operator [](size_t i) { return data[i]; }
- __forceinline const float& operator [](size_t i) const { return data[i]; }
- };
-
- // Parses tensors from a buffer
- std::map<std::string, Tensor> parseTensors(void* buffer);
-
-} // namespace oidn
diff --git a/thirdparty/oidn/common/thread.cpp b/thirdparty/oidn/common/thread.cpp
deleted file mode 100644
index 48c489c57b..0000000000
--- a/thirdparty/oidn/common/thread.cpp
+++ /dev/null
@@ -1,297 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#if defined(_MSC_VER)
- #pragma warning (disable : 4146) // unary minus operator applied to unsigned type, result still unsigned
-#endif
-
-#if defined(__APPLE__)
- #include <mach/thread_act.h>
- #include <mach/mach_init.h>
-#endif
-
-#include "thread.h"
-#include <fstream>
-
-namespace oidn {
-
-#if defined(_WIN32)
-
- // --------------------------------------------------------------------------
- // ThreadAffinity - Windows
- // --------------------------------------------------------------------------
-
- ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose)
- : Verbose(verbose)
- {
- HMODULE hLib = GetModuleHandle(TEXT("kernel32"));
- pGetLogicalProcessorInformationEx = (GetLogicalProcessorInformationExFunc)GetProcAddress(hLib, "GetLogicalProcessorInformationEx");
- pSetThreadGroupAffinity = (SetThreadGroupAffinityFunc)GetProcAddress(hLib, "SetThreadGroupAffinity");
-
- if (pGetLogicalProcessorInformationEx && pSetThreadGroupAffinity)
- {
- // Get logical processor information
- PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX buffer = nullptr;
- DWORD bufferSize = 0;
-
- // First call the function with an empty buffer to get the required buffer size
- BOOL result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize);
- if (result || GetLastError() != ERROR_INSUFFICIENT_BUFFER)
- {
- OIDN_WARNING("GetLogicalProcessorInformationEx failed");
- return;
- }
-
- // Allocate the buffer
- buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)malloc(bufferSize);
- if (!buffer)
- {
- OIDN_WARNING("SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX allocation failed");
- return;
- }
-
- // Call again the function but now with the properly sized buffer
- result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize);
- if (!result)
- {
- OIDN_WARNING("GetLogicalProcessorInformationEx failed");
- free(buffer);
- return;
- }
-
- // Iterate over the logical processor information structures
- // There should be one structure for each physical core
- char* ptr = (char*)buffer;
- while (ptr < (char*)buffer + bufferSize)
- {
- PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX item = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)ptr;
- if (item->Relationship == RelationProcessorCore && item->Processor.GroupCount > 0)
- {
- // Iterate over the groups
- int numThreads = 0;
- for (int group = 0; (group < item->Processor.GroupCount) && (numThreads < numThreadsPerCore); ++group)
- {
- GROUP_AFFINITY coreAffinity = item->Processor.GroupMask[group];
- while ((coreAffinity.Mask != 0) && (numThreads < numThreadsPerCore))
- {
- // Extract the next set bit/thread from the mask
- GROUP_AFFINITY threadAffinity = coreAffinity;
- threadAffinity.Mask = threadAffinity.Mask & -threadAffinity.Mask;
-
- // Push the affinity for this thread
- affinities.push_back(threadAffinity);
- oldAffinities.push_back(threadAffinity);
- numThreads++;
-
- // Remove this bit/thread from the mask
- coreAffinity.Mask ^= threadAffinity.Mask;
- }
- }
- }
-
- // Next structure
- ptr += item->Size;
- }
-
- // Free the buffer
- free(buffer);
- }
- }
-
- void ThreadAffinity::set(int threadIndex)
- {
- if (threadIndex >= (int)affinities.size())
- return;
-
- // Save the current affinity and set the new one
- const HANDLE thread = GetCurrentThread();
- if (!pSetThreadGroupAffinity(thread, &affinities[threadIndex], &oldAffinities[threadIndex]))
- OIDN_WARNING("SetThreadGroupAffinity failed");
- }
-
- void ThreadAffinity::restore(int threadIndex)
- {
- if (threadIndex >= (int)affinities.size())
- return;
-
- // Restore the original affinity
- const HANDLE thread = GetCurrentThread();
- if (!pSetThreadGroupAffinity(thread, &oldAffinities[threadIndex], nullptr))
- OIDN_WARNING("SetThreadGroupAffinity failed");
- }
-
-#elif defined(__linux__)
-
- // --------------------------------------------------------------------------
- // ThreadAffinity - Linux
- // --------------------------------------------------------------------------
-
- ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose)
- : Verbose(verbose)
- {
- std::vector<int> threadIds;
-
- // Parse the thread/CPU topology
- for (int cpuId = 0; ; cpuId++)
- {
- std::fstream fs;
- std::string cpu = std::string("/sys/devices/system/cpu/cpu") + std::to_string(cpuId) + std::string("/topology/thread_siblings_list");
- fs.open(cpu.c_str(), std::fstream::in);
- if (fs.fail()) break;
-
- int i;
- int j = 0;
- while ((j < numThreadsPerCore) && (fs >> i))
- {
- if (std::none_of(threadIds.begin(), threadIds.end(), [&](int id) { return id == i; }))
- threadIds.push_back(i);
-
- if (fs.peek() == ',')
- fs.ignore();
- j++;
- }
-
- fs.close();
- }
-
- #if 0
- for (size_t i = 0; i < thread_ids.size(); ++i)
- std::cout << "thread " << i << " -> " << thread_ids[i] << std::endl;
- #endif
-
- // Create the affinity structures
- affinities.resize(threadIds.size());
- oldAffinities.resize(threadIds.size());
-
- for (size_t i = 0; i < threadIds.size(); ++i)
- {
- cpu_set_t affinity;
- CPU_ZERO(&affinity);
- CPU_SET(threadIds[i], &affinity);
-
- affinities[i] = affinity;
- oldAffinities[i] = affinity;
- }
- }
-
- void ThreadAffinity::set(int threadIndex)
- {
- if (threadIndex >= (int)affinities.size())
- return;
-
- const pthread_t thread = pthread_self();
-
- // Save the current affinity
- if (pthread_getaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0)
- {
- OIDN_WARNING("pthread_getaffinity_np failed");
- oldAffinities[threadIndex] = affinities[threadIndex];
- return;
- }
-
- // Set the new affinity
- if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &affinities[threadIndex]) != 0)
- OIDN_WARNING("pthread_setaffinity_np failed");
- }
-
- void ThreadAffinity::restore(int threadIndex)
- {
- if (threadIndex >= (int)affinities.size())
- return;
-
- const pthread_t thread = pthread_self();
-
- // Restore the original affinity
- if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0)
- OIDN_WARNING("pthread_setaffinity_np failed");
- }
-
-#elif defined(__APPLE__)
-
- // --------------------------------------------------------------------------
- // ThreadAffinity - macOS
- // --------------------------------------------------------------------------
-
- ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose)
- : Verbose(verbose)
- {
- // Query the thread/CPU topology
- int numPhysicalCpus;
- int numLogicalCpus;
-
- if (!getSysctl("hw.physicalcpu", numPhysicalCpus) || !getSysctl("hw.logicalcpu", numLogicalCpus))
- {
- OIDN_WARNING("sysctlbyname failed");
- return;
- }
-
- if ((numLogicalCpus % numPhysicalCpus != 0) && (numThreadsPerCore > 1))
- return; // this shouldn't happen
- const int maxThreadsPerCore = numLogicalCpus / numPhysicalCpus;
-
- // Create the affinity structures
- // macOS doesn't support binding a thread to a specific core, but we can at least group threads which
- // should be on the same core together
- for (int core = 1; core <= numPhysicalCpus; ++core) // tags start from 1!
- {
- thread_affinity_policy affinity;
- affinity.affinity_tag = core;
-
- for (int thread = 0; thread < min(numThreadsPerCore, maxThreadsPerCore); ++thread)
- {
- affinities.push_back(affinity);
- oldAffinities.push_back(affinity);
- }
- }
- }
-
- void ThreadAffinity::set(int threadIndex)
- {
- if (threadIndex >= (int)affinities.size())
- return;
-
- const auto thread = mach_thread_self();
-
- // Save the current affinity
- mach_msg_type_number_t policyCount = THREAD_AFFINITY_POLICY_COUNT;
- boolean_t getDefault = FALSE;
- if (thread_policy_get(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], &policyCount, &getDefault) != KERN_SUCCESS)
- {
- OIDN_WARNING("thread_policy_get failed");
- oldAffinities[threadIndex] = affinities[threadIndex];
- return;
- }
-
- // Set the new affinity
- if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&affinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS)
- OIDN_WARNING("thread_policy_set failed");
- }
-
- void ThreadAffinity::restore(int threadIndex)
- {
- if (threadIndex >= (int)affinities.size())
- return;
-
- const auto thread = mach_thread_self();
-
- // Restore the original affinity
- if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS)
- OIDN_WARNING("thread_policy_set failed");
- }
-
-#endif
-
-} // namespace oidn
diff --git a/thirdparty/oidn/common/thread.h b/thirdparty/oidn/common/thread.h
deleted file mode 100644
index 2c731367da..0000000000
--- a/thirdparty/oidn/common/thread.h
+++ /dev/null
@@ -1,202 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "platform.h"
-
-#if !defined(_WIN32)
- #include <pthread.h>
- #include <sched.h>
- #if defined(__APPLE__)
- #include <mach/thread_policy.h>
- #endif
-#endif
-
-#include <vector>
-#include <mutex>
-
-namespace oidn {
-
- // --------------------------------------------------------------------------
- // ThreadLocal
- // --------------------------------------------------------------------------
-
- // Wrapper which makes any variable thread-local
- template<typename T>
- class ThreadLocal : public Verbose
- {
- private:
- #if defined(_WIN32)
- DWORD key;
- #else
- pthread_key_t key;
- #endif
-
- std::vector<T*> instances;
- std::mutex mutex;
-
- public:
- ThreadLocal(int verbose = 0)
- : Verbose(verbose)
- {
- #if defined(_WIN32)
- key = TlsAlloc();
- if (key == TLS_OUT_OF_INDEXES)
- OIDN_FATAL("TlsAlloc failed");
- #else
- if (pthread_key_create(&key, nullptr) != 0)
- OIDN_FATAL("pthread_key_create failed");
- #endif
- }
-
- ~ThreadLocal()
- {
- std::lock_guard<std::mutex> lock(mutex);
- for (T* ptr : instances)
- delete ptr;
-
- #if defined(_WIN32)
- if (!TlsFree(key))
- OIDN_WARNING("TlsFree failed");
- #else
- if (pthread_key_delete(key) != 0)
- OIDN_WARNING("pthread_key_delete failed");
- #endif
- }
-
- T& get()
- {
- #if defined(_WIN32)
- T* ptr = (T*)TlsGetValue(key);
- #else
- T* ptr = (T*)pthread_getspecific(key);
- #endif
-
- if (ptr)
- return *ptr;
-
- ptr = new T;
- std::lock_guard<std::mutex> lock(mutex);
- instances.push_back(ptr);
-
- #if defined(_WIN32)
- if (!TlsSetValue(key, ptr))
- OIDN_FATAL("TlsSetValue failed");
- #else
- if (pthread_setspecific(key, ptr) != 0)
- OIDN_FATAL("pthread_setspecific failed");
- #endif
-
- return *ptr;
- }
- };
-
-#if defined(_WIN32)
-
- // --------------------------------------------------------------------------
- // ThreadAffinity - Windows
- // --------------------------------------------------------------------------
-
- class ThreadAffinity : public Verbose
- {
- private:
- typedef BOOL (WINAPI *GetLogicalProcessorInformationExFunc)(LOGICAL_PROCESSOR_RELATIONSHIP,
- PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX,
- PDWORD);
-
- typedef BOOL (WINAPI *SetThreadGroupAffinityFunc)(HANDLE,
- CONST GROUP_AFFINITY*,
- PGROUP_AFFINITY);
-
- GetLogicalProcessorInformationExFunc pGetLogicalProcessorInformationEx = nullptr;
- SetThreadGroupAffinityFunc pSetThreadGroupAffinity = nullptr;
-
- std::vector<GROUP_AFFINITY> affinities; // thread affinities
- std::vector<GROUP_AFFINITY> oldAffinities; // original thread affinities
-
- public:
- ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0);
-
- int getNumThreads() const
- {
- return (int)affinities.size();
- }
-
- // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity
- void set(int threadIndex);
-
- // Restores the affinity of the thread
- void restore(int threadIndex);
- };
-
-#elif defined(__linux__)
-
- // --------------------------------------------------------------------------
- // ThreadAffinity - Linux
- // --------------------------------------------------------------------------
-
- class ThreadAffinity : public Verbose
- {
- private:
- std::vector<cpu_set_t> affinities; // thread affinities
- std::vector<cpu_set_t> oldAffinities; // original thread affinities
-
- public:
- ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0);
-
- int getNumThreads() const
- {
- return (int)affinities.size();
- }
-
- // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity
- void set(int threadIndex);
-
- // Restores the affinity of the thread
- void restore(int threadIndex);
- };
-
-#elif defined(__APPLE__)
-
- // --------------------------------------------------------------------------
- // ThreadAffinity - macOS
- // --------------------------------------------------------------------------
-
- class ThreadAffinity : public Verbose
- {
- private:
- std::vector<thread_affinity_policy> affinities; // thread affinities
- std::vector<thread_affinity_policy> oldAffinities; // original thread affinities
-
- public:
- ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0);
-
- int getNumThreads() const
- {
- return (int)affinities.size();
- }
-
- // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity
- void set(int threadIndex);
-
- // Restores the affinity of the thread
- void restore(int threadIndex);
- };
-
-#endif
-
-} // namespace oidn
diff --git a/thirdparty/oidn/common/timer.h b/thirdparty/oidn/common/timer.h
deleted file mode 100644
index 62aaaa1c33..0000000000
--- a/thirdparty/oidn/common/timer.h
+++ /dev/null
@@ -1,49 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "platform.h"
-#include <chrono>
-
-namespace oidn {
-
- class Timer
- {
- private:
- using clock = std::chrono::high_resolution_clock;
-
- std::chrono::time_point<clock> start;
-
- public:
- Timer()
- {
- reset();
- }
-
- void reset()
- {
- start = clock::now();
- }
-
- double query() const
- {
- auto end = clock::now();
- return std::chrono::duration_cast<std::chrono::duration<double>>(end - start).count();
- }
- };
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/api.cpp b/thirdparty/oidn/core/api.cpp
deleted file mode 100644
index 7353fe4e25..0000000000
--- a/thirdparty/oidn/core/api.cpp
+++ /dev/null
@@ -1,408 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#ifdef _WIN32
-# define OIDN_API extern "C" __declspec(dllexport)
-#else
-# define OIDN_API extern "C" __attribute__ ((visibility ("default")))
-#endif
-
-// Locks the device that owns the specified object
-// Use *only* inside OIDN_TRY/CATCH!
-#define OIDN_LOCK(obj) \
- std::lock_guard<std::mutex> lock(obj->getDevice()->getMutex());
-
-// Try/catch for converting exceptions to errors
-#define OIDN_TRY \
- try {
-
-#define OIDN_CATCH(obj) \
- } catch (Exception& e) { \
- Device::setError(obj ? obj->getDevice() : nullptr, e.code(), e.what()); \
- } catch (std::bad_alloc&) { \
- Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \
- } catch (mkldnn::error& e) { \
- if (e.status == mkldnn_out_of_memory) \
- Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory"); \
- else \
- Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.message); \
- } catch (std::exception& e) { \
- Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.what()); \
- } catch (...) { \
- Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, "unknown exception caught"); \
- }
-
-#include "device.h"
-#include "filter.h"
-#include <mutex>
-
-namespace oidn {
-
- namespace
- {
- __forceinline void checkHandle(void* handle)
- {
- if (handle == nullptr)
- throw Exception(Error::InvalidArgument, "invalid handle");
- }
-
- template<typename T>
- __forceinline void retainObject(T* obj)
- {
- if (obj)
- {
- obj->incRef();
- }
- else
- {
- OIDN_TRY
- checkHandle(obj);
- OIDN_CATCH(obj)
- }
- }
-
- template<typename T>
- __forceinline void releaseObject(T* obj)
- {
- if (obj == nullptr || obj->decRefKeep() == 0)
- {
- OIDN_TRY
- checkHandle(obj);
- OIDN_LOCK(obj);
- obj->destroy();
- OIDN_CATCH(obj)
- }
- }
-
- template<>
- __forceinline void releaseObject(Device* obj)
- {
- if (obj == nullptr || obj->decRefKeep() == 0)
- {
- OIDN_TRY
- checkHandle(obj);
- // Do NOT lock the device because it owns the mutex
- obj->destroy();
- OIDN_CATCH(obj)
- }
- }
- }
-
- OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type)
- {
- Ref<Device> device = nullptr;
- OIDN_TRY
- if (type == OIDN_DEVICE_TYPE_CPU || type == OIDN_DEVICE_TYPE_DEFAULT)
- device = makeRef<Device>();
- else
- throw Exception(Error::InvalidArgument, "invalid device type");
- OIDN_CATCH(device)
- return (OIDNDevice)device.detach();
- }
-
- OIDN_API void oidnRetainDevice(OIDNDevice hDevice)
- {
- Device* device = (Device*)hDevice;
- retainObject(device);
- }
-
- OIDN_API void oidnReleaseDevice(OIDNDevice hDevice)
- {
- Device* device = (Device*)hDevice;
- releaseObject(device);
- }
-
- OIDN_API void oidnSetDevice1b(OIDNDevice hDevice, const char* name, bool value)
- {
- Device* device = (Device*)hDevice;
- OIDN_TRY
- checkHandle(hDevice);
- OIDN_LOCK(device);
- device->set1i(name, value);
- OIDN_CATCH(device)
- }
-
- OIDN_API void oidnSetDevice1i(OIDNDevice hDevice, const char* name, int value)
- {
- Device* device = (Device*)hDevice;
- OIDN_TRY
- checkHandle(hDevice);
- OIDN_LOCK(device);
- device->set1i(name, value);
- OIDN_CATCH(device)
- }
-
- OIDN_API bool oidnGetDevice1b(OIDNDevice hDevice, const char* name)
- {
- Device* device = (Device*)hDevice;
- OIDN_TRY
- checkHandle(hDevice);
- OIDN_LOCK(device);
- return device->get1i(name);
- OIDN_CATCH(device)
- return false;
- }
-
- OIDN_API int oidnGetDevice1i(OIDNDevice hDevice, const char* name)
- {
- Device* device = (Device*)hDevice;
- OIDN_TRY
- checkHandle(hDevice);
- OIDN_LOCK(device);
- return device->get1i(name);
- OIDN_CATCH(device)
- return 0;
- }
-
- OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice hDevice, OIDNErrorFunction func, void* userPtr)
- {
- Device* device = (Device*)hDevice;
- OIDN_TRY
- checkHandle(hDevice);
- OIDN_LOCK(device);
- device->setErrorFunction((ErrorFunction)func, userPtr);
- OIDN_CATCH(device)
- }
-
- OIDN_API OIDNError oidnGetDeviceError(OIDNDevice hDevice, const char** outMessage)
- {
- Device* device = (Device*)hDevice;
- OIDN_TRY
- return (OIDNError)Device::getError(device, outMessage);
- OIDN_CATCH(device)
- if (outMessage) *outMessage = "";
- return OIDN_ERROR_UNKNOWN;
- }
-
- OIDN_API void oidnCommitDevice(OIDNDevice hDevice)
- {
- Device* device = (Device*)hDevice;
- OIDN_TRY
- checkHandle(hDevice);
- OIDN_LOCK(device);
- device->commit();
- OIDN_CATCH(device)
- }
-
- OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice hDevice, size_t byteSize)
- {
- Device* device = (Device*)hDevice;
- OIDN_TRY
- checkHandle(hDevice);
- OIDN_LOCK(device);
- Ref<Buffer> buffer = device->newBuffer(byteSize);
- return (OIDNBuffer)buffer.detach();
- OIDN_CATCH(device)
- return nullptr;
- }
-
- OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice hDevice, void* ptr, size_t byteSize)
- {
- Device* device = (Device*)hDevice;
- OIDN_TRY
- checkHandle(hDevice);
- OIDN_LOCK(device);
- Ref<Buffer> buffer = device->newBuffer(ptr, byteSize);
- return (OIDNBuffer)buffer.detach();
- OIDN_CATCH(device)
- return nullptr;
- }
-
- OIDN_API void oidnRetainBuffer(OIDNBuffer hBuffer)
- {
- Buffer* buffer = (Buffer*)hBuffer;
- retainObject(buffer);
- }
-
- OIDN_API void oidnReleaseBuffer(OIDNBuffer hBuffer)
- {
- Buffer* buffer = (Buffer*)hBuffer;
- releaseObject(buffer);
- }
-
- OIDN_API void* oidnMapBuffer(OIDNBuffer hBuffer, OIDNAccess access, size_t byteOffset, size_t byteSize)
- {
- Buffer* buffer = (Buffer*)hBuffer;
- OIDN_TRY
- checkHandle(hBuffer);
- OIDN_LOCK(buffer);
- return buffer->map(byteOffset, byteSize);
- OIDN_CATCH(buffer)
- return nullptr;
- }
-
- OIDN_API void oidnUnmapBuffer(OIDNBuffer hBuffer, void* mappedPtr)
- {
- Buffer* buffer = (Buffer*)hBuffer;
- OIDN_TRY
- checkHandle(hBuffer);
- OIDN_LOCK(buffer);
- return buffer->unmap(mappedPtr);
- OIDN_CATCH(buffer)
- }
-
- OIDN_API OIDNFilter oidnNewFilter(OIDNDevice hDevice, const char* type)
- {
- Device* device = (Device*)hDevice;
- OIDN_TRY
- checkHandle(hDevice);
- OIDN_LOCK(device);
- Ref<Filter> filter = device->newFilter(type);
- return (OIDNFilter)filter.detach();
- OIDN_CATCH(device)
- return nullptr;
- }
-
- OIDN_API void oidnRetainFilter(OIDNFilter hFilter)
- {
- Filter* filter = (Filter*)hFilter;
- retainObject(filter);
- }
-
- OIDN_API void oidnReleaseFilter(OIDNFilter hFilter)
- {
- Filter* filter = (Filter*)hFilter;
- releaseObject(filter);
- }
-
- OIDN_API void oidnSetFilterImage(OIDNFilter hFilter, const char* name,
- OIDNBuffer hBuffer, OIDNFormat format,
- size_t width, size_t height,
- size_t byteOffset,
- size_t bytePixelStride, size_t byteRowStride)
- {
- Filter* filter = (Filter*)hFilter;
- OIDN_TRY
- checkHandle(hFilter);
- checkHandle(hBuffer);
- OIDN_LOCK(filter);
- Ref<Buffer> buffer = (Buffer*)hBuffer;
- if (buffer->getDevice() != filter->getDevice())
- throw Exception(Error::InvalidArgument, "the specified objects are bound to different devices");
- Image data(buffer, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride);
- filter->setImage(name, data);
- OIDN_CATCH(filter)
- }
-
- OIDN_API void oidnSetSharedFilterImage(OIDNFilter hFilter, const char* name,
- void* ptr, OIDNFormat format,
- size_t width, size_t height,
- size_t byteOffset,
- size_t bytePixelStride, size_t byteRowStride)
- {
- Filter* filter = (Filter*)hFilter;
- OIDN_TRY
- checkHandle(hFilter);
- OIDN_LOCK(filter);
- Image data(ptr, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride);
- filter->setImage(name, data);
- OIDN_CATCH(filter)
- }
-
- OIDN_API void oidnSetFilter1b(OIDNFilter hFilter, const char* name, bool value)
- {
- Filter* filter = (Filter*)hFilter;
- OIDN_TRY
- checkHandle(hFilter);
- OIDN_LOCK(filter);
- filter->set1i(name, int(value));
- OIDN_CATCH(filter)
- }
-
- OIDN_API bool oidnGetFilter1b(OIDNFilter hFilter, const char* name)
- {
- Filter* filter = (Filter*)hFilter;
- OIDN_TRY
- checkHandle(hFilter);
- OIDN_LOCK(filter);
- return filter->get1i(name);
- OIDN_CATCH(filter)
- return false;
- }
-
- OIDN_API void oidnSetFilter1i(OIDNFilter hFilter, const char* name, int value)
- {
- Filter* filter = (Filter*)hFilter;
- OIDN_TRY
- checkHandle(hFilter);
- OIDN_LOCK(filter);
- filter->set1i(name, value);
- OIDN_CATCH(filter)
- }
-
- OIDN_API int oidnGetFilter1i(OIDNFilter hFilter, const char* name)
- {
- Filter* filter = (Filter*)hFilter;
- OIDN_TRY
- checkHandle(hFilter);
- OIDN_LOCK(filter);
- return filter->get1i(name);
- OIDN_CATCH(filter)
- return 0;
- }
-
- OIDN_API void oidnSetFilter1f(OIDNFilter hFilter, const char* name, float value)
- {
- Filter* filter = (Filter*)hFilter;
- OIDN_TRY
- checkHandle(hFilter);
- OIDN_LOCK(filter);
- filter->set1f(name, value);
- OIDN_CATCH(filter)
- }
-
- OIDN_API float oidnGetFilter1f(OIDNFilter hFilter, const char* name)
- {
- Filter* filter = (Filter*)hFilter;
- OIDN_TRY
- checkHandle(hFilter);
- OIDN_LOCK(filter);
- return filter->get1f(name);
- OIDN_CATCH(filter)
- return 0;
- }
-
- OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter hFilter, OIDNProgressMonitorFunction func, void* userPtr)
- {
- Filter* filter = (Filter*)hFilter;
- OIDN_TRY
- checkHandle(hFilter);
- OIDN_LOCK(filter);
- filter->setProgressMonitorFunction(func, userPtr);
- OIDN_CATCH(filter)
- }
-
- OIDN_API void oidnCommitFilter(OIDNFilter hFilter)
- {
- Filter* filter = (Filter*)hFilter;
- OIDN_TRY
- checkHandle(hFilter);
- OIDN_LOCK(filter);
- filter->commit();
- OIDN_CATCH(filter)
- }
-
- OIDN_API void oidnExecuteFilter(OIDNFilter hFilter)
- {
- Filter* filter = (Filter*)hFilter;
- OIDN_TRY
- checkHandle(hFilter);
- OIDN_LOCK(filter);
- filter->execute();
- OIDN_CATCH(filter)
- }
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/autoencoder.cpp b/thirdparty/oidn/core/autoencoder.cpp
deleted file mode 100644
index d8da684cb8..0000000000
--- a/thirdparty/oidn/core/autoencoder.cpp
+++ /dev/null
@@ -1,535 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#include "autoencoder.h"
-
-namespace oidn {
-
- // --------------------------------------------------------------------------
- // AutoencoderFilter
- // --------------------------------------------------------------------------
-
- AutoencoderFilter::AutoencoderFilter(const Ref<Device>& device)
- : Filter(device)
- {
- }
-
- void AutoencoderFilter::setImage(const std::string& name, const Image& data)
- {
- if (name == "color")
- color = data;
- else if (name == "albedo")
- albedo = data;
- else if (name == "normal")
- normal = data;
- else if (name == "output")
- output = data;
-
- dirty = true;
- }
-
- void AutoencoderFilter::set1i(const std::string& name, int value)
- {
- if (name == "hdr")
- hdr = value;
- else if (name == "srgb")
- srgb = value;
- else if (name == "maxMemoryMB")
- maxMemoryMB = value;
-
- dirty = true;
- }
-
- int AutoencoderFilter::get1i(const std::string& name)
- {
- if (name == "hdr")
- return hdr;
- else if (name == "srgb")
- return srgb;
- else if (name == "maxMemoryMB")
- return maxMemoryMB;
- else if (name == "alignment")
- return alignment;
- else if (name == "overlap")
- return overlap;
- else
- throw Exception(Error::InvalidArgument, "invalid parameter");
- }
-
- void AutoencoderFilter::set1f(const std::string& name, float value)
- {
- if (name == "hdrScale")
- hdrScale = value;
-
- dirty = true;
- }
-
- float AutoencoderFilter::get1f(const std::string& name)
- {
- if (name == "hdrScale")
- return hdrScale;
- else
- throw Exception(Error::InvalidArgument, "invalid parameter");
- }
-
- void AutoencoderFilter::commit()
- {
- if (!dirty)
- return;
-
- // -- GODOT start --
- //device->executeTask([&]()
- //{
- // GODOT end --
-
- if (mayiuse(avx512_common))
- net = buildNet<16>();
- else
- net = buildNet<8>();
-
- // GODOT start --
- //});
- // GODOT end --
-
- dirty = false;
- }
-
- void AutoencoderFilter::execute()
- {
- if (dirty)
- throw Exception(Error::InvalidOperation, "changes to the filter are not committed");
-
- if (!net)
- return;
- // -- GODOT start --
- //device->executeTask([&]()
- //{
- // -- GODOT end --
- Progress progress;
- progress.func = progressFunc;
- progress.userPtr = progressUserPtr;
- progress.taskCount = tileCountH * tileCountW;
-
- // Iterate over the tiles
- int tileIndex = 0;
-
- for (int i = 0; i < tileCountH; ++i)
- {
- const int h = i * (tileH - 2*overlap); // input tile position (including overlap)
- const int overlapBeginH = i > 0 ? overlap : 0; // overlap on the top
- const int overlapEndH = i < tileCountH-1 ? overlap : 0; // overlap on the bottom
- const int tileH1 = min(H - h, tileH); // input tile size (including overlap)
- const int tileH2 = tileH1 - overlapBeginH - overlapEndH; // output tile size
- const int alignOffsetH = tileH - roundUp(tileH1, alignment); // align to the bottom in the tile buffer
-
- for (int j = 0; j < tileCountW; ++j)
- {
- const int w = j * (tileW - 2*overlap); // input tile position (including overlap)
- const int overlapBeginW = j > 0 ? overlap : 0; // overlap on the left
- const int overlapEndW = j < tileCountW-1 ? overlap : 0; // overlap on the right
- const int tileW1 = min(W - w, tileW); // input tile size (including overlap)
- const int tileW2 = tileW1 - overlapBeginW - overlapEndW; // output tile size
- const int alignOffsetW = tileW - roundUp(tileW1, alignment); // align to the right in the tile buffer
-
- // Set the input tile
- inputReorder->setTile(h, w,
- alignOffsetH, alignOffsetW,
- tileH1, tileW1);
-
- // Set the output tile
- outputReorder->setTile(alignOffsetH + overlapBeginH, alignOffsetW + overlapBeginW,
- h + overlapBeginH, w + overlapBeginW,
- tileH2, tileW2);
-
- //printf("Tile: %d %d -> %d %d\n", w+overlapBeginW, h+overlapBeginH, w+overlapBeginW+tileW2, h+overlapBeginH+tileH2);
-
- // Denoise the tile
- net->execute(progress, tileIndex);
-
- // Next tile
- tileIndex++;
- }
- }
- // -- GODOT start --
- //});
- // -- GODOT end --
- }
-
- void AutoencoderFilter::computeTileSize()
- {
- const int minTileSize = 3*overlap;
- const int estimatedBytesPerPixel = mayiuse(avx512_common) ? estimatedBytesPerPixel16 : estimatedBytesPerPixel8;
- const int64_t maxTilePixels = (int64_t(maxMemoryMB)*1024*1024 - estimatedBytesBase) / estimatedBytesPerPixel;
-
- tileCountH = 1;
- tileCountW = 1;
- tileH = roundUp(H, alignment);
- tileW = roundUp(W, alignment);
-
- // Divide the image into tiles until the tile size gets below the threshold
- while (int64_t(tileH) * tileW > maxTilePixels)
- {
- if (tileH > minTileSize && tileH > tileW)
- {
- tileCountH++;
- tileH = max(roundUp(ceilDiv(H - 2*overlap, tileCountH), alignment) + 2*overlap, minTileSize);
- }
- else if (tileW > minTileSize)
- {
- tileCountW++;
- tileW = max(roundUp(ceilDiv(W - 2*overlap, tileCountW), alignment) + 2*overlap, minTileSize);
- }
- else
- break;
- }
-
- // Compute the final number of tiles
- tileCountH = (H > tileH) ? ceilDiv(H - 2*overlap, tileH - 2*overlap) : 1;
- tileCountW = (W > tileW) ? ceilDiv(W - 2*overlap, tileW - 2*overlap) : 1;
-
- if (device->isVerbose(2))
- {
- std::cout << "Tile size : " << tileW << "x" << tileH << std::endl;
- std::cout << "Tile count: " << tileCountW << "x" << tileCountH << std::endl;
- }
- }
-
- template<int K>
- std::shared_ptr<Executable> AutoencoderFilter::buildNet()
- {
- H = color.height;
- W = color.width;
-
- // Configure the network
- int inputC;
- void* weightPtr;
-
- if (srgb && hdr)
- throw Exception(Error::InvalidOperation, "srgb and hdr modes cannot be enabled at the same time");
-
- if (color && !albedo && !normal && weightData.hdr)
- {
- inputC = 3;
- weightPtr = hdr ? weightData.hdr : weightData.ldr;
- }
- else if (color && albedo && !normal && weightData.hdr_alb)
- {
- inputC = 6;
- weightPtr = hdr ? weightData.hdr_alb : weightData.ldr_alb;
- }
- else if (color && albedo && normal && weightData.hdr_alb_nrm)
- {
- inputC = 9;
- weightPtr = hdr ? weightData.hdr_alb_nrm : weightData.ldr_alb_nrm;
- }
- else
- {
- throw Exception(Error::InvalidOperation, "unsupported combination of input features");
- }
-
- if (!output)
- throw Exception(Error::InvalidOperation, "output image not specified");
-
- if ((color.format != Format::Float3)
- || (albedo && albedo.format != Format::Float3)
- || (normal && normal.format != Format::Float3)
- || (output.format != Format::Float3))
- throw Exception(Error::InvalidOperation, "unsupported image format");
-
- if ((albedo && (albedo.width != W || albedo.height != H))
- || (normal && (normal.width != W || normal.height != H))
- || (output.width != W || output.height != H))
- throw Exception(Error::InvalidOperation, "image size mismatch");
-
- // Compute the tile size
- computeTileSize();
-
- // If the image size is zero, there is nothing else to do
- if (H <= 0 || W <= 0)
- return nullptr;
-
- // Parse the weights
- const auto weightMap = parseTensors(weightPtr);
-
- // Create the network
- std::shared_ptr<Network<K>> net = std::make_shared<Network<K>>(device, weightMap);
-
- // Compute the tensor sizes
- const auto inputDims = memory::dims({1, inputC, tileH, tileW});
- const auto inputReorderDims = net->getInputReorderDims(inputDims, alignment); //-> concat0
-
- const auto conv1Dims = net->getConvDims("conv1", inputReorderDims); //-> temp0
- const auto conv1bDims = net->getConvDims("conv1b", conv1Dims); //-> temp1
- const auto pool1Dims = net->getPoolDims(conv1bDims); //-> concat1
- const auto conv2Dims = net->getConvDims("conv2", pool1Dims); //-> temp0
- const auto pool2Dims = net->getPoolDims(conv2Dims); //-> concat2
- const auto conv3Dims = net->getConvDims("conv3", pool2Dims); //-> temp0
- const auto pool3Dims = net->getPoolDims(conv3Dims); //-> concat3
- const auto conv4Dims = net->getConvDims("conv4", pool3Dims); //-> temp0
- const auto pool4Dims = net->getPoolDims(conv4Dims); //-> concat4
- const auto conv5Dims = net->getConvDims("conv5", pool4Dims); //-> temp0
- const auto pool5Dims = net->getPoolDims(conv5Dims); //-> temp1
- const auto upsample4Dims = net->getUpsampleDims(pool5Dims); //-> concat4
- const auto concat4Dims = net->getConcatDims(upsample4Dims, pool4Dims);
- const auto conv6Dims = net->getConvDims("conv6", concat4Dims); //-> temp0
- const auto conv6bDims = net->getConvDims("conv6b", conv6Dims); //-> temp1
- const auto upsample3Dims = net->getUpsampleDims(conv6bDims); //-> concat3
- const auto concat3Dims = net->getConcatDims(upsample3Dims, pool3Dims);
- const auto conv7Dims = net->getConvDims("conv7", concat3Dims); //-> temp0
- const auto conv7bDims = net->getConvDims("conv7b", conv7Dims); //-> temp1
- const auto upsample2Dims = net->getUpsampleDims(conv7bDims); //-> concat2
- const auto concat2Dims = net->getConcatDims(upsample2Dims, pool2Dims);
- const auto conv8Dims = net->getConvDims("conv8", concat2Dims); //-> temp0
- const auto conv8bDims = net->getConvDims("conv8b", conv8Dims); //-> temp1
- const auto upsample1Dims = net->getUpsampleDims(conv8bDims); //-> concat1
- const auto concat1Dims = net->getConcatDims(upsample1Dims, pool1Dims);
- const auto conv9Dims = net->getConvDims("conv9", concat1Dims); //-> temp0
- const auto conv9bDims = net->getConvDims("conv9b", conv9Dims); //-> temp1
- const auto upsample0Dims = net->getUpsampleDims(conv9bDims); //-> concat0
- const auto concat0Dims = net->getConcatDims(upsample0Dims, inputReorderDims);
- const auto conv10Dims = net->getConvDims("conv10", concat0Dims); //-> temp0
- const auto conv10bDims = net->getConvDims("conv10b", conv10Dims); //-> temp1
- const auto conv11Dims = net->getConvDims("conv11", conv10bDims); //-> temp0
-
- const auto outputDims = memory::dims({1, 3, tileH, tileW});
-
- // Allocate two temporary ping-pong buffers to decrease memory usage
- const auto temp0Dims = getMaxTensorDims({
- conv1Dims,
- conv2Dims,
- conv3Dims,
- conv4Dims,
- conv5Dims,
- conv6Dims,
- conv7Dims,
- conv8Dims,
- conv9Dims,
- conv10Dims,
- conv11Dims
- });
-
- const auto temp1Dims = getMaxTensorDims({
- conv1bDims,
- pool5Dims,
- conv6bDims,
- conv7bDims,
- conv8bDims,
- conv9bDims,
- conv10bDims,
- });
-
- auto temp0 = net->allocTensor(temp0Dims);
- auto temp1 = net->allocTensor(temp1Dims);
-
- // Allocate enough memory to hold the concat outputs. Then use the first
- // half to hold the previous conv output and the second half to hold the
- // pool/orig image output. This works because everything is C dimension
- // outermost, padded to K floats, and all the concats are on the C dimension.
- auto concat0Dst = net->allocTensor(concat0Dims);
- auto concat1Dst = net->allocTensor(concat1Dims);
- auto concat2Dst = net->allocTensor(concat2Dims);
- auto concat3Dst = net->allocTensor(concat3Dims);
- auto concat4Dst = net->allocTensor(concat4Dims);
-
- // Transfer function
- std::shared_ptr<TransferFunction> transferFunc = makeTransferFunc();
-
- // Autoexposure
- if (auto tf = std::dynamic_pointer_cast<HDRTransferFunction>(transferFunc))
- {
- if (isnan(hdrScale))
- net->addAutoexposure(color, tf);
- else
- tf->setExposure(hdrScale);
- }
-
- // Input reorder
- auto inputReorderDst = net->castTensor(inputReorderDims, concat0Dst, upsample0Dims);
- inputReorder = net->addInputReorder(color, albedo, normal,
- transferFunc,
- alignment, inputReorderDst);
-
- // conv1
- auto conv1 = net->addConv("conv1", inputReorder->getDst(), temp0);
-
- // conv1b
- auto conv1b = net->addConv("conv1b", conv1->getDst(), temp1);
-
- // pool1
- // Adjust pointer for pool1 to eliminate concat1
- auto pool1Dst = net->castTensor(pool1Dims, concat1Dst, upsample1Dims);
- auto pool1 = net->addPool(conv1b->getDst(), pool1Dst);
-
- // conv2
- auto conv2 = net->addConv("conv2", pool1->getDst(), temp0);
-
- // pool2
- // Adjust pointer for pool2 to eliminate concat2
- auto pool2Dst = net->castTensor(pool2Dims, concat2Dst, upsample2Dims);
- auto pool2 = net->addPool(conv2->getDst(), pool2Dst);
-
- // conv3
- auto conv3 = net->addConv("conv3", pool2->getDst(), temp0);
-
- // pool3
- // Adjust pointer for pool3 to eliminate concat3
- auto pool3Dst = net->castTensor(pool3Dims, concat3Dst, upsample3Dims);
- auto pool3 = net->addPool(conv3->getDst(), pool3Dst);
-
- // conv4
- auto conv4 = net->addConv("conv4", pool3->getDst(), temp0);
-
- // pool4
- // Adjust pointer for pool4 to eliminate concat4
- auto pool4Dst = net->castTensor(pool4Dims, concat4Dst, upsample4Dims);
- auto pool4 = net->addPool(conv4->getDst(), pool4Dst);
-
- // conv5
- auto conv5 = net->addConv("conv5", pool4->getDst(), temp0);
-
- // pool5
- auto pool5 = net->addPool(conv5->getDst(), temp1);
-
- // upsample4
- auto upsample4Dst = net->castTensor(upsample4Dims, concat4Dst);
- auto upsample4 = net->addUpsample(pool5->getDst(), upsample4Dst);
-
- // conv6
- auto conv6 = net->addConv("conv6", concat4Dst, temp0);
-
- // conv6b
- auto conv6b = net->addConv("conv6b", conv6->getDst(), temp1);
-
- // upsample3
- auto upsample3Dst = net->castTensor(upsample3Dims, concat3Dst);
- auto upsample3 = net->addUpsample(conv6b->getDst(), upsample3Dst);
-
- // conv7
- auto conv7 = net->addConv("conv7", concat3Dst, temp0);
-
- // conv7b
- auto conv7b = net->addConv("conv7b", conv7->getDst(), temp1);
-
- // upsample2
- auto upsample2Dst = net->castTensor(upsample2Dims, concat2Dst);
- auto upsample2 = net->addUpsample(conv7b->getDst(), upsample2Dst);
-
- // conv8
- auto conv8 = net->addConv("conv8", concat2Dst, temp0);
-
- // conv8b
- auto conv8b = net->addConv("conv8b", conv8->getDst(), temp1);
-
- // upsample1
- auto upsample1Dst = net->castTensor(upsample1Dims, concat1Dst);
- auto upsample1 = net->addUpsample(conv8b->getDst(), upsample1Dst);
-
- // conv9
- auto conv9 = net->addConv("conv9", concat1Dst, temp0);
-
- // conv9b
- auto conv9b = net->addConv("conv9b", conv9->getDst(), temp1);
-
- // upsample0
- auto upsample0Dst = net->castTensor(upsample0Dims, concat0Dst);
- auto upsample0 = net->addUpsample(conv9b->getDst(), upsample0Dst);
-
- // conv10
- auto conv10 = net->addConv("conv10", concat0Dst, temp0);
-
- // conv10b
- auto conv10b = net->addConv("conv10b", conv10->getDst(), temp1);
-
- // conv11
- auto conv11 = net->addConv("conv11", conv10b->getDst(), temp0, false /* no relu */);
-
- // Output reorder
- outputReorder = net->addOutputReorder(conv11->getDst(), transferFunc, output);
-
- net->finalize();
- return net;
- }
-
- std::shared_ptr<TransferFunction> AutoencoderFilter::makeTransferFunc()
- {
- if (hdr)
- return std::make_shared<PQXTransferFunction>();
- else if (srgb)
- return std::make_shared<LinearTransferFunction>();
- else
- return std::make_shared<GammaTransferFunction>();
- }
-
-// -- GODOT start --
-// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
-#if 0
-// -- GODOT end --
-
- // --------------------------------------------------------------------------
- // RTFilter
- // --------------------------------------------------------------------------
-
- namespace weights
- {
- // LDR
- extern unsigned char rt_ldr[]; // color
- extern unsigned char rt_ldr_alb[]; // color, albedo
- extern unsigned char rt_ldr_alb_nrm[]; // color, albedo, normal
-
- // HDR
- extern unsigned char rt_hdr[]; // color
- extern unsigned char rt_hdr_alb[]; // color, albedo
- extern unsigned char rt_hdr_alb_nrm[]; // color, albedo, normal
- }
-
- RTFilter::RTFilter(const Ref<Device>& device)
- : AutoencoderFilter(device)
- {
- weightData.ldr = weights::rt_ldr;
- weightData.ldr_alb = weights::rt_ldr_alb;
- weightData.ldr_alb_nrm = weights::rt_ldr_alb_nrm;
- weightData.hdr = weights::rt_hdr;
- weightData.hdr_alb = weights::rt_hdr_alb;
- weightData.hdr_alb_nrm = weights::rt_hdr_alb_nrm;
- }
-// -- GODOT start --
-#endif
-// -- GODOT end --
-
- // --------------------------------------------------------------------------
- // RTLightmapFilter
- // --------------------------------------------------------------------------
-
- namespace weights
- {
- // HDR
- extern unsigned char rtlightmap_hdr[]; // color
- }
-
- RTLightmapFilter::RTLightmapFilter(const Ref<Device>& device)
- : AutoencoderFilter(device)
- {
- weightData.hdr = weights::rtlightmap_hdr;
-
- hdr = true;
- }
-
- std::shared_ptr<TransferFunction> RTLightmapFilter::makeTransferFunc()
- {
- return std::make_shared<LogTransferFunction>();
- }
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/autoencoder.h b/thirdparty/oidn/core/autoencoder.h
deleted file mode 100644
index 98b610844e..0000000000
--- a/thirdparty/oidn/core/autoencoder.h
+++ /dev/null
@@ -1,120 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "filter.h"
-#include "network.h"
-#include "transfer_function.h"
-
-namespace oidn {
-
- // --------------------------------------------------------------------------
- // AutoencoderFilter - Direct-predicting autoencoder
- // --------------------------------------------------------------------------
-
- class AutoencoderFilter : public Filter
- {
- protected:
- static constexpr int alignment = 32; // required spatial alignment in pixels (padding may be necessary)
- static constexpr int receptiveField = 222; // receptive field in pixels
- static constexpr int overlap = roundUp(receptiveField / 2, alignment); // required spatial overlap between tiles in pixels
-
- static constexpr int estimatedBytesBase = 16*1024*1024; // estimated base memory usage
- static constexpr int estimatedBytesPerPixel8 = 889; // estimated memory usage per pixel for K=8
- static constexpr int estimatedBytesPerPixel16 = 2185; // estimated memory usage per pixel for K=16
-
- Image color;
- Image albedo;
- Image normal;
- Image output;
- bool hdr = false;
- float hdrScale = std::numeric_limits<float>::quiet_NaN();
- bool srgb = false;
- int maxMemoryMB = 6000; // approximate maximum memory usage in MBs
-
- int H = 0; // image height
- int W = 0; // image width
- int tileH = 0; // tile height
- int tileW = 0; // tile width
- int tileCountH = 1; // number of tiles in H dimension
- int tileCountW = 1; // number of tiles in W dimension
-
- std::shared_ptr<Executable> net;
- std::shared_ptr<Node> inputReorder;
- std::shared_ptr<Node> outputReorder;
-
- struct
- {
- void* ldr = nullptr;
- void* ldr_alb = nullptr;
- void* ldr_alb_nrm = nullptr;
- void* hdr = nullptr;
- void* hdr_alb = nullptr;
- void* hdr_alb_nrm = nullptr;
- } weightData;
-
- explicit AutoencoderFilter(const Ref<Device>& device);
- virtual std::shared_ptr<TransferFunction> makeTransferFunc();
-
- public:
- void setImage(const std::string& name, const Image& data) override;
- void set1i(const std::string& name, int value) override;
- int get1i(const std::string& name) override;
- void set1f(const std::string& name, float value) override;
- float get1f(const std::string& name) override;
-
- void commit() override;
- void execute() override;
-
- private:
- void computeTileSize();
-
- template<int K>
- std::shared_ptr<Executable> buildNet();
-
- bool isCommitted() const { return bool(net); }
- };
-
- // --------------------------------------------------------------------------
- // RTFilter - Generic ray tracing denoiser
- // --------------------------------------------------------------------------
-
-// -- GODOT start --
-// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
-#if 0
-// -- GODOT end --
- class RTFilter : public AutoencoderFilter
- {
- public:
- explicit RTFilter(const Ref<Device>& device);
- };
-// -- GODOT start --
-#endif
-// -- GODOT end --
-
- // --------------------------------------------------------------------------
- // RTLightmapFilter - Ray traced lightmap denoiser
- // --------------------------------------------------------------------------
-
- class RTLightmapFilter : public AutoencoderFilter
- {
- public:
- explicit RTLightmapFilter(const Ref<Device>& device);
- std::shared_ptr<TransferFunction> makeTransferFunc() override;
- };
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/buffer.h b/thirdparty/oidn/core/buffer.h
deleted file mode 100644
index b95109152e..0000000000
--- a/thirdparty/oidn/core/buffer.h
+++ /dev/null
@@ -1,75 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "common.h"
-#include "device.h"
-
-namespace oidn {
-
- class Device;
-
- // Buffer which may or may not own its data
- class Buffer : public RefCount
- {
- private:
- char* ptr;
- size_t byteSize;
- bool shared;
- Ref<Device> device;
-
- public:
- __forceinline Buffer(const Ref<Device>& device, size_t size)
- : ptr((char*)alignedMalloc(size, 64)),
- byteSize(size),
- shared(false),
- device(device) {}
-
- __forceinline Buffer(const Ref<Device>& device, void* data, size_t size)
- : ptr((char*)data),
- byteSize(size),
- shared(true),
- device(device)
- {
- if (data == nullptr)
- throw Exception(Error::InvalidArgument, "buffer pointer null");
- }
-
- __forceinline ~Buffer()
- {
- if (!shared)
- alignedFree(ptr);
- }
-
- __forceinline char* data() { return ptr; }
- __forceinline const char* data() const { return ptr; }
- __forceinline size_t size() const { return byteSize; }
-
- void* map(size_t offset, size_t size)
- {
- if (offset + size > byteSize)
- throw Exception(Error::InvalidArgument, "buffer region out of range");
-
- return ptr + offset;
- }
-
- void unmap(void* mappedPtr) {}
-
- Device* getDevice() { return device.get(); }
- };
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/common.h b/thirdparty/oidn/core/common.h
deleted file mode 100644
index a35dd908b4..0000000000
--- a/thirdparty/oidn/core/common.h
+++ /dev/null
@@ -1,136 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "common/platform.h"
-
-#include "mkl-dnn/include/mkldnn.hpp"
-#include "mkl-dnn/include/mkldnn_debug.h"
-#include "mkl-dnn/src/common/mkldnn_thread.hpp"
-#include "mkl-dnn/src/common/type_helpers.hpp"
-#include "mkl-dnn/src/cpu/jit_generator.hpp"
-
-#include "common/ref.h"
-#include "common/exception.h"
-#include "common/thread.h"
-// -- GODOT start --
-//#include "common/tasking.h"
-// -- GODOT end --
-#include "math.h"
-
-namespace oidn {
-
- using namespace mkldnn;
- using namespace mkldnn::impl::cpu;
- using mkldnn::impl::parallel_nd;
- using mkldnn::impl::memory_desc_matches_tag;
-
-
- inline size_t getFormatBytes(Format format)
- {
- switch (format)
- {
- case Format::Undefined: return 1;
- case Format::Float: return sizeof(float);
- case Format::Float2: return sizeof(float)*2;
- case Format::Float3: return sizeof(float)*3;
- case Format::Float4: return sizeof(float)*4;
- }
- assert(0);
- return 0;
- }
-
-
- inline memory::dims getTensorDims(const std::shared_ptr<memory>& mem)
- {
- const mkldnn_memory_desc_t& desc = mem->get_desc().data;
- return memory::dims(&desc.dims[0], &desc.dims[desc.ndims]);
- }
-
- inline memory::data_type getTensorType(const std::shared_ptr<memory>& mem)
- {
- const mkldnn_memory_desc_t& desc = mem->get_desc().data;
- return memory::data_type(desc.data_type);
- }
-
- // Returns the number of values in a tensor
- inline size_t getTensorSize(const memory::dims& dims)
- {
- size_t res = 1;
- for (int i = 0; i < (int)dims.size(); ++i)
- res *= dims[i];
- return res;
- }
-
- inline memory::dims getMaxTensorDims(const std::vector<memory::dims>& dims)
- {
- memory::dims result;
- size_t maxSize = 0;
-
- for (const auto& d : dims)
- {
- const size_t size = getTensorSize(d);
- if (size > maxSize)
- {
- result = d;
- maxSize = size;
- }
- }
-
- return result;
- }
-
- inline size_t getTensorSize(const std::shared_ptr<memory>& mem)
- {
- return getTensorSize(getTensorDims(mem));
- }
-
-
- template<int K>
- inline int getPadded(int dim)
- {
- return (dim + (K-1)) & ~(K-1);
- }
-
- template<int K>
- inline memory::dims getPadded_nchw(const memory::dims& dims)
- {
- assert(dims.size() == 4);
- memory::dims padDims = dims;
- padDims[1] = getPadded<K>(dims[1]); // pad C
- return padDims;
- }
-
-
- template<int K>
- struct BlockedFormat;
-
- template<>
- struct BlockedFormat<8>
- {
- static constexpr memory::format_tag nChwKc = memory::format_tag::nChw8c;
- static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw8i8o;
- };
-
- template<>
- struct BlockedFormat<16>
- {
- static constexpr memory::format_tag nChwKc = memory::format_tag::nChw16c;
- static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw16i16o;
- };
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/device.cpp b/thirdparty/oidn/core/device.cpp
deleted file mode 100644
index 3cd658b9c8..0000000000
--- a/thirdparty/oidn/core/device.cpp
+++ /dev/null
@@ -1,238 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#include "device.h"
-#include "autoencoder.h"
-
-namespace oidn {
-
- thread_local Device::ErrorState Device::globalError;
-
- Device::Device()
- {
- if (!mayiuse(sse41))
- throw Exception(Error::UnsupportedHardware, "SSE4.1 support is required at minimum");
- }
-
- Device::~Device()
- {
- // -- GODOT start --
- //observer.reset();
- // -- GODOT end --
- }
-
- void Device::setError(Device* device, Error code, const std::string& message)
- {
- // Update the stored error only if the previous error was queried
- if (device)
- {
- ErrorState& curError = device->error.get();
-
- if (curError.code == Error::None)
- {
- curError.code = code;
- curError.message = message;
- }
-
- // Print the error message in verbose mode
- if (device->isVerbose())
- std::cerr << "Error: " << message << std::endl;
-
- // Call the error callback function
- ErrorFunction errorFunc;
- void* errorUserPtr;
-
- {
- std::lock_guard<std::mutex> lock(device->mutex);
- errorFunc = device->errorFunc;
- errorUserPtr = device->errorUserPtr;
- }
-
- if (errorFunc)
- errorFunc(errorUserPtr, code, (code == Error::None) ? nullptr : message.c_str());
- }
- else
- {
- if (globalError.code == Error::None)
- {
- globalError.code = code;
- globalError.message = message;
- }
- }
- }
-
- Error Device::getError(Device* device, const char** outMessage)
- {
- // Return and clear the stored error code, but keep the error message so pointers to it will
- // remain valid until the next getError call
- if (device)
- {
- ErrorState& curError = device->error.get();
- const Error code = curError.code;
- if (outMessage)
- *outMessage = (code == Error::None) ? nullptr : curError.message.c_str();
- curError.code = Error::None;
- return code;
- }
- else
- {
- const Error code = globalError.code;
- if (outMessage)
- *outMessage = (code == Error::None) ? nullptr : globalError.message.c_str();
- globalError.code = Error::None;
- return code;
- }
- }
-
- void Device::setErrorFunction(ErrorFunction func, void* userPtr)
- {
- errorFunc = func;
- errorUserPtr = userPtr;
- }
-
- int Device::get1i(const std::string& name)
- {
- if (name == "numThreads")
- return numThreads;
- else if (name == "setAffinity")
- return setAffinity;
- else if (name == "verbose")
- return verbose;
- else if (name == "version")
- return OIDN_VERSION;
- else if (name == "versionMajor")
- return OIDN_VERSION_MAJOR;
- else if (name == "versionMinor")
- return OIDN_VERSION_MINOR;
- else if (name == "versionPatch")
- return OIDN_VERSION_PATCH;
- else
- throw Exception(Error::InvalidArgument, "invalid parameter");
- }
-
- void Device::set1i(const std::string& name, int value)
- {
- if (name == "numThreads")
- numThreads = value;
- else if (name == "setAffinity")
- setAffinity = value;
- else if (name == "verbose")
- {
- verbose = value;
- error.verbose = value;
- }
-
- dirty = true;
- }
-
- void Device::commit()
- {
- if (isCommitted())
- throw Exception(Error::InvalidOperation, "device can be committed only once");
-
- // -- GODOT start --
- #if 0
- // -- GODOT end --
- // Get the optimal thread affinities
- if (setAffinity)
- {
- affinity = std::make_shared<ThreadAffinity>(1, verbose); // one thread per core
- if (affinity->getNumThreads() == 0)
- affinity.reset();
- }
-
- // Create the task arena
- const int maxNumThreads = affinity ? affinity->getNumThreads() : tbb::this_task_arena::max_concurrency();
- numThreads = (numThreads > 0) ? min(numThreads, maxNumThreads) : maxNumThreads;
- arena = std::make_shared<tbb::task_arena>(numThreads);
-
- // Automatically set the thread affinities
- if (affinity)
- observer = std::make_shared<PinningObserver>(affinity, *arena);
- // -- GODOT start --
- #endif
- numThreads = 1;
- // -- GODOT end --
- dirty = false;
-
- if (isVerbose())
- print();
- }
-
- void Device::checkCommitted()
- {
- if (dirty)
- throw Exception(Error::InvalidOperation, "changes to the device are not committed");
- }
-
- Ref<Buffer> Device::newBuffer(size_t byteSize)
- {
- checkCommitted();
- return makeRef<Buffer>(Ref<Device>(this), byteSize);
- }
-
- Ref<Buffer> Device::newBuffer(void* ptr, size_t byteSize)
- {
- checkCommitted();
- return makeRef<Buffer>(Ref<Device>(this), ptr, byteSize);
- }
-
- Ref<Filter> Device::newFilter(const std::string& type)
- {
- checkCommitted();
-
- if (isVerbose())
- std::cout << "Filter: " << type << std::endl;
-
- Ref<Filter> filter;
-
-// -- GODOT start --
-// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
-#if 0
-// -- GODOT end --
- if (type == "RT")
- filter = makeRef<RTFilter>(Ref<Device>(this));
-// -- GODOT start --
-// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
-#endif
- if (type == "RTLightmap")
-// -- GODOT end --
- filter = makeRef<RTLightmapFilter>(Ref<Device>(this));
- else
- throw Exception(Error::InvalidArgument, "unknown filter type");
-
- return filter;
- }
-
- void Device::print()
- {
- std::cout << std::endl;
-
- std::cout << "Intel(R) Open Image Denoise " << OIDN_VERSION_STRING << std::endl;
- std::cout << " Compiler: " << getCompilerName() << std::endl;
- std::cout << " Build : " << getBuildName() << std::endl;
- std::cout << " Platform: " << getPlatformName() << std::endl;
-
-// -- GODOT start --
-// std::cout << " Tasking :";
-// std::cout << " TBB" << TBB_VERSION_MAJOR << "." << TBB_VERSION_MINOR;
-// std::cout << " TBB_header_interface_" << TBB_INTERFACE_VERSION << " TBB_lib_interface_" << tbb::TBB_runtime_interface_version();
-// std::cout << std::endl;
-// -- GODOT end --
- std::cout << std::endl;
- }
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/device.h b/thirdparty/oidn/core/device.h
deleted file mode 100644
index d9cfd8541a..0000000000
--- a/thirdparty/oidn/core/device.h
+++ /dev/null
@@ -1,102 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "common.h"
-
-namespace oidn {
-
- class Buffer;
- class Filter;
-
- class Device : public RefCount, public Verbose
- {
- private:
- // Thread-safety
- std::mutex mutex;
-
- // Error handling
- struct ErrorState
- {
- Error code = Error::None;
- std::string message;
- };
-
- static thread_local ErrorState globalError;
- ThreadLocal<ErrorState> error;
- ErrorFunction errorFunc = nullptr;
- void* errorUserPtr = nullptr;
-
-// -- GODOT start --
-// // Tasking
-// std::shared_ptr<tbb::task_arena> arena;
-// std::shared_ptr<PinningObserver> observer;
-// std::shared_ptr<ThreadAffinity> affinity;
-// -- GODOT end --
-
- // Parameters
- int numThreads = 0; // autodetect by default
- bool setAffinity = true;
-
- bool dirty = true;
-
- public:
- Device();
- ~Device();
-
- static void setError(Device* device, Error code, const std::string& message);
- static Error getError(Device* device, const char** outMessage);
-
- void setErrorFunction(ErrorFunction func, void* userPtr);
-
- int get1i(const std::string& name);
- void set1i(const std::string& name, int value);
-
- void commit();
-
-// -- GODOT start --
-// template<typename F>
-// void executeTask(F& f)
-// {
-// arena->execute(f);
-// }
-
-// template<typename F>
-// void executeTask(const F& f)
-// {
-// arena->execute(f);
-// }
-// -- GODOT end --
-
- Ref<Buffer> newBuffer(size_t byteSize);
- Ref<Buffer> newBuffer(void* ptr, size_t byteSize);
- Ref<Filter> newFilter(const std::string& type);
-
- __forceinline Device* getDevice() { return this; }
- __forceinline std::mutex& getMutex() { return mutex; }
-
- private:
-// -- GODOT start --
- //bool isCommitted() const { return bool(arena); }
- bool isCommitted() const { return false; }
-// -- GODOT end --
- void checkCommitted();
-
- void print();
- };
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/filter.cpp b/thirdparty/oidn/core/filter.cpp
deleted file mode 100644
index ec1f10af87..0000000000
--- a/thirdparty/oidn/core/filter.cpp
+++ /dev/null
@@ -1,27 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#include "filter.h"
-
-namespace oidn {
-
- void Filter::setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr)
- {
- progressFunc = func;
- progressUserPtr = userPtr;
- }
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/filter.h b/thirdparty/oidn/core/filter.h
deleted file mode 100644
index 935fa202f4..0000000000
--- a/thirdparty/oidn/core/filter.h
+++ /dev/null
@@ -1,52 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "common.h"
-#include "device.h"
-#include "image.h"
-
-namespace oidn {
-
- class Filter : public RefCount
- {
- protected:
- Ref<Device> device;
-
- ProgressMonitorFunction progressFunc = nullptr;
- void* progressUserPtr = nullptr;
-
- bool dirty = true;
-
- public:
- explicit Filter(const Ref<Device>& device) : device(device) {}
-
- virtual void setImage(const std::string& name, const Image& data) = 0;
- virtual void set1i(const std::string& name, int value) = 0;
- virtual int get1i(const std::string& name) = 0;
- virtual void set1f(const std::string& name, float value) = 0;
- virtual float get1f(const std::string& name) = 0;
-
- void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr);
-
- virtual void commit() = 0;
- virtual void execute() = 0;
-
- Device* getDevice() { return device.get(); }
- };
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/image.h b/thirdparty/oidn/core/image.h
deleted file mode 100644
index 748f49c4e5..0000000000
--- a/thirdparty/oidn/core/image.h
+++ /dev/null
@@ -1,111 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "common.h"
-#include "buffer.h"
-
-namespace oidn {
-
- struct Image
- {
- static constexpr int maxSize = 65536;
-
- char* ptr; // pointer to the first pixel
- int width; // width in number of pixels
- int height; // height in number of pixels
- size_t bytePixelStride; // pixel stride in number of *bytes*
- size_t rowStride; // row stride in number of *pixel strides*
- Format format; // pixel format
- Ref<Buffer> buffer; // buffer containing the image data
-
- Image() : ptr(nullptr), width(0), height(0), bytePixelStride(0), rowStride(0), format(Format::Undefined) {}
-
- Image(void* ptr, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride)
- {
- if (ptr == nullptr)
- throw Exception(Error::InvalidArgument, "buffer pointer null");
-
- init((char*)ptr + byteOffset, format, width, height, inBytePixelStride, inByteRowStride);
- }
-
- Image(const Ref<Buffer>& buffer, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride)
- {
- init(buffer->data() + byteOffset, format, width, height, inBytePixelStride, inByteRowStride);
-
- if (byteOffset + height * rowStride * bytePixelStride > buffer->size())
- throw Exception(Error::InvalidArgument, "buffer region out of range");
- }
-
- void init(char* ptr, Format format, int width, int height, size_t inBytePixelStride, size_t inByteRowStride)
- {
- assert(width >= 0);
- assert(height >= 0);
- if (width > maxSize || height > maxSize)
- throw Exception(Error::InvalidArgument, "image size too large");
-
- this->ptr = ptr;
- this->width = width;
- this->height = height;
-
- const size_t pixelSize = getFormatBytes(format);
- if (inBytePixelStride != 0)
- {
- if (inBytePixelStride < pixelSize)
- throw Exception(Error::InvalidArgument, "pixel stride smaller than pixel size");
-
- this->bytePixelStride = inBytePixelStride;
- }
- else
- {
- this->bytePixelStride = pixelSize;
- }
-
- if (inByteRowStride != 0)
- {
- if (inByteRowStride < width * this->bytePixelStride)
- throw Exception(Error::InvalidArgument, "row stride smaller than width * pixel stride");
- if (inByteRowStride % this->bytePixelStride != 0)
- throw Exception(Error::InvalidArgument, "row stride not integer multiple of pixel stride");
-
- this->rowStride = inByteRowStride / this->bytePixelStride;
- }
- else
- {
- this->rowStride = width;
- }
-
- this->format = format;
- }
-
- __forceinline char* get(int y, int x)
- {
- return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride);
- }
-
- __forceinline const char* get(int y, int x) const
- {
- return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride);
- }
-
- operator bool() const
- {
- return ptr != nullptr;
- }
- };
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/input_reorder.h b/thirdparty/oidn/core/input_reorder.h
deleted file mode 100644
index 966856afe9..0000000000
--- a/thirdparty/oidn/core/input_reorder.h
+++ /dev/null
@@ -1,232 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "node.h"
-#include "image.h"
-
-namespace oidn {
-
- // Input reorder node
- template<int K, class TransferFunction>
- class InputReorderNode : public Node
- {
- private:
- // Source
- Image color;
- Image albedo;
- Image normal;
-
- // Destination
- std::shared_ptr<memory> dst;
- float* dstPtr;
- int C2;
- int H2;
- int W2;
-
- // Tile
- int h1Begin;
- int w1Begin;
- int h2Begin;
- int w2Begin;
- int H;
- int W;
-
- std::shared_ptr<TransferFunction> transferFunc;
-
- public:
- InputReorderNode(const Image& color,
- const Image& albedo,
- const Image& normal,
- const std::shared_ptr<memory>& dst,
- const std::shared_ptr<TransferFunction>& transferFunc)
- : color(color), albedo(albedo), normal(normal),
- dst(dst),
- h1Begin(0), w1Begin(0),
- H(color.height), W(color.width),
- transferFunc(transferFunc)
- {
- const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
- assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
- assert(dstDesc.ndims == 4);
- assert(dstDesc.data_type == memory::data_type::f32);
- assert(dstDesc.dims[0] == 1);
- //assert(dstDesc.dims[1] >= getPadded<K>(C1));
-
- dstPtr = (float*)dst->get_data_handle();
- C2 = dstDesc.dims[1];
- H2 = dstDesc.dims[2];
- W2 = dstDesc.dims[3];
- }
-
- void setTile(int h1, int w1, int h2, int w2, int H, int W) override
- {
- h1Begin = h1;
- w1Begin = w1;
- h2Begin = h2;
- w2Begin = w2;
- this->H = H;
- this->W = W;
- }
-
- void execute(stream& sm) override
- {
- assert(H + h1Begin <= color.height);
- assert(W + w1Begin <= color.width);
- assert(H + h2Begin <= H2);
- assert(W + w2Begin <= W2);
-
- parallel_nd(H2, [&](int h2)
- {
- const int h = h2 - h2Begin;
-
- if (h >= 0 && h < H)
- {
- const int h1 = h + h1Begin;
-
- // Zero pad
- for (int w2 = 0; w2 < w2Begin; ++w2)
- {
- int c = 0;
- while (c < C2)
- store(h2, w2, c, 0.f);
- }
-
- // Reorder
- for (int w = 0; w < W; ++w)
- {
- const int w1 = w + w1Begin;
- const int w2 = w + w2Begin;
-
- int c = 0;
- storeColor(h2, w2, c, (float*)color.get(h1, w1));
- if (albedo)
- storeAlbedo(h2, w2, c, (float*)albedo.get(h1, w1));
- if (normal)
- storeNormal(h2, w2, c, (float*)normal.get(h1, w1));
- while (c < C2)
- store(h2, w2, c, 0.f);
- }
-
- // Zero pad
- for (int w2 = W + w2Begin; w2 < W2; ++w2)
- {
- int c = 0;
- while (c < C2)
- store(h2, w2, c, 0.f);
- }
- }
- else
- {
- // Zero pad
- for (int w2 = 0; w2 < W2; ++w2)
- {
- int c = 0;
- while (c < C2)
- store(h2, w2, c, 0.f);
- }
- }
- });
- }
-
- std::shared_ptr<memory> getDst() const override { return dst; }
-
- private:
- // Stores a single value
- __forceinline void store(int h, int w, int& c, float value)
- {
- // Destination is in nChwKc format
- float* dst_c = dstPtr + (H2*W2*K*(c/K)) + h*W2*K + w*K + (c%K);
- *dst_c = value;
- c++;
- }
-
- // Stores a color
- __forceinline void storeColor(int h, int w, int& c, const float* values)
- {
- #pragma unroll
- for (int i = 0; i < 3; ++i)
- {
- // Load the value
- float x = values[i];
-
- // Sanitize the value
- x = maxSafe(x, 0.f);
-
- // Apply the transfer function
- x = transferFunc->forward(x);
-
- // Store the value
- store(h, w, c, x);
- }
- }
-
- // Stores an albedo
- __forceinline void storeAlbedo(int h, int w, int& c, const float* values)
- {
- #pragma unroll
- for (int i = 0; i < 3; ++i)
- {
- // Load the value
- float x = values[i];
-
- // Sanitize the value
- x = clampSafe(x, 0.f, 1.f);
-
- // Store the value
- store(h, w, c, x);
- }
- }
-
- // Stores a normal
- __forceinline void storeNormal(int h, int w, int& c, const float* values)
- {
- // Load the normal
- float x = values[0];
- float y = values[1];
- float z = values[2];
-
- // Compute the length of the normal
- const float lengthSqr = sqr(x) + sqr(y) + sqr(z);
-
- // Normalize the normal and transform it to [0..1]
- if (isfinite(lengthSqr))
- {
- const float invLength = (lengthSqr > minVectorLengthSqr) ? rsqrt(lengthSqr) : 1.f;
-
- const float scale = invLength * 0.5f;
- const float offset = 0.5f;
-
- x = x * scale + offset;
- y = y * scale + offset;
- z = z * scale + offset;
- }
- else
- {
- x = 0.f;
- y = 0.f;
- z = 0.f;
- }
-
- // Store the normal
- store(h, w, c, x);
- store(h, w, c, y);
- store(h, w, c, z);
- }
- };
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/math.h b/thirdparty/oidn/core/math.h
deleted file mode 100644
index a844ef0d1d..0000000000
--- a/thirdparty/oidn/core/math.h
+++ /dev/null
@@ -1,78 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "common/platform.h"
-
-namespace oidn {
-
- constexpr float minVectorLength = 1e-10f;
- constexpr float minVectorLengthSqr = minVectorLength * minVectorLength;
-
- using std::log;
- using std::log2;
- using std::exp;
- using std::exp2;
- using std::pow;
- using std::isfinite;
- using std::isnan;
-
- __forceinline float sqr(float x)
- {
- return x * x;
- }
-
- __forceinline float rcp(float x)
- {
- __m128 r = _mm_rcp_ss(_mm_set_ss(x));
- return _mm_cvtss_f32(_mm_sub_ss(_mm_add_ss(r, r), _mm_mul_ss(_mm_mul_ss(r, r), _mm_set_ss(x))));
- }
-
- __forceinline float rsqrt(float x)
- {
- __m128 r = _mm_rsqrt_ss(_mm_set_ss(x));
- return _mm_cvtss_f32(_mm_add_ss(_mm_mul_ss(_mm_set_ss(1.5f), r),
- _mm_mul_ss(_mm_mul_ss(_mm_mul_ss(_mm_set_ss(x), _mm_set_ss(-0.5f)), r), _mm_mul_ss(r, r))));
- }
-
- __forceinline float maxSafe(float value, float minValue)
- {
- return isfinite(value) ? max(value, minValue) : minValue;
- }
-
- __forceinline float clampSafe(float value, float minValue, float maxValue)
- {
- return isfinite(value) ? clamp(value, minValue, maxValue) : minValue;
- }
-
- // Returns ceil(a / b) for non-negative integers
- template<class Int>
- __forceinline constexpr Int ceilDiv(Int a, Int b)
- {
- //assert(a >= 0);
- //assert(b > 0);
- return (a + b - 1) / b;
- }
-
- // Returns a rounded up to multiple of b
- template<class Int>
- __forceinline constexpr Int roundUp(Int a, Int b)
- {
- return ceilDiv(a, b) * b;
- }
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/network.cpp b/thirdparty/oidn/core/network.cpp
deleted file mode 100644
index ed8328c954..0000000000
--- a/thirdparty/oidn/core/network.cpp
+++ /dev/null
@@ -1,436 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#include "upsample.h"
-#include "weights_reorder.h"
-#include "network.h"
-// -- GODOT start --
-#include <cstring>
-// -- GODOT end --
-
-namespace oidn {
-
- template<int K>
- Network<K>::Network(const Ref<Device>& device, const std::map<std::string, Tensor>& weightMap)
- : device(device),
- eng(engine::cpu, 0),
- sm(eng),
- weightMap(weightMap)
- {
- }
-
- template<int K>
- void Network<K>::execute(const Progress& progress, int taskIndex)
- {
- if (progress.func)
- {
- const double value = double(taskIndex) / double(progress.taskCount);
- if (!progress.func(progress.userPtr, value))
- throw Exception(Error::Cancelled, "execution was cancelled");
- }
-
- for (size_t i = 0; i < nodes.size(); ++i)
- {
- nodes[i]->execute(sm);
-
- if (progress.func)
- {
- const double value = (double(taskIndex) + double(i+1) / double(nodes.size())) / double(progress.taskCount);
- if (!progress.func(progress.userPtr, value))
- throw Exception(Error::Cancelled, "execution was cancelled");
- }
- }
- }
-
- template<int K>
- std::shared_ptr<memory> Network<K>::allocTensor(const memory::dims& dims,
- memory::format_tag format,
- void* data)
- {
- if (format == memory::format_tag::any)
- {
- if (dims.size() == 4)
- format = BlockedFormat<K>::nChwKc;
- else if (dims.size() == 1)
- format = memory::format_tag::x;
- else
- assert(0);
- }
- memory::desc desc(dims, memory::data_type::f32, format);
- if (data == nullptr)
- {
- const size_t bytes = getTensorSize(dims) * sizeof(float);
- if (format == BlockedFormat<K>::nChwKc)
- activationAllocBytes += bytes;
- totalAllocBytes += bytes;
-
- return std::make_shared<memory>(desc, eng);
- }
- else
- {
- return std::make_shared<memory>(desc, eng, data);
- }
- }
-
- template<int K>
- std::shared_ptr<memory> Network<K>::castTensor(const memory::dims& dims,
- const std::shared_ptr<memory>& src,
- size_t srcOffset,
- memory::format_tag format)
- {
- const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
- MAYBE_UNUSED(srcDesc);
- assert(srcDesc.data_type == memory::data_type::f32);
- assert(getTensorSize(src) >= srcOffset + getTensorSize(dims));
-
- if (format == memory::format_tag::any)
- {
- if (dims.size() == 4)
- format = BlockedFormat<K>::nChwKc;
- else if (dims.size() == 1)
- format = memory::format_tag::x;
- else
- assert(0);
- }
- memory::desc desc(dims, memory::data_type::f32, format);
- float* srcPtr = (float*)src->get_data_handle() + srcOffset;
- return std::make_shared<memory>(desc, eng, srcPtr);
- }
-
- template<int K>
- std::shared_ptr<memory> Network<K>::castTensor(const memory::dims& dims,
- const std::shared_ptr<memory>& src,
- const memory::dims& srcOffset)
- {
- return castTensor(dims, src, getTensorSize(srcOffset));
- }
-
- template<int K>
- void Network<K>::zeroTensor(const std::shared_ptr<memory>& dst)
- {
- assert(getTensorType(dst) == memory::data_type::f32);
- memset(dst->get_data_handle(), 0, getTensorSize(dst)*sizeof(float));
- }
-
- template<int K>
- memory::dims Network<K>::getInputReorderDims(const memory::dims& srcDims, int alignment)
- {
- memory::dims dstDims = srcDims;
- dstDims[1] = getPadded<K>(srcDims[1]); // round up C
- dstDims[2] = roundUp(srcDims[2], memory::dim(alignment)); // round up H
- dstDims[3] = roundUp(srcDims[3], memory::dim(alignment)); // round up W
- return dstDims;
- }
-
- template<int K>
- std::shared_ptr<Node> Network<K>::addInputReorder(const Image& color,
- const Image& albedo,
- const Image& normal,
- const std::shared_ptr<TransferFunction>& transferFunc,
- int alignment,
- const std::shared_ptr<memory>& userDst)
- {
- assert(color);
- int inputC = 3;
- if (albedo) inputC += 3;
- if (normal) inputC += 3;
-
- memory::dims srcDims = {1, inputC, color.height, color.width};
- memory::dims dstDims = getInputReorderDims(srcDims, alignment);
-
- // Allocate padded memory
- auto dst = userDst;
- if (!dst)
- dst = allocTensor(dstDims);
-
- // Push node
- std::shared_ptr<Node> node;
-
- if (auto tf = std::dynamic_pointer_cast<LinearTransferFunction>(transferFunc))
- node = std::make_shared<InputReorderNode<K, LinearTransferFunction>>(color, albedo, normal, dst, tf);
- else if (auto tf = std::dynamic_pointer_cast<GammaTransferFunction>(transferFunc))
- node = std::make_shared<InputReorderNode<K, GammaTransferFunction>>(color, albedo, normal, dst, tf);
- else if (auto tf = std::dynamic_pointer_cast<LogTransferFunction>(transferFunc))
- node = std::make_shared<InputReorderNode<K, LogTransferFunction>>(color, albedo, normal, dst, tf);
- else if (auto tf = std::dynamic_pointer_cast<PQXTransferFunction>(transferFunc))
- node = std::make_shared<InputReorderNode<K, PQXTransferFunction>>(color, albedo, normal, dst, tf);
- else
- assert(0);
-
- nodes.push_back(node);
- return node;
- }
-
- template<int K>
- std::shared_ptr<Node> Network<K>::addOutputReorder(const std::shared_ptr<memory>& src,
- const std::shared_ptr<TransferFunction>& transferFunc,
- const Image& output)
- {
- memory::dims srcDims = getTensorDims(src);
- assert(srcDims[1] == K);
-
- // Push node
- std::shared_ptr<Node> node;
-
- if (auto tf = std::dynamic_pointer_cast<LinearTransferFunction>(transferFunc))
- node = std::make_shared<OutputReorderNode<K, LinearTransferFunction>>(src, output, tf);
- else if (auto tf = std::dynamic_pointer_cast<GammaTransferFunction>(transferFunc))
- node = std::make_shared<OutputReorderNode<K, GammaTransferFunction>>(src, output, tf);
- else if (auto tf = std::dynamic_pointer_cast<LogTransferFunction>(transferFunc))
- node = std::make_shared<OutputReorderNode<K, LogTransferFunction>>(src, output, tf);
- else if (auto tf = std::dynamic_pointer_cast<PQXTransferFunction>(transferFunc))
- node = std::make_shared<OutputReorderNode<K, PQXTransferFunction>>(src, output, tf);
- else
- assert(0);
-
- nodes.push_back(node);
- return node;
- }
-
- template<int K>
- memory::dims Network<K>::getConvDims(const std::string& name, const memory::dims& srcDims)
- {
- auto b = weightMap[name + "/b"];
- memory::dims dstDims = srcDims;
- dstDims[1] = getPadded<K>(b.dims[0]); // dstDims[C] = getPadded(OC)
- return dstDims;
- }
-
- template<int K>
- std::shared_ptr<Node> Network<K>::addConv(const std::string& name,
- const std::shared_ptr<memory>& src,
- const std::shared_ptr<memory>& userDst,
- bool relu)
- {
- const memory::dims strides = {1, 1};
- const memory::dims padding = {1, 1};
-
- memory::dims srcDims = getTensorDims(src);
-
- // Get the weights
- const auto& W = weightMap[name + "/W"];
- if (W.ndims() != 4 || W.format != "oihw")
- throw Exception(Error::InvalidOperation, "invalid convolution weights");
- memory::dims weightsDims = W.dims;
- auto userWeights = allocTensor(weightsDims, memory::format_tag::oihw, W.data);
-
- // Pad the weights
- memory::dims weightsPadDims = weightsDims;
- weightsPadDims[1] = getPadded<K>(weightsDims[1]); // IC
- weightsPadDims[0] = getPadded<K>(weightsDims[0]); // OC
- assert(srcDims[1] == weightsPadDims[1]); // srcDims[C] == weightsPadDims[IC]
- auto weightsPad = allocTensor(weightsPadDims, memory::format_tag::oihw);
- WeightsReorderNode<K>(userWeights, weightsPad).execute(sm);
-
- // Get the biases
- const auto& b = weightMap[name + "/b"];
- if (b.ndims() != 1)
- throw Exception(Error::InvalidOperation, "invalid convolution biases");
- memory::dims biasDims = b.dims;
-
- // Copy/pad the biases
- memory::dims biasPadDims = {getPadded<K>(biasDims[0])};
- auto bias = allocTensor(biasPadDims);
- if (biasDims[0] != biasPadDims[0])
- memset(bias->get_data_handle(), 0, biasPadDims[0]*sizeof(float));
- memcpy(bias->get_data_handle(), b.data, biasDims[0]*sizeof(float));
-
- // Allocate memory for destination
- memory::dims dstDims = srcDims;
- dstDims[1] = weightsPadDims[0]; // dstDims[C] = weightsPadDims[OC]
-
- std::shared_ptr<memory> dst;
- if (!userDst)
- dst = allocTensor(dstDims);
- else if (getTensorDims(userDst) == dstDims)
- dst = userDst;
- else
- dst = castTensor(dstDims, userDst);
-
- // Create a convolution
- // Let the convolution primitive choose the weights format
- auto weightsDesc = memory::desc({ weightsPadDims }, memory::data_type::f32, memory::format_tag::any);
-
- auto convAlgo = (K == 16) ? convolution_winograd : convolution_direct;
- auto convDesc = convolution_forward::desc(
- prop_kind::forward_inference, convAlgo,
- src->get_desc(),
- weightsDesc,
- bias->get_desc(),
- dst->get_desc(),
- strides, padding, padding, padding_kind::zero);
-
- // Incorporate relu
- mkldnn::primitive_attr convAttr;
- if (relu)
- {
- mkldnn::post_ops ops;
- ops.append_eltwise(
- 1.f, // scale factor, not used
- algorithm::eltwise_relu,
- 0.f, // max with
- 0.f // unused
- );
- convAttr.set_post_ops(ops);
- }
- convAttr.set_scratchpad_mode(scratchpad_mode_user);
-
- auto convPrimDesc = convolution_forward::primitive_desc(convDesc, convAttr, eng);
-
- // Reorder the weights to the final format, if necessary
- auto weights = weightsPad;
- if (convPrimDesc.weights_desc() != weightsPad->get_desc())
- {
- weights = std::make_shared<memory>(convPrimDesc.weights_desc(), eng);
- ReorderNode(weightsPad, weights).execute(sm);
- }
-
- // Create convolution node and add it to the net
- auto node = std::make_shared<ConvNode>(convPrimDesc, src, weights, bias, dst);
- nodes.push_back(node);
- return node;
- }
-
- template<int K>
- memory::dims Network<K>::getPoolDims(const memory::dims& srcDims)
- {
- memory::dims dstDims = srcDims;
- dstDims[2] /= 2; // H/2
- dstDims[3] /= 2; // W/2
- return dstDims;
- }
-
- template<int K>
- std::shared_ptr<Node> Network<K>::addPool(const std::shared_ptr<memory>& src,
- const std::shared_ptr<memory>& userDst)
- {
- const memory::dims kernel = {2, 2};
- const memory::dims strides = {2, 2};
- const memory::dims padding = {0, 0};
-
- memory::dims srcDims = getTensorDims(src);
- memory::dims dstDims = getPoolDims(srcDims);
-
- std::shared_ptr<memory> dst;
- if (!userDst)
- dst = allocTensor(dstDims);
- else if (getTensorDims(userDst) == dstDims)
- dst = userDst;
- else
- dst = castTensor(dstDims, userDst);
-
- auto poolDesc = pooling_forward::desc(
- prop_kind::forward_inference, pooling_max,
- src->get_desc(),
- dst->get_desc(),
- strides, kernel, padding, padding, padding_kind::zero);
-
- mkldnn::primitive_attr poolAttr;
- poolAttr.set_scratchpad_mode(scratchpad_mode_user);
-
- auto poolPrimDesc = pooling_forward::primitive_desc(poolDesc, poolAttr, eng);
-
- auto node = std::make_shared<PoolNode>(poolPrimDesc, src, dst);
- nodes.push_back(node);
- return node;
- }
-
- template<int K>
- memory::dims Network<K>::getUpsampleDims(const memory::dims& srcDims)
- {
- memory::dims dstDims = srcDims;
- dstDims[2] *= 2; // H*2
- dstDims[3] *= 2; // W*2
- return dstDims;
- }
-
- template<int K>
- std::shared_ptr<Node> Network<K>::addUpsample(const std::shared_ptr<memory>& src,
- const std::shared_ptr<memory>& userDst)
- {
- memory::dims srcDims = getTensorDims(src);
- memory::dims dstDims = getUpsampleDims(srcDims);
-
- std::shared_ptr<memory> dst;
- if (!userDst)
- dst = allocTensor(dstDims);
- else if (getTensorDims(userDst) == dstDims)
- dst = userDst;
- else
- dst = castTensor(dstDims, userDst);
-
- // Create upsampling node and add it to net
- auto node = std::make_shared<UpsampleNode<K>>(src, dst);
- nodes.push_back(node);
- return node;
- }
-
- template<int K>
- memory::dims Network<K>::getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims)
- {
- assert(src1Dims[0] == src2Dims[0]); // N
- assert(src1Dims[2] == src2Dims[2]); // H
- assert(src1Dims[3] == src2Dims[3]); // W
-
- memory::dims dstDims = src1Dims;
- dstDims[1] += src2Dims[1]; // C
- return dstDims;
- }
-
- template<int K>
- std::shared_ptr<Node> Network<K>::addAutoexposure(const Image& color,
- const std::shared_ptr<HDRTransferFunction>& transferFunc)
- {
- auto node = std::make_shared<AutoexposureNode>(color, transferFunc);
- nodes.push_back(node);
- return node;
- }
-
- template <int K>
- void Network<K>::finalize()
- {
- // Compute the size of the scratchpad
- size_t scratchpadSize = 0;
- for (const auto& node : nodes)
- scratchpadSize = max(scratchpadSize, node->getScratchpadSize());
-
- // Allocate the scratchpad
- memory::dims scratchpadDims = { memory::dim(scratchpadSize) };
- memory::desc scratchpadDesc(scratchpadDims, memory::data_type::u8, memory::format_tag::x);
- auto scratchpad = std::make_shared<memory>(scratchpadDesc, eng);
- activationAllocBytes += scratchpadSize;
- totalAllocBytes += scratchpadSize;
-
- // Set the scratchpad for the nodes
- for (auto& node : nodes)
- node->setScratchpad(scratchpad);
-
- // Free the weights
- weightMap.clear();
-
- // Print statistics
- if (device->isVerbose(2))
- {
- std::cout << "Activation bytes: " << activationAllocBytes << std::endl;
- std::cout << "Scratchpad bytes: " << scratchpadSize << std::endl;
- std::cout << "Total bytes : " << totalAllocBytes << std::endl;
- }
- }
-
- template class Network<8>;
- template class Network<16>;
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/network.h b/thirdparty/oidn/core/network.h
deleted file mode 100644
index 7a696fd355..0000000000
--- a/thirdparty/oidn/core/network.h
+++ /dev/null
@@ -1,112 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#include "common/tensor.h"
-#include "image.h"
-#include "node.h"
-#include "input_reorder.h"
-#include "output_reorder.h"
-#include "transfer_function.h"
-
-#pragma once
-
-namespace oidn {
-
- // Progress state
- struct Progress
- {
- ProgressMonitorFunction func;
- void* userPtr;
- int taskCount;
- };
-
- class Executable
- {
- public:
- virtual ~Executable() {}
- virtual void execute(const Progress& progress, int taskIndex) = 0;
- };
-
- template<int K>
- class Network : public Executable
- {
- public:
- Network(const Ref<Device>& device, const std::map<std::string, Tensor>& weightMap);
-
- void execute(const Progress& progress, int taskIndex) override;
-
- std::shared_ptr<memory> allocTensor(const memory::dims& dims,
- memory::format_tag format = memory::format_tag::any,
- void* data = nullptr);
-
- std::shared_ptr<memory> castTensor(const memory::dims& dims,
- const std::shared_ptr<memory>& src,
- size_t srcOffset = 0,
- memory::format_tag format = memory::format_tag::any);
-
- std::shared_ptr<memory> castTensor(const memory::dims& dims,
- const std::shared_ptr<memory>& src,
- const memory::dims& srcOffset);
-
- void zeroTensor(const std::shared_ptr<memory>& dst);
-
- memory::dims getInputReorderDims(const memory::dims& srcDims, int alignment);
-
- std::shared_ptr<Node> addInputReorder(const Image& color,
- const Image& albedo,
- const Image& normal,
- const std::shared_ptr<TransferFunction>& transferFunc,
- int alignment,
- const std::shared_ptr<memory>& userDst = nullptr);
-
- std::shared_ptr<Node> addOutputReorder(const std::shared_ptr<memory>& src,
- const std::shared_ptr<TransferFunction>& transferFunc,
- const Image& output);
-
- memory::dims getConvDims(const std::string& name, const memory::dims& srcDims);
- std::shared_ptr<Node> addConv(const std::string& name,
- const std::shared_ptr<memory>& src,
- const std::shared_ptr<memory>& userDst = nullptr,
- bool relu = true);
-
- memory::dims getPoolDims(const memory::dims& srcDims);
- std::shared_ptr<Node> addPool(const std::shared_ptr<memory>& src,
- const std::shared_ptr<memory>& userDst = nullptr);
-
- memory::dims getUpsampleDims(const memory::dims& srcDims);
- std::shared_ptr<Node> addUpsample(const std::shared_ptr<memory>& src,
- const std::shared_ptr<memory>& userDst = nullptr);
-
- memory::dims getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims);
-
- std::shared_ptr<Node> addAutoexposure(const Image& color,
- const std::shared_ptr<HDRTransferFunction>& transferFunc);
-
- void finalize();
-
- private:
- Ref<Device> device;
- engine eng;
- stream sm;
- std::vector<std::shared_ptr<Node>> nodes;
- std::map<std::string, Tensor> weightMap;
-
- // Memory allocation statistics
- size_t activationAllocBytes = 0; // number of allocated activation bytes
- size_t totalAllocBytes = 0; // total number of allocated bytes
- };
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/node.h b/thirdparty/oidn/core/node.h
deleted file mode 100644
index b9ffe906df..0000000000
--- a/thirdparty/oidn/core/node.h
+++ /dev/null
@@ -1,142 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "common.h"
-#include <vector>
-
-namespace oidn {
-
- class Node
- {
- public:
- virtual ~Node() = default;
-
- virtual void execute(stream& sm) = 0;
-
- virtual std::shared_ptr<memory> getDst() const { return nullptr; }
-
- virtual size_t getScratchpadSize() const { return 0; }
- virtual void setScratchpad(const std::shared_ptr<memory>& mem) {}
-
- virtual void setTile(int h1, int w1, int h2, int w2, int H, int W)
- {
- assert(0); // not supported
- }
- };
-
- // Node wrapping an MKL-DNN primitive
- class MklNode : public Node
- {
- private:
- primitive prim;
- std::unordered_map<int, memory> args;
- std::shared_ptr<memory> scratchpad;
-
- public:
- MklNode(const primitive& prim, const std::unordered_map<int, memory>& args)
- : prim(prim),
- args(args)
- {}
-
- size_t getScratchpadSize() const override
- {
- const auto primDesc = prim.get_primitive_desc();
- const mkldnn_memory_desc_t* scratchpadDesc = mkldnn_primitive_desc_query_md(primDesc, mkldnn_query_scratchpad_md, 0);
- if (scratchpadDesc == nullptr)
- return 0;
- return mkldnn_memory_desc_get_size(scratchpadDesc);
- }
-
- void setScratchpad(const std::shared_ptr<memory>& mem) override
- {
- scratchpad = mem;
- args.insert(std::make_pair(MKLDNN_ARG_SCRATCHPAD, *scratchpad));
- }
-
- void execute(stream& sm) override
- {
- prim.execute(sm, args);
- }
- };
-
- // Convolution node
- class ConvNode : public MklNode
- {
- private:
- std::shared_ptr<memory> src;
- std::shared_ptr<memory> weights;
- std::shared_ptr<memory> bias;
- std::shared_ptr<memory> dst;
-
- public:
- ConvNode(const convolution_forward::primitive_desc& desc,
- const std::shared_ptr<memory>& src,
- const std::shared_ptr<memory>& weights,
- const std::shared_ptr<memory>& bias,
- const std::shared_ptr<memory>& dst)
- : MklNode(convolution_forward(desc),
- { { MKLDNN_ARG_SRC, *src },
- { MKLDNN_ARG_WEIGHTS, *weights },
- { MKLDNN_ARG_BIAS, *bias },
- { MKLDNN_ARG_DST, *dst } }),
- src(src), weights(weights), bias(bias), dst(dst)
- {}
-
- std::shared_ptr<memory> getDst() const override { return dst; }
- };
-
- // Pooling node
- class PoolNode : public MklNode
- {
- private:
- std::shared_ptr<memory> src;
- std::shared_ptr<memory> dst;
-
- public:
- PoolNode(const pooling_forward::primitive_desc& desc,
- const std::shared_ptr<memory>& src,
- const std::shared_ptr<memory>& dst)
- : MklNode(pooling_forward(desc),
- { { MKLDNN_ARG_SRC, *src },
- { MKLDNN_ARG_DST, *dst } }),
- src(src), dst(dst)
- {}
-
- std::shared_ptr<memory> getDst() const override { return dst; }
- };
-
- // Reorder node
- class ReorderNode : public MklNode
- {
- private:
- std::shared_ptr<memory> src;
- std::shared_ptr<memory> dst;
-
- public:
- ReorderNode(const std::shared_ptr<memory>& src,
- const std::shared_ptr<memory>& dst)
- : MklNode(reorder(reorder::primitive_desc(*src, *dst)),
- { { MKLDNN_ARG_SRC, *src },
- { MKLDNN_ARG_DST, *dst } }),
- src(src), dst(dst)
- {}
-
- std::shared_ptr<memory> getDst() const override { return dst; }
- };
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/output_reorder.h b/thirdparty/oidn/core/output_reorder.h
deleted file mode 100644
index 7918d48e15..0000000000
--- a/thirdparty/oidn/core/output_reorder.h
+++ /dev/null
@@ -1,126 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "node.h"
-#include "image.h"
-
-namespace oidn {
-
- // Output reorder node
- template<int K, class TransferFunction>
- class OutputReorderNode : public Node
- {
- private:
- // Source
- std::shared_ptr<memory> src;
- const float* srcPtr;
- int H1;
- int W1;
-
- // Destination
- Image output;
-
- // Tile
- int h1Begin;
- int w1Begin;
- int h2Begin;
- int w2Begin;
- int H;
- int W;
-
- std::shared_ptr<TransferFunction> transferFunc;
-
- public:
- OutputReorderNode(const std::shared_ptr<memory>& src,
- const Image& output,
- const std::shared_ptr<TransferFunction>& transferFunc)
- : src(src),
- output(output),
- h1Begin(0), w1Begin(0),
- h2Begin(0), w2Begin(0),
- H(output.height), W(output.width),
- transferFunc(transferFunc)
- {
- const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
- MAYBE_UNUSED(srcDesc);
- assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
- assert(srcDesc.ndims == 4);
- assert(srcDesc.data_type == memory::data_type::f32);
- assert(srcDesc.dims[0] == 1);
- // We assume output data is <= K OC
- assert(srcDesc.dims[1] == K);
-
- srcPtr = (float*)src->get_data_handle();
- H1 = srcDesc.dims[2];
- W1 = srcDesc.dims[3];
- }
-
- void setTile(int h1, int w1, int h2, int w2, int H, int W) override
- {
- h1Begin = h1;
- w1Begin = w1;
- h2Begin = h2;
- w2Begin = w2;
- this->H = H;
- this->W = W;
- }
-
- void execute(stream& sm) override
- {
- assert(h1Begin + H <= H1);
- assert(w1Begin + W <= W1);
- assert(h2Begin + H <= output.height);
- assert(w2Begin + W <= output.width);
-
- const int C1 = K;
-
- parallel_nd(H, [&](int h)
- {
- const int h1 = h + h1Begin;
- const int h2 = h + h2Begin;
-
- for (int w = 0; w < W; ++w)
- {
- const int w1 = w + w1Begin;
- const int w2 = w + w2Begin;
- float* dstPtr_C = (float*)output.get(h2, w2);
-
- // Source is in nChwKc format. In this case C is 1 so this is really nhwc
- const float* srcPtr_C = srcPtr + h1*W1*C1 + w1*C1;
-
- #pragma unroll
- for (int i = 0; i < 3; ++i)
- {
- // Load the value
- float x = srcPtr_C[i];
-
- // The CNN output may contain negative values or even NaNs, so it must be sanitized
- x = maxSafe(x, 0.f);
-
- // Apply the inverse transfer function
- x = transferFunc->inverse(x);
-
- // Sanitize and store the final value
- dstPtr_C[i] = max(x, 0.f);
- }
- }
- });
- }
- };
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/transfer_function.cpp b/thirdparty/oidn/core/transfer_function.cpp
deleted file mode 100644
index ce5deca56b..0000000000
--- a/thirdparty/oidn/core/transfer_function.cpp
+++ /dev/null
@@ -1,103 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#include "transfer_function.h"
-
-namespace oidn {
-
- const float LogTransferFunction::xScale = 1.f / log(LogTransferFunction::yMax + 1.f);
- const float PQXTransferFunction::xScale = 1.f / PQXTransferFunction::pqxForward(PQXTransferFunction::yMax * PQXTransferFunction::yScale);
-
- float AutoexposureNode::autoexposure(const Image& color)
- {
- assert(color.format == Format::Float3);
-
- constexpr float key = 0.18f;
- constexpr float eps = 1e-8f;
- constexpr int K = 16; // downsampling amount
-
- // Downsample the image to minimize sensitivity to noise
- const int H = color.height; // original height
- const int W = color.width; // original width
- const int HK = (H + K/2) / K; // downsampled height
- const int WK = (W + K/2) / K; // downsampled width
-
- // Compute the average log luminance of the downsampled image
- using Sum = std::pair<float, int>;
-
- // -- GODOT start --
- // Sum sum =
- // tbb::parallel_reduce(
- // tbb::blocked_range2d<int>(0, HK, 0, WK),
- // Sum(0.f, 0),
- // [&](const tbb::blocked_range2d<int>& r, Sum sum) -> Sum
- // {
- // // Iterate over blocks
- // for (int i = r.rows().begin(); i != r.rows().end(); ++i)
- // {
- // for (int j = r.cols().begin(); j != r.cols().end(); ++j)
- // {
-
- Sum sum = Sum(0.0f, 0);
-
- for (int i = 0; i != HK; ++i)
- {
- for (int j = 0; j != WK; ++j)
- {
- // Compute the average luminance in the current block
- const int beginH = int(ptrdiff_t(i) * H / HK);
- const int beginW = int(ptrdiff_t(j) * W / WK);
- const int endH = int(ptrdiff_t(i+1) * H / HK);
- const int endW = int(ptrdiff_t(j+1) * W / WK);
-
- float L = 0.f;
-
- for (int h = beginH; h < endH; ++h)
- {
- for (int w = beginW; w < endW; ++w)
- {
- const float* rgb = (const float*)color.get(h, w);
-
- const float r = maxSafe(rgb[0], 0.f);
- const float g = maxSafe(rgb[1], 0.f);
- const float b = maxSafe(rgb[2], 0.f);
-
- L += luminance(r, g, b);
- }
- }
-
- L /= (endH - beginH) * (endW - beginW);
-
- // Accumulate the log luminance
- if (L > eps)
- {
- sum.first += log2(L);
- sum.second++;
- }
- }
- }
-
- // return sum;
- // },
- // [](Sum a, Sum b) -> Sum { return Sum(a.first+b.first, a.second+b.second); },
- // tbb::static_partitioner()
- // );
- // -- GODOT end --
-
- return (sum.second > 0) ? (key / exp2(sum.first / float(sum.second))) : 1.f;
- }
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/transfer_function.h b/thirdparty/oidn/core/transfer_function.h
deleted file mode 100644
index 35f2833092..0000000000
--- a/thirdparty/oidn/core/transfer_function.h
+++ /dev/null
@@ -1,201 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "image.h"
-#include "node.h"
-
-namespace oidn {
-
- __forceinline float luminance(float r, float g, float b)
- {
- return 0.212671f * r + 0.715160f * g + 0.072169f * b;
- }
-
- // Color transfer function base class
- class TransferFunction
- {
- public:
- virtual ~TransferFunction() = default;
-
- virtual float forward(float y) const = 0;
- virtual float inverse(float x) const = 0;
- };
-
- // HDR transfer function base class
- class HDRTransferFunction : public TransferFunction
- {
- protected:
- static constexpr float yMax = 65504.f;
-
- float exposure;
- float rcpExposure;
-
- public:
- HDRTransferFunction(float exposure = 1.f)
- {
- setExposure(exposure);
- }
-
- void setExposure(float exposure)
- {
- this->exposure = exposure;
- this->rcpExposure = (exposure != 0.f) ? (1.f / exposure) : 0.f;
- }
- };
-
- // Linear transfer function (LDR)
- class LinearTransferFunction : public TransferFunction
- {
- public:
- __forceinline float forward(float y) const override
- {
- return min(y, 1.f);
- }
-
- __forceinline float inverse(float x) const override
- {
- return min(x, 1.f);
- }
- };
-
- // 2.2 gamma transfer function (LDR)
- class GammaTransferFunction : public TransferFunction
- {
- public:
- __forceinline float forward(float y) const override
- {
- return min(pow(y, 1.f/2.2f), 1.f);
- }
-
- __forceinline float inverse(float x) const override
- {
- return min(pow(x, 2.2f), 1.f);
- }
- };
-
- // Logarithmic transfer function (HDR)
- // Compresses [0..65504] to [0..1]
- class LogTransferFunction : public HDRTransferFunction
- {
- private:
- static const float xScale;
-
- public:
- LogTransferFunction(float exposure = 1.f)
- : HDRTransferFunction(exposure)
- {
- }
-
- __forceinline float forward(float y) const override
- {
- return log(y * exposure + 1.f) * xScale;
- }
-
- __forceinline float inverse(float x) const override
- {
- return (exp(x * (1.f/xScale)) - 1.f) * rcpExposure;
- }
- };
-
- // PQX transfer function (HDR)
- // Compresses [0..65504] to [0..1]
- class PQXTransferFunction : public HDRTransferFunction
- {
- private:
- static constexpr float m1 = 2610.f / 4096.f / 4.f;
- static constexpr float m2 = 2523.f / 4096.f * 128.f;
- static constexpr float c1 = 3424.f / 4096.f;
- static constexpr float c2 = 2413.f / 4096.f * 32.f;
- static constexpr float c3 = 2392.f / 4096.f * 32.f;
- static constexpr float a = 3711.f / 4096.f / 8.f;
-
- static constexpr float yScale = 100.f / 10000.f;
- static const float xScale;
-
- public:
- PQXTransferFunction(float exposure = 1.f)
- : HDRTransferFunction(exposure)
- {
- }
-
- __forceinline float forward(float y) const override
- {
- return pqxForward(y * exposure * yScale) * xScale;
- }
-
- __forceinline float inverse(float x) const override
- {
- return pqxInverse(x * (1.f/xScale)) * (1.f/yScale) * rcpExposure;
- }
-
- private:
- static __forceinline float pqForward(float y)
- {
- const float yp = pow(y, m1);
- return pow((c1 + c2 * yp) * rcp(1.f + c3 * yp), m2);
- }
-
- static __forceinline float pqxForward(float y)
- {
- if (y <= 1.f)
- return pqForward(y);
- else
- return a * log(y) + 1.f;
- }
-
- static __forceinline float pqInverse(float x)
- {
- const float xp = pow(x, 1.f/m2);
- return pow(max((xp - c1) * rcp(c2 - c3 * xp), 0.f), 1.f/m1);
- }
-
- static __forceinline float pqxInverse(float x)
- {
- if (x <= 1.f)
- return pqInverse(x);
- else
- return exp((x - 1.f) * (1.f/a));
- }
- };
-
- // Autoexposure node
- class AutoexposureNode : public Node
- {
- private:
- Image color;
- std::shared_ptr<HDRTransferFunction> transferFunc;
-
- public:
- AutoexposureNode(const Image& color,
- const std::shared_ptr<HDRTransferFunction>& transferFunc)
- : color(color),
- transferFunc(transferFunc)
- {}
-
- void execute(stream& sm) override
- {
- const float exposure = autoexposure(color);
- //printf("exposure = %f\n", exposure);
- transferFunc->setExposure(exposure);
- }
-
- private:
- static float autoexposure(const Image& color);
- };
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/upsample.h b/thirdparty/oidn/core/upsample.h
deleted file mode 100644
index f6cace44cd..0000000000
--- a/thirdparty/oidn/core/upsample.h
+++ /dev/null
@@ -1,92 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "node.h"
-
-namespace oidn {
-
- // 2x2 nearest-neighbor upsampling node
- template<int K>
- class UpsampleNode : public Node
- {
- private:
- std::shared_ptr<memory> src;
- std::shared_ptr<memory> dst;
-
- public:
- UpsampleNode(const std::shared_ptr<memory>& src,
- const std::shared_ptr<memory>& dst)
- : src(src),
- dst(dst)
- {
- const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
- const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
- MAYBE_UNUSED(srcDesc);
- MAYBE_UNUSED(dstDesc);
- assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
- assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
- assert(srcDesc.ndims == 4);
- assert(dstDesc.ndims == 4);
- assert(srcDesc.data_type == memory::data_type::f32);
- assert(dstDesc.data_type == memory::data_type::f32);
- assert(srcDesc.dims[0] == 1);
- assert(dstDesc.dims[0] == 1);
- // 2x2 upsampling
- assert(dstDesc.dims[2] == srcDesc.dims[2] * 2);
- assert(dstDesc.dims[3] == srcDesc.dims[3] * 2);
- }
-
- void execute(stream& sm) override
- {
- const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
-
- const float* srcPtr = (float*)src->get_data_handle();
- float* dstPtr = (float*)dst->get_data_handle();
-
- const int C = srcDesc.dims[1];
- const int H = srcDesc.dims[2];
- const int W = srcDesc.dims[3];
- const int CK = C / K;
-
- parallel_nd(CK, H, [&](int ck, int h)
- {
- const size_t offset = ck*H*W*K + h*W*K;
- const float* srcPtr_line = srcPtr + offset;
- float* dstPtr_line0 = dstPtr + offset * 4;
- float* dstPtr_line1 = dstPtr_line0 + W*2*K; // next line
-
- for (int w = 0; w < W; ++w)
- {
- #pragma unroll
- for (int k = 0; k < K; k += 4)
- {
- const __m128 m = _mm_load_ps(&srcPtr_line[w*K + k]);
-
- _mm_stream_ps(&dstPtr_line0[w*2*K + k], m);
- _mm_stream_ps(&dstPtr_line0[w*2*K+K + k], m);
- _mm_stream_ps(&dstPtr_line1[w*2*K + k], m);
- _mm_stream_ps(&dstPtr_line1[w*2*K+K + k], m);
- }
- }
- });
- }
-
- std::shared_ptr<memory> getDst() const override { return dst; }
- };
-
-} // namespace oidn
diff --git a/thirdparty/oidn/core/weights_reorder.h b/thirdparty/oidn/core/weights_reorder.h
deleted file mode 100644
index 6c5dacb8aa..0000000000
--- a/thirdparty/oidn/core/weights_reorder.h
+++ /dev/null
@@ -1,99 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "node.h"
-
-namespace oidn {
-
- // Reorders weights from oihw to padded oihw format
- template<int K>
- class WeightsReorderNode : public Node
- {
- private:
- std::shared_ptr<memory> src;
- std::shared_ptr<memory> dst;
-
- public:
- WeightsReorderNode(const std::shared_ptr<memory>& src,
- const std::shared_ptr<memory>& dst)
- : src(src),
- dst(dst)
- {
- const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
- const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
- MAYBE_UNUSED(srcDesc);
- MAYBE_UNUSED(dstDesc);
- assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(memory::format_tag::oihw)));
- assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(memory::format_tag::oihw)));
- assert(srcDesc.ndims == 4);
- assert(dstDesc.ndims == 4);
- assert(srcDesc.data_type == memory::data_type::f32);
- assert(dstDesc.data_type == memory::data_type::f32);
- assert(getPadded<K>(srcDesc.dims[0]) == dstDesc.dims[0]); // OC
- assert(getPadded<K>(srcDesc.dims[1]) == dstDesc.dims[1]); // IC
- assert(srcDesc.dims[2] == dstDesc.dims[2]);
- assert(srcDesc.dims[3] == dstDesc.dims[3]);
- }
-
- void execute(stream& sm) override
- {
- const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
- const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
-
- const float* srcPtr = (float*)src->get_data_handle();
- float* dstPtr = (float*)dst->get_data_handle();
-
- const int OC1 = srcDesc.dims[0];
- const int OC2 = dstDesc.dims[0];
- const int IC1 = srcDesc.dims[1];
- const int IC2 = dstDesc.dims[1];
- const int H = dstDesc.dims[2];
- const int W = dstDesc.dims[3];
-
- for (int oc = 0; oc < OC2; ++oc)
- {
- for (int ic = 0; ic < IC2; ++ic)
- {
- for (int h = 0; h < H; ++h)
- {
- for (int w = 0; w < W; ++w)
- {
- // Output is in oihw format
- float* dstPtr_c = dstPtr + oc*IC2*H*W + ic*H*W + h*W + w;
-
- if (oc < OC1 && ic < IC1)
- {
- // Input is in oihw format
- const float* srcPtr_c = srcPtr + oc*IC1*H*W + ic*H*W + h*W + w;
- *dstPtr_c = *srcPtr_c;
- }
- else
- {
- // padding
- *dstPtr_c = 0;
- }
- }
- }
- }
- }
- }
-
- std::shared_ptr<memory> getDst() const override { return dst; }
- };
-
-} // namespace oidn
diff --git a/thirdparty/oidn/include/OpenImageDenoise/oidn.h b/thirdparty/oidn/include/OpenImageDenoise/oidn.h
deleted file mode 100644
index 57ba6baa21..0000000000
--- a/thirdparty/oidn/include/OpenImageDenoise/oidn.h
+++ /dev/null
@@ -1,214 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include <stddef.h>
-#include <stdbool.h>
-#include <stdint.h>
-
-#include "version.h"
-
-#if defined(__cplusplus)
-extern "C" {
-#endif
-
-#ifndef OIDN_API
-#if defined(_WIN32) && !defined(OIDN_STATIC_LIB)
-# define OIDN_API __declspec(dllimport)
-#else
-# define OIDN_API
-#endif
-#endif
-
-// ----------------------------------------------------------------------------
-// Device
-// ----------------------------------------------------------------------------
-
-// Device types
-typedef enum
-{
- OIDN_DEVICE_TYPE_DEFAULT = 0, // select device automatically
-
- OIDN_DEVICE_TYPE_CPU = 1, // CPU device
-} OIDNDeviceType;
-
-// Error codes
-typedef enum
-{
- OIDN_ERROR_NONE = 0, // no error occurred
- OIDN_ERROR_UNKNOWN = 1, // an unknown error occurred
- OIDN_ERROR_INVALID_ARGUMENT = 2, // an invalid argument was specified
- OIDN_ERROR_INVALID_OPERATION = 3, // the operation is not allowed
- OIDN_ERROR_OUT_OF_MEMORY = 4, // not enough memory to execute the operation
- OIDN_ERROR_UNSUPPORTED_HARDWARE = 5, // the hardware (e.g. CPU) is not supported
- OIDN_ERROR_CANCELLED = 6, // the operation was cancelled by the user
-} OIDNError;
-
-// Error callback function
-typedef void (*OIDNErrorFunction)(void* userPtr, OIDNError code, const char* message);
-
-// Device handle
-typedef struct OIDNDeviceImpl* OIDNDevice;
-
-// Creates a new device.
-OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type);
-
-// Retains the device (increments the reference count).
-OIDN_API void oidnRetainDevice(OIDNDevice device);
-
-// Releases the device (decrements the reference count).
-OIDN_API void oidnReleaseDevice(OIDNDevice device);
-
-// Sets a boolean parameter of the device.
-OIDN_API void oidnSetDevice1b(OIDNDevice device, const char* name, bool value);
-
-// Sets an integer parameter of the device.
-OIDN_API void oidnSetDevice1i(OIDNDevice device, const char* name, int value);
-
-// Gets a boolean parameter of the device.
-OIDN_API bool oidnGetDevice1b(OIDNDevice device, const char* name);
-
-// Gets an integer parameter of the device (e.g. "version").
-OIDN_API int oidnGetDevice1i(OIDNDevice device, const char* name);
-
-// Sets the error callback function of the device.
-OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice device, OIDNErrorFunction func, void* userPtr);
-
-// Returns the first unqueried error code stored in the device for the current
-// thread, optionally also returning a string message (if not NULL), and clears
-// the stored error. Can be called with a NULL device as well to check why a
-// device creation failed.
-OIDN_API OIDNError oidnGetDeviceError(OIDNDevice device, const char** outMessage);
-
-// Commits all previous changes to the device.
-// Must be called before first using the device (e.g. creating filters).
-OIDN_API void oidnCommitDevice(OIDNDevice device);
-
-// ----------------------------------------------------------------------------
-// Buffer
-// ----------------------------------------------------------------------------
-
-// Formats for images and other data stored in buffers
-typedef enum
-{
- OIDN_FORMAT_UNDEFINED = 0,
-
- // 32-bit single-precision floating point scalar and vector formats
- OIDN_FORMAT_FLOAT = 1,
- OIDN_FORMAT_FLOAT2 = 2,
- OIDN_FORMAT_FLOAT3 = 3,
- OIDN_FORMAT_FLOAT4 = 4,
-} OIDNFormat;
-
-// Access modes for mapping buffers
-typedef enum
-{
- OIDN_ACCESS_READ = 0, // read-only access
- OIDN_ACCESS_WRITE = 1, // write-only access
- OIDN_ACCESS_READ_WRITE = 2, // read and write access
- OIDN_ACCESS_WRITE_DISCARD = 3, // write-only access, previous contents discarded
-} OIDNAccess;
-
-// Buffer handle
-typedef struct OIDNBufferImpl* OIDNBuffer;
-
-// Creates a new buffer (data allocated and owned by the device).
-OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice device, size_t byteSize);
-
-// Creates a new shared buffer (data allocated and owned by the user).
-OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice device, void* ptr, size_t byteSize);
-
-// Maps a region of the buffer to host memory.
-// If byteSize is 0, the maximum available amount of memory will be mapped.
-OIDN_API void* oidnMapBuffer(OIDNBuffer buffer, OIDNAccess access, size_t byteOffset, size_t byteSize);
-
-// Unmaps a region of the buffer.
-// mappedPtr must be a pointer returned by a previous call to oidnMapBuffer.
-OIDN_API void oidnUnmapBuffer(OIDNBuffer buffer, void* mappedPtr);
-
-// Retains the buffer (increments the reference count).
-OIDN_API void oidnRetainBuffer(OIDNBuffer buffer);
-
-// Releases the buffer (decrements the reference count).
-OIDN_API void oidnReleaseBuffer(OIDNBuffer buffer);
-
-// ----------------------------------------------------------------------------
-// Filter
-// ----------------------------------------------------------------------------
-
-// Progress monitor callback function
-typedef bool (*OIDNProgressMonitorFunction)(void* userPtr, double n);
-
-// Filter handle
-typedef struct OIDNFilterImpl* OIDNFilter;
-
-// Creates a new filter of the specified type (e.g. "RT").
-OIDN_API OIDNFilter oidnNewFilter(OIDNDevice device, const char* type);
-
-// Retains the filter (increments the reference count).
-OIDN_API void oidnRetainFilter(OIDNFilter filter);
-
-// Releases the filter (decrements the reference count).
-OIDN_API void oidnReleaseFilter(OIDNFilter filter);
-
-// Sets an image parameter of the filter (stored in a buffer).
-// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically.
-OIDN_API void oidnSetFilterImage(OIDNFilter filter, const char* name,
- OIDNBuffer buffer, OIDNFormat format,
- size_t width, size_t height,
- size_t byteOffset,
- size_t bytePixelStride, size_t byteRowStride);
-
-// Sets an image parameter of the filter (owned by the user).
-// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically.
-OIDN_API void oidnSetSharedFilterImage(OIDNFilter filter, const char* name,
- void* ptr, OIDNFormat format,
- size_t width, size_t height,
- size_t byteOffset,
- size_t bytePixelStride, size_t byteRowStride);
-
-// Sets a boolean parameter of the filter.
-OIDN_API void oidnSetFilter1b(OIDNFilter filter, const char* name, bool value);
-
-// Gets a boolean parameter of the filter.
-OIDN_API bool oidnGetFilter1b(OIDNFilter filter, const char* name);
-
-// Sets an integer parameter of the filter.
-OIDN_API void oidnSetFilter1i(OIDNFilter filter, const char* name, int value);
-
-// Gets an integer parameter of the filter.
-OIDN_API int oidnGetFilter1i(OIDNFilter filter, const char* name);
-
-// Sets a float parameter of the filter.
-OIDN_API void oidnSetFilter1f(OIDNFilter filter, const char* name, float value);
-
-// Gets a float parameter of the filter.
-OIDN_API float oidnGetFilter1f(OIDNFilter filter, const char* name);
-
-// Sets the progress monitor callback function of the filter.
-OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter filter, OIDNProgressMonitorFunction func, void* userPtr);
-
-// Commits all previous changes to the filter.
-// Must be called before first executing the filter.
-OIDN_API void oidnCommitFilter(OIDNFilter filter);
-
-// Executes the filter.
-OIDN_API void oidnExecuteFilter(OIDNFilter filter);
-
-#if defined(__cplusplus)
-}
-#endif
diff --git a/thirdparty/oidn/include/OpenImageDenoise/oidn.hpp b/thirdparty/oidn/include/OpenImageDenoise/oidn.hpp
deleted file mode 100644
index 9f95a56fe1..0000000000
--- a/thirdparty/oidn/include/OpenImageDenoise/oidn.hpp
+++ /dev/null
@@ -1,468 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include <algorithm>
-#include "oidn.h"
-
-namespace oidn {
-
- // --------------------------------------------------------------------------
- // Buffer
- // --------------------------------------------------------------------------
-
- // Formats for images and other data stored in buffers
- enum class Format
- {
- Undefined = OIDN_FORMAT_UNDEFINED,
-
- // 32-bit single-precision floating point scalar and vector formats
- Float = OIDN_FORMAT_FLOAT,
- Float2 = OIDN_FORMAT_FLOAT2,
- Float3 = OIDN_FORMAT_FLOAT3,
- Float4 = OIDN_FORMAT_FLOAT4,
- };
-
- // Access modes for mapping buffers
- enum class Access
- {
- Read = OIDN_ACCESS_READ, // read-only access
- Write = OIDN_ACCESS_WRITE, // write-only access
- ReadWrite = OIDN_ACCESS_READ_WRITE, // read and write access
- WriteDiscard = OIDN_ACCESS_WRITE_DISCARD, // write-only access, previous contents discarded
- };
-
- // Buffer object with automatic reference counting
- class BufferRef
- {
- private:
- OIDNBuffer handle;
-
- public:
- BufferRef() : handle(nullptr) {}
- BufferRef(OIDNBuffer handle) : handle(handle) {}
-
- BufferRef(const BufferRef& other) : handle(other.handle)
- {
- if (handle)
- oidnRetainBuffer(handle);
- }
-
- BufferRef(BufferRef&& other) : handle(other.handle)
- {
- other.handle = nullptr;
- }
-
- BufferRef& operator =(const BufferRef& other)
- {
- if (&other != this)
- {
- if (other.handle)
- oidnRetainBuffer(other.handle);
- if (handle)
- oidnReleaseBuffer(handle);
- handle = other.handle;
- }
- return *this;
- }
-
- BufferRef& operator =(BufferRef&& other)
- {
- std::swap(handle, other.handle);
- return *this;
- }
-
- BufferRef& operator =(OIDNBuffer other)
- {
- if (other)
- oidnRetainBuffer(other);
- if (handle)
- oidnReleaseBuffer(handle);
- handle = other;
- return *this;
- }
-
- ~BufferRef()
- {
- if (handle)
- oidnReleaseBuffer(handle);
- }
-
- OIDNBuffer getHandle() const
- {
- return handle;
- }
-
- operator bool() const
- {
- return handle != nullptr;
- }
-
- // Maps a region of the buffer to host memory.
- // If byteSize is 0, the maximum available amount of memory will be mapped.
- void* map(Access access = Access::ReadWrite, size_t byteOffset = 0, size_t byteSize = 0)
- {
- return oidnMapBuffer(handle, (OIDNAccess)access, byteOffset, byteSize);
- }
-
- // Unmaps a region of the buffer.
- // mappedPtr must be a pointer returned by a previous call to map.
- void unmap(void* mappedPtr)
- {
- oidnUnmapBuffer(handle, mappedPtr);
- }
- };
-
- // --------------------------------------------------------------------------
- // Filter
- // --------------------------------------------------------------------------
-
- // Progress monitor callback function
- typedef bool (*ProgressMonitorFunction)(void* userPtr, double n);
-
- // Filter object with automatic reference counting
- class FilterRef
- {
- private:
- OIDNFilter handle;
-
- public:
- FilterRef() : handle(nullptr) {}
- FilterRef(OIDNFilter handle) : handle(handle) {}
-
- FilterRef(const FilterRef& other) : handle(other.handle)
- {
- if (handle)
- oidnRetainFilter(handle);
- }
-
- FilterRef(FilterRef&& other) : handle(other.handle)
- {
- other.handle = nullptr;
- }
-
- FilterRef& operator =(const FilterRef& other)
- {
- if (&other != this)
- {
- if (other.handle)
- oidnRetainFilter(other.handle);
- if (handle)
- oidnReleaseFilter(handle);
- handle = other.handle;
- }
- return *this;
- }
-
- FilterRef& operator =(FilterRef&& other)
- {
- std::swap(handle, other.handle);
- return *this;
- }
-
- FilterRef& operator =(OIDNFilter other)
- {
- if (other)
- oidnRetainFilter(other);
- if (handle)
- oidnReleaseFilter(handle);
- handle = other;
- return *this;
- }
-
- ~FilterRef()
- {
- if (handle)
- oidnReleaseFilter(handle);
- }
-
- OIDNFilter getHandle() const
- {
- return handle;
- }
-
- operator bool() const
- {
- return handle != nullptr;
- }
-
- // Sets an image parameter of the filter (stored in a buffer).
- void setImage(const char* name,
- const BufferRef& buffer, Format format,
- size_t width, size_t height,
- size_t byteOffset = 0,
- size_t bytePixelStride = 0, size_t byteRowStride = 0)
- {
- oidnSetFilterImage(handle, name,
- buffer.getHandle(), (OIDNFormat)format,
- width, height,
- byteOffset,
- bytePixelStride, byteRowStride);
- }
-
- // Sets an image parameter of the filter (owned by the user).
- void setImage(const char* name,
- void* ptr, Format format,
- size_t width, size_t height,
- size_t byteOffset = 0,
- size_t bytePixelStride = 0, size_t byteRowStride = 0)
- {
- oidnSetSharedFilterImage(handle, name,
- ptr, (OIDNFormat)format,
- width, height,
- byteOffset,
- bytePixelStride, byteRowStride);
- }
-
- // Sets a boolean parameter of the filter.
- void set(const char* name, bool value)
- {
- oidnSetFilter1b(handle, name, value);
- }
-
- // Sets an integer parameter of the filter.
- void set(const char* name, int value)
- {
- oidnSetFilter1i(handle, name, value);
- }
-
- // Sets a float parameter of the filter.
- void set(const char* name, float value)
- {
- oidnSetFilter1f(handle, name, value);
- }
-
- // Gets a parameter of the filter.
- template<typename T>
- T get(const char* name);
-
- // Sets the progress monitor callback function of the filter.
- void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr = nullptr)
- {
- oidnSetFilterProgressMonitorFunction(handle, (OIDNProgressMonitorFunction)func, userPtr);
- }
-
- // Commits all previous changes to the filter.
- void commit()
- {
- oidnCommitFilter(handle);
- }
-
- // Executes the filter.
- void execute()
- {
- oidnExecuteFilter(handle);
- }
- };
-
- // Gets a boolean parameter of the filter.
- template<>
- inline bool FilterRef::get(const char* name)
- {
- return oidnGetFilter1b(handle, name);
- }
-
- // Gets an integer parameter of the filter.
- template<>
- inline int FilterRef::get(const char* name)
- {
- return oidnGetFilter1i(handle, name);
- }
-
- // Gets a float parameter of the filter.
- template<>
- inline float FilterRef::get(const char* name)
- {
- return oidnGetFilter1f(handle, name);
- }
-
- // --------------------------------------------------------------------------
- // Device
- // --------------------------------------------------------------------------
-
- // Device types
- enum class DeviceType
- {
- Default = OIDN_DEVICE_TYPE_DEFAULT, // select device automatically
-
- CPU = OIDN_DEVICE_TYPE_CPU, // CPU device
- };
-
- // Error codes
- enum class Error
- {
- None = OIDN_ERROR_NONE, // no error occurred
- Unknown = OIDN_ERROR_UNKNOWN, // an unknown error occurred
- InvalidArgument = OIDN_ERROR_INVALID_ARGUMENT, // an invalid argument was specified
- InvalidOperation = OIDN_ERROR_INVALID_OPERATION, // the operation is not allowed
- OutOfMemory = OIDN_ERROR_OUT_OF_MEMORY, // not enough memory to execute the operation
- UnsupportedHardware = OIDN_ERROR_UNSUPPORTED_HARDWARE, // the hardware (e.g. CPU) is not supported
- Cancelled = OIDN_ERROR_CANCELLED, // the operation was cancelled by the user
- };
-
- // Error callback function
- typedef void (*ErrorFunction)(void* userPtr, Error code, const char* message);
-
- // Device object with automatic reference counting
- class DeviceRef
- {
- private:
- OIDNDevice handle;
-
- public:
- DeviceRef() : handle(nullptr) {}
- DeviceRef(OIDNDevice handle) : handle(handle) {}
-
- DeviceRef(const DeviceRef& other) : handle(other.handle)
- {
- if (handle)
- oidnRetainDevice(handle);
- }
-
- DeviceRef(DeviceRef&& other) : handle(other.handle)
- {
- other.handle = nullptr;
- }
-
- DeviceRef& operator =(const DeviceRef& other)
- {
- if (&other != this)
- {
- if (other.handle)
- oidnRetainDevice(other.handle);
- if (handle)
- oidnReleaseDevice(handle);
- handle = other.handle;
- }
- return *this;
- }
-
- DeviceRef& operator =(DeviceRef&& other)
- {
- std::swap(handle, other.handle);
- return *this;
- }
-
- DeviceRef& operator =(OIDNDevice other)
- {
- if (other)
- oidnRetainDevice(other);
- if (handle)
- oidnReleaseDevice(handle);
- handle = other;
- return *this;
- }
-
- ~DeviceRef()
- {
- if (handle)
- oidnReleaseDevice(handle);
- }
-
- OIDNDevice getHandle() const
- {
- return handle;
- }
-
- operator bool() const
- {
- return handle != nullptr;
- }
-
- // Sets a boolean parameter of the device.
- void set(const char* name, bool value)
- {
- oidnSetDevice1b(handle, name, value);
- }
-
- // Sets an integer parameter of the device.
- void set(const char* name, int value)
- {
- oidnSetDevice1i(handle, name, value);
- }
-
- // Gets a parameter of the device.
- template<typename T>
- T get(const char* name);
-
- // Sets the error callback function of the device.
- void setErrorFunction(ErrorFunction func, void* userPtr = nullptr)
- {
- oidnSetDeviceErrorFunction(handle, (OIDNErrorFunction)func, userPtr);
- }
-
- // Returns the first unqueried error code and clears the stored error.
- // Can be called for a null device as well to check why a device creation failed.
- Error getError()
- {
- return (Error)oidnGetDeviceError(handle, nullptr);
- }
-
- // Returns the first unqueried error code and string message, and clears the stored error.
- // Can be called for a null device as well to check why a device creation failed.
- Error getError(const char*& outMessage)
- {
- return (Error)oidnGetDeviceError(handle, &outMessage);
- }
-
- // Commits all previous changes to the device.
- // Must be called before first using the device (e.g. creating filters).
- void commit()
- {
- oidnCommitDevice(handle);
- }
-
- // Creates a new buffer (data allocated and owned by the device).
- BufferRef newBuffer(size_t byteSize)
- {
- return oidnNewBuffer(handle, byteSize);
- }
-
- // Creates a new shared buffer (data allocated and owned by the user).
- BufferRef newBuffer(void* ptr, size_t byteSize)
- {
- return oidnNewSharedBuffer(handle, ptr, byteSize);
- }
-
- // Creates a new filter of the specified type (e.g. "RT").
- FilterRef newFilter(const char* type)
- {
- return oidnNewFilter(handle, type);
- }
- };
-
- // Gets a boolean parameter of the device.
- template<>
- inline bool DeviceRef::get(const char* name)
- {
- return oidnGetDevice1b(handle, name);
- }
-
- // Gets an integer parameter of the device (e.g. "version").
- template<>
- inline int DeviceRef::get(const char* name)
- {
- return oidnGetDevice1i(handle, name);
- }
-
- // Creates a new device.
- inline DeviceRef newDevice(DeviceType type = DeviceType::Default)
- {
- return DeviceRef(oidnNewDevice((OIDNDeviceType)type));
- }
-
-} // namespace oidn
diff --git a/thirdparty/oidn/include/OpenImageDenoise/version.h b/thirdparty/oidn/include/OpenImageDenoise/version.h
deleted file mode 100644
index 66b347c992..0000000000
--- a/thirdparty/oidn/include/OpenImageDenoise/version.h
+++ /dev/null
@@ -1,23 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#define OIDN_VERSION_MAJOR 1
-#define OIDN_VERSION_MINOR 1
-#define OIDN_VERSION_PATCH 0
-#define OIDN_VERSION 10100
-#define OIDN_VERSION_STRING "1.1.0"
diff --git a/thirdparty/oidn/mkl-dnn/LICENSE b/thirdparty/oidn/mkl-dnn/LICENSE
deleted file mode 100644
index d13f7b7ca0..0000000000
--- a/thirdparty/oidn/mkl-dnn/LICENSE
+++ /dev/null
@@ -1,214 +0,0 @@
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "{}"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright {yyyy} {name of copyright owner}
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
- ============================================================================
-
- Intel MKL-DNN includes components with separate copyright
- notices and license terms.
-
- XByak, 3-clause BSD license
- Copyright (c) 2007 MITSUNARI Shigeo
- See full copyright notice and license text in src/cpu/xbyak/COPYRIGHT
-
- gtest, 3-clause BSD license
- Copyright 2008, Google Inc.
- See full copyright notice and license text in tests/gtests/gtest/LICENSE
diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn.h b/thirdparty/oidn/mkl-dnn/include/mkldnn.h
deleted file mode 100644
index 9b64994922..0000000000
--- a/thirdparty/oidn/mkl-dnn/include/mkldnn.h
+++ /dev/null
@@ -1,1771 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef MKLDNN_H
-#define MKLDNN_H
-
-#ifndef DOXYGEN_SHOULD_SKIP_THIS
-
-/* All symbols shall be internal unless marked as MKLDNN_API */
-#if defined _WIN32 || defined __CYGWIN__
-# define MKLDNN_HELPER_DLL_IMPORT __declspec(dllimport)
-# define MKLDNN_HELPER_DLL_EXPORT __declspec(dllexport)
-#else
-# if __GNUC__ >= 4
-# define MKLDNN_HELPER_DLL_IMPORT __attribute__ ((visibility ("default")))
-# define MKLDNN_HELPER_DLL_EXPORT __attribute__ ((visibility ("default")))
-# else
-# define MKLDNN_HELPER_DLL_IMPORT
-# define MKLDNN_HELPER_DLL_EXPORT
-# endif
-#endif
-
-#ifdef MKLDNN_DLL
-# ifdef MKLDNN_DLL_EXPORTS
-# define MKLDNN_API MKLDNN_HELPER_DLL_EXPORT
-# else
-# define MKLDNN_API MKLDNN_HELPER_DLL_IMPORT
-# endif
-#else
-# define MKLDNN_API
-#endif
-
-#if defined (__GNUC__)
-# define MKLDNN_DEPRECATED __attribute__((deprecated))
-#elif defined(_MSC_VER)
-# define MKLDNN_DEPRECATED __declspec(deprecated)
-#else
-# define MKLDNN_DEPRECATED
-#endif
-
-#include "mkldnn_types.h"
-#include "mkldnn_version.h"
-#endif /* DOXYGEN_SHOULD_SKIP_THIS */
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-/** @addtogroup c_api C API
- * @{ */
-
-/** @addtogroup c_api_primitive Primitive operations
- * @{ */
-
-/** @addtogroup c_api_primitive_common Common primitive operations
- * @{ */
-
-/** Creates a primitive descriptor @p iterator for given @p op_desc, @p attr,
- * @p engine, and optionally a hint primitive descriptor from forward
- * propagation (required for backward propagation). Pass @c NULL for forward
- * propagation.
- */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_create(
- mkldnn_primitive_desc_iterator_t *iterator,
- const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr,
- mkldnn_engine_t engine,
- const_mkldnn_primitive_desc_t hint_forward_primitive_desc);
-
-/** Iterates over primitive descriptors. Returns #mkldnn_iterator_ends if no
- * more primitive descriptors are available. */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_next(
- mkldnn_primitive_desc_iterator_t iterator);
-
-/** Fetches the current primitive descriptor.
- *
- * @note
- * The user should delete the fetched primitive descriptor using
- * mkldnn_primitive_desc_destroy() once it is no longer needed. */
-mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_iterator_fetch(
- const_mkldnn_primitive_desc_iterator_t iterator);
-
-/** Deletes a primitive descriptor @p iterator */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_destroy(
- mkldnn_primitive_desc_iterator_t iterator);
-
-/** Creates a @p primitive_desc using @p op_desc, @p attr, @p engine, and
- * optionally a hint primitive descriptor from forward propagation. The call is
- * equivalent to creating a primitive descriptor iterator, immediately fetching
- * a primitive descriptor, and then destroying the iterator. */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_create(
- mkldnn_primitive_desc_t *primitive_desc,
- const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr,
- mkldnn_engine_t engine,
- const_mkldnn_primitive_desc_t hint_forward_primitive_desc);
-
-/** Makes a copy of a @p primitive_desc. */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_clone(
- mkldnn_primitive_desc_t *primitive_desc,
- const_mkldnn_primitive_desc_t existing_primitive_desc);
-
-/** Returns a constant reference to the attribute of a @p primitive_desc.
- *
- * @warning
- * The user should not destroy the obtained @p attr.
- *
- * @warning
- * The lifetime of an @p attr is the same as that of a @p primitive_desc,
- * so it is illegal to use the @p attr once @p primitive_desc has been
- * destroyed. */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_get_attr(
- const_mkldnn_primitive_desc_t primitive_desc,
- const_mkldnn_primitive_attr_t *attr);
-
-/** Deletes a @p primitive_desc. */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_destroy(
- mkldnn_primitive_desc_t primitive_desc);
-
-/** Queries primitive descriptor
- *
- * One of the most typical use cases is to query a convolution primitive
- * descriptor created with source, weights, and destination formats equal
- * to #mkldnn_format_tag_any about the corresponding memory descriptors
- * (@p what equals #mkldnn_query_src_md, #mkldnn_query_weights_md, and
- * #mkldnn_query_dst_md respectively) to be able to prepare memory and
- * create reorders if required.
- *
- * Another quite typical use case is to query an operation primitive
- * descriptor for a workspace (@p what equals #mkldnn_query_workspace_md).
- * The returned status #mkldnn_not_required indicates that a workspace is
- * not required.
- *
- * A few other possibilities:
- * - query an operation primitive descriptor for the underlying operation
- * descriptor (#mkldnn_query_convolution_d, #mkldnn_query_eltwise_d,
- * #mkldnn_query_rnn_d, etc.)
- * - query an operation primitive descriptor for the implementation
- * information string (#mkldnn_query_impl_info_str)
- * - query an operation primitive descriptor for the number of inputs and
- * outputs (#mkldnn_query_num_of_inputs_s32 and
- * #mkldnn_query_num_of_outputs_s32 respectively)
- *
- * @sa mkldnn_query_t for more options
- */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_query(
- const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what,
- int index, void *result);
-
-/** Queries primitive descriptor for memory descriptor
- *
- * @returns NULL in case of any error.
- *
- * This is just a specialized version of mkldnn_primitive_desc_query
- * used for convenience.
- */
-const mkldnn_memory_desc_t MKLDNN_API *mkldnn_primitive_desc_query_md(
- const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what,
- int index);
-
-/** Queries primitive descriptor for signed 32bit int
- *
- * @returns 0 in case of any error (in particular if the queried entity is
- * not of type int32_t). Note that 0 might also be the actual returned
- * value.
- *
- * This is just a specialized version of mkldnn_primitive_desc_query
- * used for convenience.
- */
-int MKLDNN_API mkldnn_primitive_desc_query_s32(
- const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what,
- int index);
-
-/** Creates a @p primitive using a @p primitive_desc descriptor. */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_create(
- mkldnn_primitive_t *primitive,
- const_mkldnn_primitive_desc_t primitive_desc);
-
-/** Executes a @p primitive using a @p stream, and @p nargs arguments
- * @p args. */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_execute(
- const_mkldnn_primitive_t primitive, mkldnn_stream_t stream,
- int nargs, const mkldnn_exec_arg_t *args);
-
-/** Retrieves a reference to the @p primitive_desc descriptor of given @p
- * primitive.
- *
- * @warning
- * The returned object must not be destroyed by the user. The @c const
- * qualifier of the returned object prevents such attempts. */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_get_primitive_desc(
- const_mkldnn_primitive_t primitive,
- const_mkldnn_primitive_desc_t *primitive_desc);
-
-/** Deletes a @p primitive. */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_destroy(
- mkldnn_primitive_t primitive);
-
-/** @} */
-
-/** @addtogroup c_api_attributes Attributes
- * An extension for controlling primitive behavior.
- * @{ */
-
-/** Creates an empty (default) @p attr attribute. All the parameters are set to
- * default values.
- *
- * An empty attribute is used in primitive descriptor creation whenever it
- * is not passed explicitly, e.g. in mkldnn_primitive_desc_create.
- */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_create(
- mkldnn_primitive_attr_t *attr);
-
-/** Makes a copy of an @p existing_attr. */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_clone(
- mkldnn_primitive_attr_t *attr,
- const_mkldnn_primitive_attr_t existing_attr);
-
-/** Deletes an @p attr. */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_destroy(
- mkldnn_primitive_attr_t attr);
-
-/** Returns the scratchpad @p mode set in the attribute @p attr */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_scratchpad_mode(
- const_mkldnn_primitive_attr_t attr, mkldnn_scratchpad_mode_t *mode);
-
-/** Sets scratchpad @p mode.
- *
- * The possible values are: #mkldnn_scratchpad_mode_library (default) and
- * #mkldnn_scratchpad_mode_user. */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_scratchpad_mode(
- mkldnn_primitive_attr_t attr, mkldnn_scratchpad_mode_t mode);
-
-/** Returns @p count, correspondence scale @p mask, and a pointer to a constant
- * floating point array of output @p scales for given @p attr, previously set
- * by mkldnn_primitive_attr_set_output_scales.
- *
- * @warning
- * The @p scales array points to the internal @p attr field, so the user
- * should not modify or destroy @p scales.
- *
- * @warning
- * The lifetime of @p scales is the same as that of the @p attr to which it
- * belongs, so it is illegal to use @p scales after @p attr is destroyed.
- */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_output_scales(
- const_mkldnn_primitive_attr_t attr, mkldnn_dim_t *count, int *mask,
- const float **scales);
-
-/** Sets output @p scales for primitive operations. The number of elements @p
- * count and correspondence scale @p mask are stored for future use.
- *
- * The @p mask argument defines the correspondence between the output tensor
- * dimensions and the @p scales array. Set the i-th bit of @p mask to 1 to use a
- * dedicated scaling factor for each slice of the output tensor over the i-th
- * dimension. Set @p mask to 0 to use a common scaling factor for the whole
- * output tensor.
- *
- * @note
- * The dimension order is always native and does not depend on the actual
- * layout used. Examples:
- * - 2D dimensional data the order of dimensions is always: (n, c)
- * - 4D dimensional data the order is always: (n, c, h, w)
- * - 5D dimensional weights the order is always: (g, oc, ic, kh, kw)
- *
- * Example usage:
- * @code
- * int mb = 32, oc = 32, oh = 14, ow = 14; // convolution output params
- * float scales[oc] = { ... }; // unique output scales per output channel
- * int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ...
- *
- * mkldnn_convolution_desc_t cd; // create & configure convolution op_desc
- *
- * mkldnn_primitive_attr_t attr;
- * mkldnn_primitive_attr_create(&attr); // create default attributes
- * mkldnn_primitive_attr_set_output_scales(attr, oc, 1 << oc_dim, scales);
- *
- * mkldnn_primitive_desc_t cpd;
- * mkldnn_primitive_desc_create(&cpd, &cd, attr, NULL);
- * @endcode
- *
- * @note
- * There is no way to check that @p count corresponds to @p mask until an
- * actual primitive descriptor is created, so it is the user's
- * responsibility to set proper values. The following formula must hold:
- *
- * \f[count = \prod\limits_{d \in mask} output.dims[d]\f]
- */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_output_scales(
- mkldnn_primitive_attr_t attr, mkldnn_dim_t count, int mask,
- const float *scales);
-
-/** Returns @p post_ops for given @p attr.
- *
- * @warning
- * @p post_ops points to the internal @p attr field, so the user should not
- * modify or destroy @p post_ops. Also, the lifetime of @p post_ops is the
- * same as that of the @p attr it belongs to, so it is illegal to use @p
- * post_ops after @p attr has been destroyed.
- */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_post_ops(
- const_mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t *post_ops);
-
-/** Sets configured @p post_ops to an attribute @p attr for future use (when
- * primitive descriptor is being created).
- *
- * @note
- * At this point in time, there is no way to check whether the primitive
- * descriptor does or does not support a given sequence of post operations.
- * Therefore the user should handle an error that might occur at the
- * mkldnn_primitive_desc_create call.
- */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_post_ops(
- mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t post_ops);
-
-/** @addtogroup c_api_attributes_post_ops Sequence of post operations
- * An extension for performing extra operations after a base operation.
- * @{ */
-
-/** Creates an empty sequence of post operations @p post_ops. */
-mkldnn_status_t MKLDNN_API mkldnn_post_ops_create(mkldnn_post_ops_t *post_ops);
-
-/** Deletes a @p post_ops sequence. */
-mkldnn_status_t MKLDNN_API mkldnn_post_ops_destroy(mkldnn_post_ops_t post_ops);
-
-/** Returns the @p length of post operations for given @p post_ops. */
-int MKLDNN_API mkldnn_post_ops_len(const_mkldnn_post_ops_t post_ops);
-
-/** Returns the type of post operation with index @p index in given
- * @p post_ops. In case of error, returns #mkldnn_undefined_primitive. */
-mkldnn_primitive_kind_t MKLDNN_API mkldnn_post_ops_get_kind(
- const_mkldnn_post_ops_t post_ops, int index);
-
-/** Appends accumulation (sum) post operation to the @p post_ops. Prior to
- * accumulating the result, the previous value would be multiplied by @p scale.
- *
- * The kind of this post operation is #mkldnn_sum.
- *
- * This feature might improve performance for cases like residual learning
- * blocks, where the result of convolution is accumulated to the previously
- * computed activations. The parameter @p scale might be extreme for the
- * integer-based computations when the result and previous activations have
- * different logical scaling factors.
- *
- * In the simplest case when the accumulation is the only post operation, the
- * computations would be:
- * dst[] <- scale * dst[] + op(...) // instead of dst[] <- op(...)
- *
- * @note
- * This post operation (as well as all the others) disregards the original
- * layout of the destination; that is, the layout of the original
- * destination is expected to be the same as the layout of the stored
- * destination.
- */
-mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_sum(
- mkldnn_post_ops_t post_ops, float scale);
-
-/** Gets the parameters of the accumulation (sum) post operation with index
- * @p index in the sequence of @p post_ops.
- *
- * @note
- * If index @p index would not correspond to the accumulation post
- * operation, the function returns #mkldnn_invalid_arguments.
- */
-mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_sum(
- const_mkldnn_post_ops_t post_ops, int index, float *scale);
-
-/** Appends eltwise post operation to the @p post_ops with given parameters
- * @p kind, @p alpha, and @p beta (@sa mkldnn_eltwise_forward_desc_init and
- * mkldnn_eltwise_desc_t).
- *
- * The kind of this post operation is #mkldnn_eltwise.
- *
- * In the simplest case when the eltwise is the only post operation, the
- * computations would be:
- * dst[] <- scale * eltwise_op ( op(...) ) // instead of dst[] <- op(...)
- * where eltwise_op is configured with the given parameters.
- */
-mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_eltwise(
- mkldnn_post_ops_t post_ops, float scale, mkldnn_alg_kind_t alg,
- float alpha, float beta);
-
-/** Gets the eltwise parameters of the post operation with index @p index in
- * the sequence of @p post_ops.
- */
-mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_eltwise(
- const_mkldnn_post_ops_t post_ops, int index, float *scale,
- mkldnn_alg_kind_t *alg, float *alpha, float *beta);
-
-/** @} */
-
-/** @} */
-
-/** @addtogroup c_api_memory Memory
- * A primitive to describe and store data.
- *
- * The library supports various data types and formats. Memory hierarchy
- * consists of three levels of abstraction:
- * 1. **Memory descriptor** -- engine agnostic logical description of data
- * (number of dimensions, dimensions themselves, and data type), and
- * optionally the format/layout that describes the physical representation
- * of data in memory. If the format is not known yet, one can pass
- * #mkldnn_format_tag_any. This approach is used to allow compute-intensive
- * primitives to specify the most appropriate format on their own with
- * users required to reorder the data if the incoming format doesn't match
- * the primitive's selection. Memory descriptor can be initialized with
- * mkldnn_memory_desc_init_by_tag() or mkldnn_memory_desc_init_by_strides()
- * functions, or by directly filling the mkldnn_memory_desc_t structure.
- * The latter requires deep knowledge of how the physical data
- * representation is mapped to the structure.
- * The @ref understanding_memory_formats topic should shed some light on
- * that.
- * For the fully defined memory descriptors (i.e. where the format kind is
- * not equal to #mkldnn_format_kind_any) a user can the size, using the
- * mkldnn_memory_desc_get_size() function. As described in
- * @ref understanding_memory_formats, the size of data sometimes cannot
- * be computed as the product of dimensions times the size of the data
- * type. So users are encouraged to use this function for better code
- * portability.
- * Two memory descriptors can be compared with mkldnn_memory_desc_equal().
- * The comparison is especially useful when checking whether a primitive
- * requires reorder from the user's data format to the primitive's format.
- * 2. **Memory** -- an engine-specific object that handles the data and its
- * description (a memory descriptor). For CPU enigne, the data handle is
- * simply a pointer to @c void. The data handle can be queried using
- * mkldnn_memory_get_data_handle() and set using
- * mkldnn_memory_set_data_handle(). The latter function always sets the
- * memory in the padding region to zero, which is the invariant maintained
- * by all the primitives in Intel MKL-DNN.
- * See @ref understanding_memory_formats for more details.
- * A memory can be created using mkldnn_memory_create() function.
- * A memory can also be queried for the underlying memory descriptor and
- * engine using mkldnn_memory_get_memory_desc() and
- * mkldnn_memory_get_engine() functions.
- *
- * Along with ordinary memory with all dimensions being positive, Intel
- * MKL-DNN supports *zero-volume* memory with one or more dimensions set to
- * zero. This is to support the NumPy\* convention.
- * If a *zero-volume* memory is passed to a primitive, the primitive does
- * not perform any computations on this memory. For example:
- * - Convolution with `(0 batch, 3 input channels, 13 height, 13 width)`
- * source and `(16 output channels, 3 inputs, channel, 3 height, 3 width)`
- * weights would produce `(0 batch, 16 output channels, 11 height, 11 width)`
- * destination (assuming strides are `1` and paddings are zero) and perform
- * zero multiply-add operations.
- * - Concatenation of three memories of shapes `(3, 4, 13, 13)`,
- * `(3, 0, 13, 13)`, and `(3, 1, 13, 13)` along the second axis would produce
- * the output of the shape `(3, 5, 13, 13)`, effectively ignoring the second
- * input (however, if the user created a concatenation primitive descriptor
- * with three inputs they should also provide all three memories to the
- * concatenation primitive, including the one with zero second dimension).
- * - However, Intel MKL-DNN would return an error when attempting to create a
- * convolution with *zero-volume* memory passed for weights because such a
- * convolution is not well-defined:
- * ~~~
- * dst(1, 16, 11, 11) <-- src(1, 0, 13, 13) (*) wei(16, 0, 3, 3)
- * ~~~
- * Should the values in the destination be zeroes or just not accessed at
- * all? Moreover, backward pass w.r.t. weights in such cases is also not
- * well-defined.
- *
- * Data handle of *zero-volume* memory is never accessed and hence can be
- * unset (NULL in case of CPU engine).
- *
- * @sa @ref understanding_memory_formats
- * @{ */
-
-/** Initializes a @p memory_desc memory descriptor using @p ndims, @p dims, @p
- * data_type, and @p strides.
- *
- * The @p strides might be NULL, which means the order of physical dimensions
- * is the same as the order of logical ones.
- *
- * @note The logical order of dimensions is defined by a primitive that
- * consumes the memory.
- */
-mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_by_strides(
- mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims,
- mkldnn_data_type_t data_type, const mkldnn_dims_t strides);
-
-/** Initializes a @p memory_desc memory descriptor using @p ndims, @p dims, @p
- * data_type, and format @p tag.
- *
- * @p tag can be #mkldnn_format_tag_any, which allows a primitive to define
- * the appropriate memory format. In this case, the @p format_kind would be set
- * to #mkldnn_format_kind_any */
-mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_by_tag(
- mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims,
- mkldnn_data_type_t data_type, mkldnn_format_tag_t tag);
-
-/** Initializes a @p memory_desc for a given @p parent_memory_desc, with
- * @p dims sizes and @p offsets. May fail if layout used does not allow
- * obtain desired submemory. In this case consider using `extract` or `insert`
- * primitive */
-mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_submemory(
- mkldnn_memory_desc_t *memory_desc,
- const mkldnn_memory_desc_t *parent_memory_desc,
- const mkldnn_dims_t dims, const mkldnn_dims_t offsets);
-
-/** Compares two memory descriptors.
- * @return 1 if the descriptors are the same.
- * @return 0 if the descriptors are different.
- *
- * Use this function to identify whether a reorder is required between the
- * two memories */
-int MKLDNN_API mkldnn_memory_desc_equal(
- const mkldnn_memory_desc_t *lhs,
- const mkldnn_memory_desc_t *rhs);
-
-/** Returns the size (in bytes) that is required for given @p memory_desc */
-size_t MKLDNN_API mkldnn_memory_desc_get_size(
- const mkldnn_memory_desc_t *memory_desc);
-
-/** Creates a memory for given @p memory_desc and @p engine. Also sets handle
- * to @p native_handle.
- * The @p native_handle can:
- * - point to the user allocated memory, i.e. valid handle. In this case the
- * library doesn't own allocated memory.
- * - be MKLDNN_NATIVE_HANDLE_ALLOCATE to ask the library to allocate and
- * attach memory. In this case the library owns allocated memory.
- * - be MKLDNN_NATIVE_HANDLE_NONE to create mkldnn_memory w/o attached memory.
- */
-mkldnn_status_t MKLDNN_API mkldnn_memory_create(mkldnn_memory_t *memory,
- const mkldnn_memory_desc_t *memory_desc, mkldnn_engine_t engine,
- void *native_handle);
-
-/** Returns a @p memory_desc associated with @p memory. */
-mkldnn_status_t MKLDNN_API mkldnn_memory_get_memory_desc(
- const_mkldnn_memory_t memory,
- const mkldnn_memory_desc_t **memory_desc);
-
-/** Returns an @p engine associated with @p memory. */
-mkldnn_status_t MKLDNN_API mkldnn_memory_get_engine(
- const_mkldnn_memory_t memory, mkldnn_engine_t *engine);
-
-/** For a @p memory, returns the data @p handle.
- *
- * For the CPU engine, the data handle is a pointer to the actual data. */
-mkldnn_status_t MKLDNN_API mkldnn_memory_get_data_handle(
- const_mkldnn_memory_t memory, void **handle);
-
-/** For a @p memory, sets the data @p handle. */
-mkldnn_status_t MKLDNN_API mkldnn_memory_set_data_handle(
- mkldnn_memory_t memory, void *handle);
-
-/** Deletes a @p memory. */
-mkldnn_status_t MKLDNN_API mkldnn_memory_destroy(mkldnn_memory_t memory);
-
-/** @} */
-
-/** @addtogroup c_api_reorder Reorder
- * A primitive to copy data between memory formats.
- * @{ */
-
-/** Initializes a @p reorder_primitive_desc using the description of the source
- * (@p src_engine and @p src_md) and destination (@p dst_engine and @p dst_md)
- * memory, and an @p attr attribute.
- *
- * Inputs:
- * - input (#mkldnn_query_src_md, 0)
- *
- * Outputs:
- * - output (#mkldnn_query_dst_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create(
- mkldnn_primitive_desc_t *reorder_primitive_desc,
- mkldnn_engine_t src_engine, const mkldnn_memory_desc_t *src_md,
- mkldnn_engine_t dst_engine, const mkldnn_memory_desc_t *dst_md,
- const_mkldnn_primitive_attr_t attr);
-
-/** @} */
-
-/** @addtogroup c_api_concat Concat
- * A primitive to concatenate data by arbitrary dimension.
- * @{ */
-
-/** Creates out-of-place @p concat_primitive_desc for concatenation of @p n
- * inputs by @p concat_dimension with resulting @p output_desc memory
- * descriptor. @p output_desc can be NULL or specified with the
- * #mkldnn_format_kind_any format kind -- in this case, the appropriate memory
- * format would be chosen automatically.
- *
- * Inputs:
- * - input 0 (#mkldnn_query_src_md, 0)
- * - input 1 (#mkldnn_query_src_md, 1)
- * - ...
- * - input @p n - 1 (#mkldnn_query_src_md, @p n - 1)
- *
- * Outputs:
- * - output (#mkldnn_query_dst_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_concat_primitive_desc_create(
- mkldnn_primitive_desc_t *concat_primitive_desc,
- const mkldnn_memory_desc_t *dst_md,
- int n, int concat_dimension,
- const mkldnn_memory_desc_t *src_mds,
- const_mkldnn_primitive_attr_t attr,
- mkldnn_engine_t engine);
-
-/** @} */
-
-/** @addtogroup c_api_sum Sum
- * A primitive to sum data.
- * @{ */
-
-/** Creates out-of-place @p sum_primitive_desc for sum of @p n
- * inputs multiplied by scale with resulting @p output_desc memory
- * descriptor. @p output_desc can be NULL or specified with the
- * #mkldnn_format_kind_any format kind -- in this case, the appropriate memory
- * format would be chosen automatically.
- *
- * Inputs:
- * - src 0 (#mkldnn_query_src_md, 0)
- * - src 1 (#mkldnn_query_src_md, 1)
- * - ...
- * - src @p n - 1 (#mkldnn_query_src_md, @p n - 1)
- *
- * Outputs:
- * - output (#mkldnn_query_dst_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_sum_primitive_desc_create(
- mkldnn_primitive_desc_t *sum_primitive_desc,
- const mkldnn_memory_desc_t *dst_mds,
- int n, const float *scales,
- const mkldnn_memory_desc_t *src_mds,
- const_mkldnn_primitive_attr_t attr,
- mkldnn_engine_t engine);
-
-/** @} */
-
-/** @addtogroup c_api_convolution Convolution
- * A primitive to compute convolution using different algorithms.
- *
- * \f[dst[n][oc][oh][ow] =
- * \sum_{kw=0}^{KW}\sum_{kh=0}^{KH}\sum_{ic=0}^{IC}
- * src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw]
- * \cdot weights[g][oc][ic][kh][kw]
- * + bias[g][oc],\f]
- *
- * where size of output spatial domain is given by
- * \f$ OH = \left\lfloor{\frac{IH - KH + p_l[0] + p_r[0]}{s_h}}
- * \right\rfloor + 1\f$,
- * \f$ OW = \left\lfloor{\frac{IW - KW + p_l[1] + p_r[1]}{s_w}}
- * \right\rfloor + 1\f$,
- *
- * and summation is carried over input channels \f$ic\f$ in
- * group \f$g\f$, and \f$s_h, s_w\f$ are @p strides and
- * \f$p_l, p_r\f$ are @p padding_l and @p padding_r.
- * @{ */
-
-/** Initializes a convolution descriptor @p conv_desc for forward propagation
- * using @p prop_kind (possible values are #mkldnn_forward_training and
- * #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, @p
- * padding_l, @p padding_r, and @p padding_kind. In order to create a
- * convolution without bias, @p bias_desc should either be @c NULL or point to
- * a descriptor with memory format kind equal to #mkldnn_format_kind_undef.
- *
- * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric.
- *
- * @note Memory descriptors are allowed to be initialized with
- * #mkldnn_format_kind_any value of @p format_kind.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- * - weights (#mkldnn_query_weights_md, 0)
- * - bias (#mkldnn_query_weights_md, 1), if created with bias
- *
- * Outputs:
- * - dst (#mkldnn_query_dst_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_convolution_forward_desc_init(
- mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind,
- mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc,
- const mkldnn_memory_desc_t *weights_desc,
- const mkldnn_memory_desc_t *bias_desc,
- const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides,
- const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r,
- mkldnn_padding_kind_t padding_kind);
-
-/** Initializes a dilated convolution descriptor @p conv_desc for forward
- * propagation using @p prop_kind (possible values are #mkldnn_forward_training
- * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides,
- * @p dilates, @p padding_l, @p padding_r, and @p padding_kind.
- * In order to create a dilated convolution without bias, @p bias_desc
- * should either be @c NULL or point to a descriptor with memory format kind
- * equals #mkldnn_format_kind_undef.
- *
- * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric.
- *
- * @note Memory descriptors are allowed to be initialized with
- * #mkldnn_format_kind_any value of @p format_kind.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- * - weights (#mkldnn_query_weights_md, 0)
- * - bias (#mkldnn_query_weights_md, 1), if created with bias
- *
- * Outputs:
- * - dst (#mkldnn_query_dst_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_forward_desc_init(
- mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind,
- mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc,
- const mkldnn_memory_desc_t *weights_desc,
- const mkldnn_memory_desc_t *bias_desc,
- const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides,
- const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
- const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
-
-/** Initializes a convolution descriptor @p conv_desc for backward propagation
- * with respect to data using @p alg_kind, memory descriptors, @p strides, @p
- * padding_l, @p padding_r, and @p padding_kind.
- *
- * @note Memory descriptors are allowed to be initialized with
- * #mkldnn_format_kind_any value of @p format_kind.
- *
- * Inputs:
- * - diff_dst (#mkldnn_query_diff_dst_md, 0)
- * - weights (#mkldnn_query_weights_md, 0)
- *
- * Outputs:
- * - diff_src (#mkldnn_query_diff_src_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_data_desc_init(
- mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
- const mkldnn_memory_desc_t *diff_src_desc,
- const mkldnn_memory_desc_t *weights_desc,
- const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
- const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r,
- mkldnn_padding_kind_t padding_kind);
-
-/** Initializes a dilated convolution descriptor @p conv_desc for backward
- * propagation with respect to data using @p alg_kind, memory descriptors, @p
- * strides, @p dilates @p padding_l, @p padding_r, and @p padding_kind.
- *
- * @note Memory descriptors are allowed to be initialized with
- * #mkldnn_format_kind_any value of @p format_kind.
- *
- * Inputs:
- * - diff_dst (#mkldnn_query_diff_dst_md, 0)
- * - weights (#mkldnn_query_weights_md, 0)
- *
- * Outputs:
- * - diff_src (#mkldnn_query_diff_src_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_data_desc_init(
- mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
- const mkldnn_memory_desc_t *diff_src_desc,
- const mkldnn_memory_desc_t *weights_desc,
- const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
- const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
- const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
-
-/** Initializes a convolution descriptor @p conv_desc for backward propagation
- * with respect to weights using @p alg_kind, memory descriptors, @p strides,
- * @p padding_l, @p padding_r, and @p padding_kind.
- *
- * @note Memory descriptors are allowed to be initialized with
- * #mkldnn_format_kind_any value of @p format_kind.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- * - diff_dst (#mkldnn_query_diff_dst_md, 0)
- *
- * Outputs:
- * - diff_weights (#mkldnn_query_diff_weights_md, 0)
- * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias
- */
-mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_weights_desc_init(
- mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
- const mkldnn_memory_desc_t *src_desc,
- const mkldnn_memory_desc_t *diff_weights_desc,
- const mkldnn_memory_desc_t *diff_bias_desc,
- const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
- const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r,
- mkldnn_padding_kind_t padding_kind);
-
-/** Initializes a convolution descriptor @p conv_desc for backward propagation
- * with respect to weights using @p alg_kind, memory descriptors, @p strides,
- * @p dilates @p padding_l, @p padding_r, and @p padding_kind.
- *
- * @note Memory descriptors are allowed to be initialized with
- * #mkldnn_format_kind_any value of @p format_kind.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- * - diff_dst (#mkldnn_query_diff_dst_md, 0)
- *
- * Outputs:
- * - diff_weights (#mkldnn_query_diff_weights_md, 0)
- * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias
- */
-mkldnn_status_t MKLDNN_API
-mkldnn_dilated_convolution_backward_weights_desc_init(
- mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
- const mkldnn_memory_desc_t *src_desc,
- const mkldnn_memory_desc_t *diff_weights_desc,
- const mkldnn_memory_desc_t *diff_bias_desc,
- const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
- const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
- const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
-
-/** @} */
-
-/** @addtogroup c_api_deconvolution Deconvolution
- * A primitive to compute deconvolution using different algorithms.
- *
- * @{ */
-
-
-/** Initializes a deconvolution descriptor @p deconv_desc for forward
- * propagation using @p prop_kind (possible values are #mkldnn_forward_training
- * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides,
- * @p padding_l, @p padding_r, and @p padding_kind. In order to create a
- * deconvolution without bias, @p bias_desc should either be @c NULL or point to
- * a descriptor with memory format kind equals #mkldnn_format_kind_undef.
- *
- * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric.
- *
- * @note Memory descriptors are allowed to be initialized with
- * #mkldnn_format_kind_any value of @p format_kind.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- * - weights (#mkldnn_query_weights_md, 0)
- * - bias (#mkldnn_query_weights_md, 1), if created with bias
- *
- * Outputs:
- * - dst (#mkldnn_query_dst_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_deconvolution_forward_desc_init(
- mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind,
- mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc,
- const mkldnn_memory_desc_t *weights_desc,
- const mkldnn_memory_desc_t *bias_desc,
- const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides,
- const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r,
- mkldnn_padding_kind_t padding_kind);
-
-/** Initializes a dilated deconvolution descriptor @p deconv_desc for forward
- * propagation using @p prop_kind (possible values are #mkldnn_forward_training
- * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides,
- * @p dilates, @p padding_l, @p padding_r, and @p padding_kind. In order to
- * create a dilated deconvolution without bias, @p bias_desc should either be
- * @c NULL or point to a descriptor with memory format kind equal
- * #mkldnn_format_kind_undef.
- *
- * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric.
- *
- * @note Memory descriptors are allowed to be initialized with
- * #mkldnn_format_kind_any value of @p format_kind.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- * - weights (#mkldnn_query_weights_md, 0)
- * - bias (#mkldnn_query_weights_md, 1), if created with bias
- *
- * Outputs:
- * - dst (#mkldnn_query_dst_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_forward_desc_init(
- mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind,
- mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc,
- const mkldnn_memory_desc_t *weights_desc,
- const mkldnn_memory_desc_t *bias_desc,
- const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides,
- const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
- const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
-
-/** Initializes a deconvolution descriptor @p conv_desc for backward propagation
- * with respect to data using @p alg_kind, memory descriptors, @p strides, @p
- * padding_l, @p padding_r, and @p padding_kind.
- *
- * @note Memory descriptors are allowed to be initialized with
- * #mkldnn_format_kind_any value of @p format_kind.
- *
- * Inputs:
- * - diff_dst (#mkldnn_query_diff_dst_md, 0)
- * - weights (#mkldnn_query_weights_md, 0)
- *
- * Outputs:
- * - diff_src (#mkldnn_query_diff_src_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_data_desc_init(
- mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
- const mkldnn_memory_desc_t *diff_src_desc,
- const mkldnn_memory_desc_t *weights_desc,
- const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
- const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r,
- mkldnn_padding_kind_t padding_kind);
-
-/** Initializes a dilated deconvolution descriptor @p conv_desc for backward
- * propagation with respect to data using @p alg_kind, memory descriptors, @p
- * strides, @p dilates, @p padding_l, @p padding_r, and @p padding_kind.
- *
- * @note Memory descriptors are allowed to be initialized with
- * #mkldnn_format_kind_any value of @p format_kind.
- *
- * Inputs:
- * - diff_dst (#mkldnn_query_diff_dst_md, 0)
- * - weights (#mkldnn_query_weights_md, 0)
- *
- * Outputs:
- * - diff_src (#mkldnn_query_diff_src_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_data_desc_init(
- mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
- const mkldnn_memory_desc_t *diff_src_desc,
- const mkldnn_memory_desc_t *weights_desc,
- const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
- const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
- const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
-
-/** Initializes a deconvolution descriptor @p conv_desc for backward propagation
- * with respect to weights using @p alg_kind, memory descriptors, @p strides,
- * @p padding_l, @p padding_r, and @p padding_kind.
- *
- * @note Memory descriptors are allowed to be initialized with
- * #mkldnn_format_kind_any value of @p format_kind.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- * - diff_dst (#mkldnn_query_diff_dst_md, 0)
- *
- * Outputs:
- * - diff_weights (#mkldnn_query_diff_weights_md, 0)
- * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias
- */
-mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_weights_desc_init(
- mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
- const mkldnn_memory_desc_t *src_desc,
- const mkldnn_memory_desc_t *diff_weights_desc,
- const mkldnn_memory_desc_t *diff_bias_desc,
- const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
- const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r,
- mkldnn_padding_kind_t padding_kind);
-
-/** Initializes a dilated deconvolution descriptor @p conv_desc for backward
- * propagation with respect to weights using @p alg_kind, memory descriptors,
- * @p strides, @p dilates, @p padding_l, @p padding_r, and @p padding_kind.
- *
- * @note Memory descriptors are allowed to be initialized with
- * #mkldnn_format_kind_any value of @p format_kind.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- * - diff_dst (#mkldnn_query_diff_dst_md, 0)
- *
- * Outputs:
- * - diff_weights (#mkldnn_query_diff_weights_md, 0)
- * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias
- */
-mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_weights_desc_init(
- mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
- const mkldnn_memory_desc_t *src_desc,
- const mkldnn_memory_desc_t *diff_weights_desc,
- const mkldnn_memory_desc_t *diff_bias_desc,
- const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
- const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
- const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
-
-/** @} */
-
-/** @addtogroup c_api_shuffle Shuffle
- * A primitive to shuffle data along the axis.
- * @{ */
-
-/** Initializes a @p shuffle_desc for forward propagation using @p prop_kind,
- * memory descriptor @p data_desc, @p axis, and @p group_size.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- *
- * Outputs:
- * - dst (#mkldnn_query_dst_md, 0)
- *
- */
-mkldnn_status_t MKLDNN_API mkldnn_shuffle_forward_desc_init(
- mkldnn_shuffle_desc_t *shuffle_desc, mkldnn_prop_kind_t prop_kind,
- const mkldnn_memory_desc_t *data_desc, int axis,
- mkldnn_dim_t group_size);
-
-/** Initializes a @p shuffle_desc for backward propagation using memory
- * descriptor @p diff_data_desc, @p axis, and @p group_size.
- *
- *
- * Inputs:
- * - diff_dst (#mkldnn_query_diff_dst_md, 0)
- *
- * Outputs:
- * - diff_src (#mkldnn_query_diff_src_md, 0)
- *
- */
-mkldnn_status_t MKLDNN_API mkldnn_shuffle_backward_desc_init(
- mkldnn_shuffle_desc_t *shuffle_desc,
- const mkldnn_memory_desc_t *diff_data_desc, int axis,
- mkldnn_dim_t group_size);
-
-/** @} */
-
-/** @addtogroup c_api_eltwise Eltwise
- * A primitive to compute element-wise operations like parametric rectifier
- * linear unit (ReLU).
- *
- * Both forward and backward passes support in-place operation; that is, src
- * and dst point to the same memory for forward pass, and diff_dst and diff_src
- * point to the same memory for backward pass.
- *
- * @warning Because the original src is required for backward pass, in-place
- * forward pass in general cannot be applied during training. However, for some
- * kinds of element-wise operations (namely ReLU with alpha parameter equals 0),
- * dst and src can be interchangeable for the backward pass, which enables
- * performing in-place forward even for training.
- *
- * @{ */
-
-/** Initializes an @p eltwise_desc for forward propagation using @p prop_kind
- * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference),
- * @p alg_kind algorithm, memory descriptor @p data_desc, @p alpha, and
- * @p beta parameters.
- *
- * @sa mkldnn_eltwise_desc_t for details.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- *
- * Outputs:
- * - dst (#mkldnn_query_dst_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_eltwise_forward_desc_init(
- mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_prop_kind_t prop_kind,
- mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc,
- float alpha, float beta);
-
-/** Initializes an @p eltwise_desc for backward propagation using @p alg_kind
- * algorithm memory descriptors @p diff_data_desc and @p data_desc, and the
- * @p alpha and @p beta parameters.
- *
- * @sa mkldnn_eltwise_desc_t for details.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- * - diff_dst (#mkldnn_query_diff_dst_md, 0)
- *
- * Outputs:
- * - diff_src (#mkldnn_query_diff_src_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_eltwise_backward_desc_init(
- mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_alg_kind_t alg_kind,
- const mkldnn_memory_desc_t *diff_data_desc,
- const mkldnn_memory_desc_t *data_desc, float alpha, float beta);
-
-/** @} */
-
-/** @addtogroup c_api_softmax Softmax
- * A primitive to perform softmax.
- *
- * \f[dst[u][c][in] =
- * \frac{\exp(src[ou][c][in]) - \max\limits_{c}(src[ou][c][in])}
- * {\sum\limits_{c}\{\exp(src[ou][c][in])
- * - \max\limits_{c}(src[ou][c][in])\}},\f]
- *
- * where \f$ou, iu\f$ are outer and inner sizes repectively, defined
- * by @p data_desc.dims and @p softmax_axis.
- * @{ */
-
-/** Initializes a @p softmax_desc for forward propagation using @p prop_kind
- * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference)
- * and memory descriptor @p data_desc.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- *
- * Outputs:
- * - dst (#mkldnn_query_dst_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_softmax_forward_desc_init(
- mkldnn_softmax_desc_t *softmax_desc, mkldnn_prop_kind_t prop_kind,
- const mkldnn_memory_desc_t *data_desc, int softmax_axis);
-
-/** Initializes a @p softmax_desc for backward propagation using memory
- * descriptors @p diff_desc and @p data_desc.
- *
- * Inputs:
- * - dst (#mkldnn_query_dst_md, 0)
- * - diff_dst (#mkldnn_query_diff_dst_md, 0)
- *
- * Outputs:
- * - diff_src (#mkldnn_query_diff_src_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_softmax_backward_desc_init(
- mkldnn_softmax_desc_t *softmax_desc,
- const mkldnn_memory_desc_t *diff_desc,
- const mkldnn_memory_desc_t *data_desc, int softmax_axis);
-
-/** @} */
-
-/** @addtogroup c_api_pooling Pooling
- * A primitive to perform max or average pooling.
- *
- * Max pooling:
- * \f[dst[n][oc][oh][ow] =
- * \max\limits_{kw,kh}
- * (src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw]),\f]
- *
- * Average pooling:
- * \f[dst[n][oc][oh][ow] =
- * \frac{1}{KW \cdot KH}\sum\limits_{kw,kh}
- * src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw],\f]
- *
- * where \f$p_l, p_r\f$ are @p padding_l and @p padding_r respectively, and
- * output spatial dimensions are calculated similarly to how they are done in
- * convolution.
- *
- * During training, max pooling requires a workspace on forward
- * (#mkldnn_forward_training) and backward (#mkldnn_backward) passes to
- * save indices where maximum was found. The workspace layout is opaque, and
- * the indices cannot be restored from it. However, one can use backward
- * pooling to perform up-sampling (used in some detection topologies).
- *
- * @{ */
-
-/** Initializes a pooling descriptor @p pool_desc for forward propagation using
- * @p prop_kind (possible values are #mkldnn_forward_training and
- * #mkldnn_forward_inference), @p alg_kind, memory descriptors, and pooling
- * parameters in the spatial domain: @p strides, @p kernel sizes, @p padding_l,
- * @p padding_r, and @p padding_kind.
- *
- * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- *
- * Outputs:
- * - dst (#mkldnn_query_dst_md, 0)
- * - workspace (#mkldnn_query_workspace_md, 0),
- * if @p alg_kind = #mkldnn_pooling_max and
- * @p prop_kind = #mkldnn_forward_training
- */
-mkldnn_status_t MKLDNN_API mkldnn_pooling_forward_desc_init(
- mkldnn_pooling_desc_t *pool_desc, mkldnn_prop_kind_t prop_kind,
- mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc,
- const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides,
- const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l,
- const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
-
-/** Initializes a pooling descriptor @p pool_desc for backward propagation
- * using @p alg_kind, memory descriptors, and pooling parameters in the spatial
- * domain: @p strides, @p kernel sizes, @p padding_l, @p padding_r, and @p
- * padding_kind.
- *
- * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric.
- *
- * Inputs:
- * - diff_dst (#mkldnn_query_diff_dst_md, 0)
- * - workspace (#mkldnn_query_workspace_md, 0),
- * if @p alg_kind = #mkldnn_pooling_max
- *
- * Outputs:
- * - diff_src (#mkldnn_query_diff_src_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_pooling_backward_desc_init(
- mkldnn_pooling_desc_t *pool_desc, mkldnn_alg_kind_t alg_kind,
- const mkldnn_memory_desc_t *diff_src_desc,
- const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
- const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l,
- const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
-
-/** @} */
-
-/** @addtogroup c_api_lrn LRN
- * A primitive to perform local response normalization (LRN) across or within
- * channels.
- *
- * LRN accross channels:
- * \f[dst[n][c][h][w] = \left\{k + \frac{\alpha}{n_{l}}
- * \sum\limits_{i=-(n_{l}-1)/2}^{(n_{l}+1)/2}
- * (src[n][c+i][h][w])^2\right\}^{-\beta}
- * src[n][c][h][w],\f]
- *
- * LRN within channels:
- * \f[dst[n][c][h][w] = \left\{k + \frac{\alpha}{n_{l}}
- * \sum\limits_{i=-(n_{l}-1)/2}^{(n_{l}+1)/2}
- * (src[n][c][h+i][w+i])^2\right\}^{-\beta}
- * src[n][c][h][w],\f]
- *
- * where \f$n_{l}\f$ is the @p local_size.
- *
- * During training, LRN might or might not require a workspace on forward
- * (#mkldnn_forward_training) and backward (#mkldnn_backward) passes. The
- * behavior is implementation specific. Optimized implementations typically
- * require a workspace and use it to save some intermediate results from the
- * forward pass that accelerate computations on the backward pass.
- *
- * To check whether a workspace is required, query the LRN primitive descriptor
- * for the workspace (#mkldnn_query_workspace_md). Success indicates that the
- * workspace is required and its description will be returned.
- * @sa mkldnn_primitive_desc_query and mkldnn_primitive_desc_query_pd
- *
- * @{ */
-
-/** Initializes an @p lrn_desc for forward propagation using @p prop_kind
- * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference),
- * @p alg_kind, memory descriptor @p data_desc, and regularization
- * parameters @p local_size, @p alpha, @p beta, and @p k.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- *
- * Outputs:
- * - dst (#mkldnn_query_dst_md, 0)
- * - workspace (#mkldnn_query_workspace_md, 0),
- * if the underlying implementation requires
- */
-mkldnn_status_t MKLDNN_API mkldnn_lrn_forward_desc_init(
- mkldnn_lrn_desc_t *lrn_desc, mkldnn_prop_kind_t prop_kind,
- mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc,
- mkldnn_dim_t local_size, float alpha, float beta, float k);
-
-/** Initializes an @p lrn_desc for backward propagation using @p alg_kind,
- * memory descriptors @p data_desc and @p diff_data_desc, and regularization
- * parameters @p local_size, @p alpha, @p beta, and @p k.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- * - diff_dst (#mkldnn_query_diff_dst_md, 0)
- * - workspace (#mkldnn_query_workspace_md, 0),
- * if the underlying implementation requires
- *
- * Outputs:
- * - diff_src (#mkldnn_query_diff_src_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_lrn_backward_desc_init(
- mkldnn_lrn_desc_t *lrn_desc, mkldnn_alg_kind_t alg_kind,
- const mkldnn_memory_desc_t *diff_data_desc,
- const mkldnn_memory_desc_t *data_desc, mkldnn_dim_t local_size,
- float alpha, float beta, float k);
-
-/** @} */
-
-/** @addtogroup c_api_batch_normalization Batch Normalization
- * A primitive to perform batch normalization.
- *
- * \f[dst[n][c][h][w] = \gamma[c] \frac{src[n][c][h][w] - \mu[c]}
- * {\sqrt{\sigma[c] + eps}} + \beta[c],\f]
- *
- * where \f$\gamma[c], \beta[c]\f$ are weights and bias for a channel and,
- *
- * \f$\mu[c] = \frac{1}{NHW} \sum\limits_{whn} src[n][c][h][w]\f$,
- * \f$\sigma[c] = \frac{1}{NHW} \sum\limits_{whn}
- * (src[n][c][h][w] - \mu[c])^2\f$,
- *
- * and @c eps is a constant to improve numerical stability.
- *
- * Both forward and backward passes support in-place operation; that is, src
- * and dst point to the same memory for forward pass, and diff_dst and diff_src
- * point to the same memory for backward pass.
- *
- * Batch normalization supports different flavors controlled by
- * mkldnn_batch_normalization_desc_t. For example, batch normalization can
- * compute the mean and variance on its own or take them as inputs. It can
- * either perform scaling and shifting using gamma and beta parameters or not.
- * Optionally it can also perform a fused ReLU, which in case of training would
- * also require a workspace.
- *
- * @sa mkldnn_batch_normalization_desc_t
- * @{ */
-
-/** Initializes a batch normalization descriptor @p bnrm_desc for forward
- * propagation using @p prop_kind (possible values are
- * #mkldnn_forward_training and #mkldnn_forward_inference), memory descriptor
- * @p data_desc, normalization parameter @p epsilon, and @p flags set using bit
- * flags of type mkldnn_batch_normalization_desc_t.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- * - mean (#mkldnn_query_src_md, 1),
- * if #mkldnn_use_global_stats bit-flags is set in @p flags
- * - variance (#mkldnn_query_src_md, 2),
- * if #mkldnn_use_global_stats bit-flags is set in @p flags
- * - scale_and_shift (#mkldnn_query_weights_md, 0),
- * if #mkldnn_use_scaleshift bit-flags is set in @p flags
- *
- * Outputs:
- * - dst (#mkldnn_query_dst_md, 0)
- * - mean (#mkldnn_query_dst_md, 1),
- * if #mkldnn_use_global_stats bit-flags is not set in @p flags
- * @p prop_kind = #mkldnn_forward_training
- * - variance (#mkldnn_query_dst_md, 2),
- * if #mkldnn_use_global_stats bit-flags is not set in @p flags
- * and @p prop_kind = #mkldnn_forward_training
- * - workspace (#mkldnn_query_workspace_md, 0),
- * if #mkldnn_fuse_bn_relu bit-flags is set in @p flags
- * and @p prop_kind = #mkldnn_forward_training
- *
- * @note In-place operation is supported; that is, dst points to the same memory
- * as src.
- *
- * @sa mkldnn_batch_normalization_desc_t
- */
-mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_forward_desc_init(
- mkldnn_batch_normalization_desc_t *bnrm_desc,
- mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc,
- float epsilon, unsigned flags);
-
-/** Initializes a batch normalization descriptor @p bnrm_desc for backward
- * propagation with respect to data and scale-shift parameters using memory
- * descriptors @p data_desc and @p diff_data_desc, normalization parameter
- * @p epsilon, and @p flags set using bit flags of type
- * mkldnn_batch_normalization_desc_t.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- * - mean (#mkldnn_query_src_md, 1)
- * - variance (#mkldnn_query_src_md, 2)
- * - diff_dst (#mkldnn_query_diff_dst_md, 0)
- * - scale_and_shift (#mkldnn_query_weights_md, 0),
- * if #mkldnn_use_scaleshift bit-flags is set in @p flags
- * - workspace (#mkldnn_query_workspace_md, 0),
- * if #mkldnn_fuse_bn_relu bit-flags is set in @p flags
- *
- * Outputs:
- * - diff_src (#mkldnn_query_diff_src_md, 0)
- * - diff_scale_and_shift (#mkldnn_query_diff_weights_md, 0),
- * if #mkldnn_use_scaleshift bit-flags is set in @p flags
- * and @p prop_kind = #mkldnn_backward
- *
- * @note in-place operation is supported,
- * i.e. diff_src points to the same memory as diff_dst.
- *
- * @sa mkldnn_batch_normalization_desc_t
- */
-mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_backward_desc_init(
- mkldnn_batch_normalization_desc_t *bnrm_desc,
- mkldnn_prop_kind_t prop_kind,
- const mkldnn_memory_desc_t *diff_data_desc,
- const mkldnn_memory_desc_t *data_desc,
- float epsilon, unsigned flags);
-
-/** @} */
-
-/** @addtogroup c_api_inner_product Inner product
- * A primitive to compute an inner product.
- *
- * Inner product layer is also known as fully connected layer.
- * With spatial dimension:
- *
- * \f[dst[n][oc] = \sum\limits_{ic, kh, kw}
- * src[n][ic][kh][kw] \cdot weights[oc][ic][kh][kw]
- * + bias[oc]\f]
- * @{ */
-
-/** Initializes an inner product descriptor @p ip_desc for forward propagation
- * using @p prop_kind (possible values are #mkldnn_forward_training and
- * #mkldnn_forward_inference) and memory descriptors. In order to create an
- * inner product without bias, @p bias_desc should be either @c NULL or a
- * pointer to a descriptor with memory format kind equals
- * #mkldnn_format_kind_undef.
- *
- * @note Memory descriptors are allowed to be initialized with
- * #mkldnn_format_kind_any value of @p format_kind.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- * - weights (#mkldnn_query_weights_md, 0)
- * - bias (#mkldnn_query_weights_md, 1), if created with bias
- *
- * Outputs:
- * - dst (#mkldnn_query_dst_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_inner_product_forward_desc_init(
- mkldnn_inner_product_desc_t *ip_desc, mkldnn_prop_kind_t prop_kind,
- const mkldnn_memory_desc_t *src_desc,
- const mkldnn_memory_desc_t *weights_desc,
- const mkldnn_memory_desc_t *bias_desc,
- const mkldnn_memory_desc_t *dst_desc);
-
-/** Initializes an inner product descriptor @p ip_desc for backward propagation
- * with respect to data using memory descriptors.
- *
- * @note Memory descriptors are allowed to be initialized with
- * #mkldnn_format_kind_any value of @p format_kind.
- *
- * Inputs:
- * - diff_dst (#mkldnn_query_diff_dst_md, 0)
- * - weights (#mkldnn_query_weights_md, 0)
- *
- * Outputs:
- * - diff_src (#mkldnn_query_diff_src_md, 0)
- */
-mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_data_desc_init(
- mkldnn_inner_product_desc_t *ip_desc,
- const mkldnn_memory_desc_t *diff_src_desc,
- const mkldnn_memory_desc_t *weights_desc,
- const mkldnn_memory_desc_t *diff_dst_desc);
-
-/** Initializes an inner product descriptor @p ip_desc for backward propagation
- * with respect to weights using memory descriptors.
- *
- * @note Memory descriptors are allowed to be initialized with
- * #mkldnn_format_kind_any value of @p format_kind.
- *
- * Inputs:
- * - src (#mkldnn_query_src_md, 0)
- * - diff_dst (#mkldnn_query_diff_dst_md, 0)
- *
- * Outputs:
- * - diff_weights (#mkldnn_query_diff_weights_md, 0)
- * - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias
- */
-mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_weights_desc_init(
- mkldnn_inner_product_desc_t *ip_desc,
- const mkldnn_memory_desc_t *src_desc,
- const mkldnn_memory_desc_t *diff_weights_desc,
- const mkldnn_memory_desc_t *diff_bias_desc,
- const mkldnn_memory_desc_t *diff_dst_desc);
-
-/** @} */
-
-/** @addtogroup c_api_rnn RNN
- * A primitive to compute the common recurrent layer.
- * @todo add additional description for the group
- * @{ */
-
-/**
- * Initializes a recurrent cell descriptor @p rnn_cell_desc
- * using @p rnn_cell_desc, @p kind (possible values are
- * #mkldnn_vanilla_rnn, #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru, and
- * #mkldnn_gru_linear_before_reset),
- * @p f (possible values are #mkldnn_eltwise_relu and
- * #mkldnn_eltwise_tanh), @p flags, @p alpha, and @p clipping.
- */
-mkldnn_status_t MKLDNN_API mkldnn_rnn_cell_desc_init(
- mkldnn_rnn_cell_desc_t *rnn_cell_desc,
- mkldnn_alg_kind_t kind, mkldnn_alg_kind_t f,
- unsigned int flags, float alpha, float clipping);
-
-/** Returns the number of gates of a particular @p rnn_cell_desc. */
-int MKLDNN_API mkldnn_rnn_cell_get_gates_count(
- const mkldnn_rnn_cell_desc_t *rnn_cell_desc);
-
-/** Returns the number of states of a particular @p rnn_cell_desc. */
-int MKLDNN_API mkldnn_rnn_cell_get_states_count(
- const mkldnn_rnn_cell_desc_t *rnn_cell_desc);
-
-/** Sets quantization @p scale and @p shift for RNN data tensors.
- * For performance reasons, low precision configuration of RNN primitive
- * expects input activations to have unsigned int8 data type. Scale and shift
- * used to quantize floating point data to unsigned integer must be passed to
- * RNN primitive using attributes.
- * Example usage:
- * @code
- * // rnn parameters
- * int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
- * // activations quantization parameters
- * float scale = ..., shift = ..;
- *
- * mkldnn_primitive_attr_t rnn_attr;
- * // create default attributes
- * mkldnn_primitive_attr_create(&rnn_attr);
- *
- * // set scale and shift for int8 quantization of activation
- * mkldnn_primitive_attr_set_rnn_data_qparams(rnn_attr, scale, shift);
- *
- * // create & configure rnn op_desc
- * mkldnn_rnn_desc_t rnn_d;
- * mkldnn_primitive_desc_t rnn_pd;
- * mkldnn_primitive_desc_create(&rnn_pd, &rnn_d, attr, engine, NULL);
- * @endcode
- * @note
- * Quantization scale and shift are common for src_layer, src_iter,
- * dst_iter and dst_layer.
- */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_data_qparams(
- mkldnn_primitive_attr_t attr, const float scale, const float shift);
-
-/** Sets quantization scales @p weights_scales for RNN weights tensors.
- * Low precision configuration of RNN primitive expects input weights to have
- * signed int8 data type. Scales used to quantize floating point data
- * to signed integer must be passed to RNN primitive using attributes.
- * The @p mask argument defines correspondence between output tensor dimensions
- * and the @p weights_scales array. Set i-th bit of @p mask to 1 to use
- * dedicated scaling factor for each slice of the output tensor over i-th
- * dimension. Set @p mask to 0 to use common scaling factor for the whole output
- * tensor. Example usage:
- * @code
- * // rnn parameters
- * int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
- * // unique output scales per output channel
- * float weights_scales[dic * n_gates] = { ... };
- * // mask that specifies last two dimensions of ldigo format
- * int mask = 0x3;
- *
- * mkldnn_primitive_attr_t attr;
- * // create default attributes
- * mkldnn_primitive_attr_create(&attr);
- *
- * // set output channel-wise weights scales
- * mkldnn_primitive_attr_set_rnn_weights_qparams(attr, dic * n_gates, mask,
- * weights_scales);
- *
- * // create & configure rnn op_desc
- * mkldnn_rnn_desc_t rnn_d;
- * mkldnn_primitive_desc_t rnn_pd;
- * mkldnn_primitive_desc_create(&rnn_pd, &rnn_d, attr, engine, NULL);
- * @endcode
- * @note
- * The dimension order is always native and does not depend on the actual
- * layout used. For example, 5 dimensional weights always have
- * (l, d, i, g, o) logical dimension ordering.
- * @note
- * Quantization sales are common for weights_layer and weights_iteration
- * @note
- * There is no way to check that @p count corresponds to @p mask until an
- * actual primitive descriptor is created, so it is user's responsibility
- * to set proper values. The following formula must be held:
- *
- * \f[count = \prod\limits_{d \in mask} output.dims[d]\f]
- */
-mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_weights_qparams (
- mkldnn_primitive_attr_t attr, mkldnn_dim_t count, int mask,
- const float *weights_scales);
-
-/** Initializes a rnn descriptor @p rnn_desc for forward propagation
- * using @p prop_kind, @p rnn_cell_desc, @p direction, and memory descriptors.
- * @note If @p prop_kind equals #mkldnn_forward_training, you must query a
- * workspace memory descriptor before creating the primitive.
- *
- * @p src_iter_desc, @p bias_desc, and @p dst_iter_desc are allowed to either be
- * @c NULL or point to a zero memory descriptor, which would indicate that the
- * RNN primitive should not use them.
- *
- * @note All memory descriptors except @p src_iter_desc are allowed to be
- * initialized with #mkldnn_format_kind_any value of @p format_kind.
- *
- * Inputs:
- * - src_layer (#mkldnn_query_src_md, 0)
- * - src_iter (#mkldnn_query_src_md, 1), if used
- * - weights_layer (#mkldnn_query_weights_md, 0)
- * - weights_iter (#mkldnn_query_weights_md, 1)
- * - bias (#mkldnn_query_weights_md, 2), if used
- *
- * Outputs:
- * - dst_layer (#mkldnn_query_dst_md, 0)
- * - dst_iter (#mkldnn_query_dst_md, 1), if used
- * - workspace (#mkldnn_query_workspace_md, 0),
- * if @p prop_kind equals #mkldnn_forward_training
- */
-mkldnn_status_t MKLDNN_API mkldnn_rnn_forward_desc_init(
- mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind,
- const mkldnn_rnn_cell_desc_t *rnn_cell_desc,
- const mkldnn_rnn_direction_t direction,
- const mkldnn_memory_desc_t *src_layer_desc,
- const mkldnn_memory_desc_t *src_iter_desc,
- const mkldnn_memory_desc_t *weights_layer_desc,
- const mkldnn_memory_desc_t *weights_iter_desc,
- const mkldnn_memory_desc_t *bias_desc,
- const mkldnn_memory_desc_t *dst_layer_desc,
- const mkldnn_memory_desc_t *dst_iter_desc);
-
-/** Initializes a rnn descriptor @p rnn_desc for backward propagation
- * using @p prop_kind, @p rnn_cell_desc, @p direction, and memory descriptors.
- *
- * @note All memory descriptors are allowed to be initialized with
- * #mkldnn_format_kind_any value of @p format_kind.
- *
- * @p src_iter_desc (simultaneously with @p diff_src_iter_desc),
- * @p bias_desc (simultaneously with @p diff_bias_desc), and
- * @p dst_iter_desc (simultaneously with @p diff_src_iter_desc) are allowed to
- * either be @c NULL or point to a zero memory descriptor, which would indicate
- * that the RNN primitive should not use them.
- *
- * Inputs:
- * - src_layer (#mkldnn_query_src_md, 0)
- * - src_iter (#mkldnn_query_src_md, 1), if used
- * - weights_layer (#mkldnn_query_weights_md, 0)
- * - weights_iter (#mkldnn_query_weights_md, 1)
- * - bias (#mkldnn_query_weights_md, 2), if used
- * - dst_layer (#mkldnn_query_dst_md, 0)
- * - dst_iter (#mkldnn_query_dst_md, 1), if used
- * - diff_dst_layer (#mkldnn_query_diff_dst_md, 0)
- * - diff_dst_iter (#mkldnn_query_diff_dst_md, 1), if used
- * - workspace (#mkldnn_query_workspace_md, 0)
- *
- * Outputs:
- * - diff_src_layer (#mkldnn_query_diff_src_md, 0)
- * - diff_src_iter (#mkldnn_query_diff_src_md, 1), if used
- * - diff_weights_layer (#mkldnn_query_diff_weights_md, 0)
- * - diff_weights_iter (#mkldnn_query_diff_weights_md, 1)
- * - diff_bias (#mkldnn_query_diff_weights_md, 2), if used
- */
-mkldnn_status_t MKLDNN_API mkldnn_rnn_backward_desc_init(
- mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind,
- const mkldnn_rnn_cell_desc_t *rnn_cell_desc,
- const mkldnn_rnn_direction_t direction,
- const mkldnn_memory_desc_t *src_layer_desc,
- const mkldnn_memory_desc_t *src_iter_desc,
- const mkldnn_memory_desc_t *weights_layer_desc,
- const mkldnn_memory_desc_t *weights_iter_desc,
- const mkldnn_memory_desc_t *bias_desc,
- const mkldnn_memory_desc_t *dst_layer_desc,
- const mkldnn_memory_desc_t *dst_iter_desc,
- const mkldnn_memory_desc_t *diff_src_layer_desc,
- const mkldnn_memory_desc_t *diff_src_iter_desc,
- const mkldnn_memory_desc_t *diff_weights_layer_desc,
- const mkldnn_memory_desc_t *diff_weights_iter_desc,
- const mkldnn_memory_desc_t *diff_bias_desc,
- const mkldnn_memory_desc_t *diff_dst_layer,
- const mkldnn_memory_desc_t *diff_dst_iter_desc);
-
-/** @} */
-
-/** @} */
-
-/** @addtogroup c_api_engine Engine operations
- * @{ */
-
-/** Returns the number of engines of a particular @p kind. */
-size_t MKLDNN_API mkldnn_engine_get_count(mkldnn_engine_kind_t kind);
-
-/** Creates an @p engine of particular @p kind and @p index. */
-mkldnn_status_t MKLDNN_API mkldnn_engine_create(mkldnn_engine_t *engine,
- mkldnn_engine_kind_t kind, size_t index);
-
-/** Returns the kind of an @p engine. */
-mkldnn_status_t MKLDNN_API mkldnn_engine_get_kind(mkldnn_engine_t engine,
- mkldnn_engine_kind_t *kind);
-
-/** Destroys an @p engine. */
-mkldnn_status_t MKLDNN_API mkldnn_engine_destroy(mkldnn_engine_t engine);
-
-/** @} */
-
-/** @addtogroup c_api_stream Execution stream operations
- * @{ */
-
-/** Creates an execution @p stream for @p engine and with @p flags. */
-mkldnn_status_t MKLDNN_API mkldnn_stream_create(mkldnn_stream_t *stream,
- mkldnn_engine_t engine, unsigned flags);
-
-/** Destroys an execution @p stream. */
-mkldnn_status_t MKLDNN_API mkldnn_stream_destroy(mkldnn_stream_t stream);
-
-/** @} */
-
-/** @addtogroup c_api_service Service functions
- * @{ */
-
-/** Sets verbosity level (print information to stdout).
- * Possible levels are:
- * - 0 -- no verbose output (default)
- * - 1 -- primitive information at execution
- * - 2 -- primitive information at creation and execution
- *
- * @note
- * Dumping information might affect performance.
- * This setting overrides the MKLDNN_VERBOSE environment variable. */
-mkldnn_status_t MKLDNN_API mkldnn_set_verbose(int level);
-
-/** Enables or disables dumping of JIT-generated code.
- * The enable parameter can be:
- * - 0 -- disable
- * - any other value -- enable
- *
- * @note
- * This setting overrides the MKLDNN_JIT_DUMP environment variable. */
-mkldnn_status_t MKLDNN_API mkldnn_set_jit_dump(int enable);
-
-/** Gets library version information.
- * Version information includes:
- * - major -- major version number
- * - minor -- minor version number
- * - patch -- patch release number
- * - hash -- git commit hash */
-const mkldnn_version_t MKLDNN_API *mkldnn_version();
-
-/** @} */
-
-/** @addtogroup c_api_blas BLAS functions
- * A subset of Basic Linear ALgebra (BLAS) functions to perform
- * matrix-matrix multiplication.
- * @{ */
-
-/** SGEMM performs a matrix-matrix multiplication operation defined as
- *
- * C := alpha*op( A )*op( B ) + beta*C
- *
- * where
- * - op( X ) is one of op( X ) = X or op( X ) = X**T,
- * - alpha and beta are scalars,
- * - A, B and C are matrices, with op( A ) an m by k matrix, op( B ) a k by n matrix
- * and C an m by n matrix.
- *
- * The matrices are assumed to be stored in column-major order (the elements
- * in a matrix columns are contiguous in memory).
- *
- * @note
- * The API is different from the standard BLAS routine
- * because it returns mkldnn_status_t for error handling.
- * XERBLA is not supported: no error message will be printed
- * in case of incorrect parameters. */
-mkldnn_status_t MKLDNN_API mkldnn_sgemm(
- const char *transa, const char *transb,
- const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K,
- const float *alpha, const float *A, const mkldnn_dim_t *lda,
- const float *B, const mkldnn_dim_t *ldb,
- const float *beta, float *C, const mkldnn_dim_t *ldc);
-
-/** gemm_s8u8s32 and gemm_s8s8s32 perform a matrix-matrix multiplication
- * operation and add the result to a scalar-matrix product. For the final
- * result, a vector is added to each row or column of the output matrix.
- * The operation is defined as:
- *
- * C := alpha*(op(A) + A_offset) * (op(B) + B_offset) + beta*C + C_offset
- *
- * where
- * - op( X ) = X or op( X ) = X**T,
- * - A_offset is an m-by-k matrix with every element equal to the value oa,
- * - B_offset is an k-by-n matrix with every element equal to the value ob,
- * - C_offset is an m-by-n matrix defined by the oc array, size len:
- * - if offsetc = F: len must be at least 1
- * - if offsetc = C: len must be at least max(1, m)
- * - if offsetc = R: len must be at least max(1, n)
- * - alpha and beta are scalars, and A, B and C are matrices, with op( A )
- * an m-by-k matrix, op( B ) a k-by-n matrix and C an m-by-n matrix.
- *
- * The matrices are assumed to be stored in column-major order (the elements
- * in a matrix columns are contiguous in memory).
- *
- * @note
- * The API is different compared with the standard BLAS routine
- * because it returns mkldnn_status_t for error handling.
- * XERBLA is not supported: no error message will be printed
- * in case of incorrect parameters. */
-mkldnn_status_t MKLDNN_API mkldnn_gemm_s8u8s32(
- const char *transa, const char *transb, const char *offsetc,
- const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K,
- const float *alpha,
- const int8_t *A, const mkldnn_dim_t *lda, const int8_t *ao,
- const uint8_t *B, const mkldnn_dim_t *ldb, const int8_t *bo,
- const float *beta,
- int32_t *c, const mkldnn_dim_t *ldc, const int32_t *co);
-
-mkldnn_status_t MKLDNN_API mkldnn_gemm_s8s8s32(
- const char *transa, const char *transb, const char *offsetc,
- const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K,
- const float *alpha,
- const int8_t *A, const mkldnn_dim_t *lda, const int8_t *ao,
- const int8_t *B, const mkldnn_dim_t *ldb, const int8_t *bo,
- const float *beta,
- int32_t *c, const mkldnn_dim_t *ldc, const int32_t *co);
-/** @} */
-
-/** @} */
-
-#ifdef __cplusplus
-}
-#endif
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp b/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp
deleted file mode 100644
index 581400a013..0000000000
--- a/thirdparty/oidn/mkl-dnn/include/mkldnn.hpp
+++ /dev/null
@@ -1,2615 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef MKLDNN_HPP
-#define MKLDNN_HPP
-
-#ifndef DOXYGEN_SHOULD_SKIP_THIS
-#include <stdlib.h>
-#include <memory>
-#include <vector>
-#include <unordered_map>
-#include <algorithm>
-#include <iterator>
-
-#include "mkldnn.h"
-#endif
-
-namespace mkldnn {
-
-/// @addtogroup cpp_api C++ API
-/// @{
-
-/// @addtogroup cpp_api_utils Utils
-/// @{
-
-/// A class that provides the destructor for an Intel(R) MKL-DNN C handle
-template <typename T> class handle_traits {};
-
-/// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base
-/// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and
-/// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class
-/// can be passed by value. This class enables wrapping:
-/// - Newly constructed handles.
-/// @n In this case, the constructed handle uses reference counting provided
-/// by @p std::shared_ptr with a proper deleter function specified through
-/// the @p handle_traits class.
-/// - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for
-/// example, through mkldnn_primitive_get_primitive_desc()).
-/// @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a
-/// deleter because it is assumed that the handle wrapper for the original
-/// object deletes the handle (this model is similar to @p std::weak_ptr).
-template <typename T, typename traits=handle_traits<T>> class handle {
-private:
- std::shared_ptr<typename std::remove_pointer<T>::type> _data;
- handle(const handle &&) = delete;
- handle &operator=(const handle &&other) = delete;
-protected:
- bool operator==(const T other) const { return other == _data.get(); }
- bool operator!=(const T other) const { return !(*this == other); }
-public:
- /// Constructs a C handle wrapper.
- /// @param t The C handle to wrap.
- /// @param weak A flag to specify whether to construct a weak wrapper.
- handle(T t = 0, bool weak = false): _data(0) {
- reset(t, weak);
- }
-
- handle(const handle &other): _data(other._data) {}
- handle &operator=(const handle &other) {
- _data = other._data;
- return *this;
- }
- /// Resets the value of a C handle.
- /// @param t The new value of the C handle.
- /// @param weak A flag to specify whether the wrapper should be weak.
- void reset(T t, bool weak = false) {
- auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); };
- _data.reset(t, weak ? dummy_destructor : traits::destructor);
- }
-
- /// Returns the value of the underlying C handle.
- T get() const { return _data.get(); }
-
- bool operator==(const handle &other) const { return other._data.get() == _data.get(); }
- bool operator!=(const handle &other) const { return !(*this == other); }
-};
-
-#ifndef DOXYGEN_SHOULD_SKIP_THIS
-template <> struct handle_traits<mkldnn_memory_t> {
- static constexpr auto destructor = &mkldnn_memory_destroy;
-};
-
-template <> struct handle_traits<mkldnn_primitive_desc_t> {
- static constexpr auto destructor = &mkldnn_primitive_desc_destroy;
-};
-
-template <> struct handle_traits<mkldnn_primitive_t> {
- static constexpr auto destructor = &mkldnn_primitive_destroy;
-};
-
-template <> struct handle_traits<mkldnn_primitive_desc_iterator_t> {
- static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy;
-};
-#endif
-
-struct memory;
-struct primitive_desc;
-
-/// Base class for all computational primitives.
-class primitive: public handle<mkldnn_primitive_t> {
- friend struct error;
- friend struct stream;
- using handle::handle;
-public:
- /// A proxy to C primitive kind enum
- enum class kind {
- undefined_primitive = mkldnn_undefined_primitive,
- reorder = mkldnn_reorder,
- concat = mkldnn_concat,
- sum = mkldnn_sum,
- convolution = mkldnn_convolution,
- deconvolution = mkldnn_deconvolution,
- shuffle = mkldnn_shuffle,
- eltwise = mkldnn_eltwise,
- softmax = mkldnn_softmax,
- pooling = mkldnn_pooling,
- lrn = mkldnn_lrn,
- batch_normalization = mkldnn_batch_normalization,
- inner_product = mkldnn_inner_product,
- rnn = mkldnn_rnn,
- };
-
- primitive(const_mkldnn_primitive_desc_t c_pd);
- primitive(const primitive_desc &pd);
-
- /// Returns the descriptor of the underlying C API primitive.
- inline const_mkldnn_primitive_desc_t get_primitive_desc() const;
- // TODO: use the C++ API wrapper structure.
-
- void execute(struct stream &astream,
- const std::unordered_map<int, memory> &args) const;
-};
-
-inline mkldnn_primitive_kind_t convert_to_c(primitive::kind akind) {
- return static_cast<mkldnn_primitive_kind_t>(akind);
-}
-/// Intel(R) MKL-DNN exception class.
-///
-/// This class captures the status returned by the failed C API function, error
-/// message, and, optionally, handle of the primitive that caused the error.
-struct error: public std::exception {
- mkldnn_status_t status;
- const char *message;
-
- /// Constructs an error instance.
- ///
- /// @param astatus The error status returned by the C API.
- /// @param amessage The error message.
- error(mkldnn_status_t astatus, const char *amessage)
- : status(astatus), message(amessage) {}
-
- /// A convenience function for wrapping calls to the C API. Checks the
- /// return status and throws an #error in case of failure.
- ///
- /// @param status The error status returned by the C API.
- /// @param message The error message.
- static void wrap_c_api(mkldnn_status_t status, const char *message) {
- if (status != mkldnn_success)
- throw error(status, message);
- }
-};
-
-const_mkldnn_primitive_desc_t primitive::get_primitive_desc() const {
- const_mkldnn_primitive_desc_t pd;
- error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(), &pd),
- "could not get primitive descriptor by primitive");
- return pd;
-}
-/// @}
-
-/// @addtogroup cpp_api_enums Common data types and enumerations
-/// A proxy to @ref c_api_types in @ref c_api.
-///
-/// @{
-
-enum scratchpad_mode {
- scratchpad_mode_library = mkldnn_scratchpad_mode_library,
- scratchpad_mode_user = mkldnn_scratchpad_mode_user,
-};
-
-inline mkldnn_scratchpad_mode_t convert_to_c(scratchpad_mode mode) {
- return static_cast<mkldnn_scratchpad_mode_t>(mode);
-}
-
-enum padding_kind {
- zero = mkldnn_padding_zero
-};
-
-inline mkldnn_padding_kind_t convert_to_c(padding_kind kind) {
- return static_cast<mkldnn_padding_kind_t>(kind);
-}
-
-enum prop_kind {
- forward_training = mkldnn_forward_training,
- forward_scoring = mkldnn_forward_scoring,
- forward_inference = mkldnn_forward_inference,
- forward = mkldnn_forward,
- backward = mkldnn_backward,
- backward_data = mkldnn_backward_data,
- backward_weights = mkldnn_backward_weights,
- backward_bias = mkldnn_backward_bias
-};
-
-inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) {
- return static_cast<mkldnn_prop_kind_t>(kind);
-}
-
-enum algorithm {
- algorithm_undef = mkldnn_alg_kind_undef,
- convolution_auto = mkldnn_convolution_auto,
- convolution_direct = mkldnn_convolution_direct,
- convolution_winograd = mkldnn_convolution_winograd,
- deconvolution_direct = mkldnn_deconvolution_direct,
- deconvolution_winograd = mkldnn_deconvolution_winograd,
- eltwise_relu = mkldnn_eltwise_relu,
- eltwise_tanh = mkldnn_eltwise_tanh,
- eltwise_elu = mkldnn_eltwise_elu,
- eltwise_square = mkldnn_eltwise_square,
- eltwise_abs = mkldnn_eltwise_abs,
- eltwise_sqrt = mkldnn_eltwise_sqrt,
- eltwise_linear = mkldnn_eltwise_linear,
- eltwise_bounded_relu = mkldnn_eltwise_bounded_relu,
- eltwise_soft_relu = mkldnn_eltwise_soft_relu,
- eltwise_logistic = mkldnn_eltwise_logistic,
- lrn_across_channels = mkldnn_lrn_across_channels,
- lrn_within_channel = mkldnn_lrn_within_channel,
- pooling_max = mkldnn_pooling_max,
- pooling_avg = mkldnn_pooling_avg,
- pooling_avg_include_padding = mkldnn_pooling_avg_include_padding,
- pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding,
- vanilla_rnn = mkldnn_vanilla_rnn,
- vanilla_lstm = mkldnn_vanilla_lstm,
- vanilla_gru = mkldnn_vanilla_gru,
- gru_linear_before_reset = mkldnn_gru_linear_before_reset
-};
-
-inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) {
- return static_cast<mkldnn_alg_kind_t>(aalgorithm);
-}
-
-enum batch_normalization_flag {
- use_global_stats = mkldnn_use_global_stats,
- use_scale_shift = mkldnn_use_scaleshift,
- fuse_bn_relu = mkldnn_fuse_bn_relu
-};
-
-inline mkldnn_batch_normalization_flag_t convert_to_c(
- batch_normalization_flag aflag) {
- return static_cast<mkldnn_batch_normalization_flag_t>(aflag);
-}
-
-enum rnn_direction {
- unidirectional_left2right = mkldnn_unidirectional_left2right,
- unidirectional_right2left = mkldnn_unidirectional_right2left,
- unidirectional = mkldnn_unidirectional,
- bidirectional_concat = mkldnn_bidirectional_concat,
- bidirectional_sum = mkldnn_bidirectional_sum,
-};
-
-inline mkldnn_rnn_direction_t convert_to_c(rnn_direction adir) {
- return static_cast<mkldnn_rnn_direction_t>(adir);
-}
-
-enum query {
- undef = mkldnn_query_undef,
-
- query_engine = mkldnn_query_engine,
- primitive_kind = mkldnn_query_primitive_kind,
-
- num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32,
- num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32,
-
- time_estimate_f64 = mkldnn_query_time_estimate_f64,
- memory_consumption_s64 = mkldnn_query_memory_consumption_s64,
-
- query_scratchpad_engine = mkldnn_query_scratchpad_engine,
-
- impl_info_str = mkldnn_query_impl_info_str,
-
- op_d = mkldnn_query_op_d,
- convolution_d = mkldnn_query_convolution_d,
- deconvolution_d = mkldnn_query_deconvolution_d,
- shuffle_d = mkldnn_query_shuffle_d,
- eltwise_d = mkldnn_query_eltwise_d,
- softmax_d = mkldnn_query_softmax_d,
- pooling_d = mkldnn_query_pooling_d,
- lrn_d = mkldnn_query_lrn_d,
- batch_normalization_d = mkldnn_query_batch_normalization_d,
- inner_product_d = mkldnn_query_inner_product_d,
- rnn_d = mkldnn_query_rnn_d,
-
- src_md = mkldnn_query_src_md,
- diff_src_md = mkldnn_query_diff_src_md,
- weights_md = mkldnn_query_weights_md,
- diff_weights_md = mkldnn_query_diff_weights_md,
- dst_md = mkldnn_query_dst_md,
- diff_dst_md = mkldnn_query_diff_dst_md,
- workspace_md = mkldnn_query_workspace_md,
- scratchpad_md = mkldnn_query_scratchpad_md,
-};
-
-inline mkldnn_query_t convert_to_c(query aquery) {
- return static_cast<mkldnn_query_t>(aquery);
-}
-
-/// @}
-
-/// @addtogroup cpp_api_attr Attributes
-/// An extension for controlling primitive behavior.
-///
-/// @sa @ref c_api_attributes in @ref c_api
-/// @{
-
-#ifndef DOXYGEN_SHOULD_SKIP_THIS
-template <> struct handle_traits<mkldnn_post_ops_t> {
- static constexpr auto destructor = &mkldnn_post_ops_destroy;
-};
-#endif
-
-struct post_ops: public handle<mkldnn_post_ops_t> {
- post_ops() {
- mkldnn_post_ops_t result;
- error::wrap_c_api(mkldnn_post_ops_create(&result),
- "could not create post operation sequence");
- reset(result);
- }
-
- int len() const { return mkldnn_post_ops_len(get()); }
-
- primitive::kind kind(int index) const {
- error::wrap_c_api(
- index < len() ? mkldnn_success : mkldnn_invalid_arguments,
- "post_ops index is out of range");
- return static_cast<primitive::kind>(mkldnn_post_ops_get_kind(get(),
- index));
- }
-
- void append_sum(float scale = 1.) {
- error::wrap_c_api(mkldnn_post_ops_append_sum(get(), scale),
- "could not append sum");
- }
-
- void get_params_sum(int index, float &scale) const {
- error::wrap_c_api(mkldnn_post_ops_get_params_sum(get(), index, &scale),
- "could not get sum params");
- }
-
- void append_eltwise(float scale, algorithm alg, float alpha,
- float beta) {
- error::wrap_c_api(mkldnn_post_ops_append_eltwise(get(), scale,
- convert_to_c(alg), alpha, beta),
- "could not append eltwise");
- }
-
- void get_params_eltwise(int index, float &scale, algorithm &alg,
- float &alpha, float &beta) const {
- mkldnn_alg_kind_t c_alg;
- error::wrap_c_api(mkldnn_post_ops_get_params_eltwise(get(), index,
- &scale, &c_alg, &alpha, &beta),
- "could not get eltwise params");
- alg = static_cast<algorithm>(c_alg);
- }
-};
-
-#ifndef DOXYGEN_SHOULD_SKIP_THIS
-template <> struct handle_traits<mkldnn_primitive_attr_t> {
- static constexpr auto destructor = &mkldnn_primitive_attr_destroy;
-};
-#endif
-
-struct primitive_attr: public handle<mkldnn_primitive_attr_t> {
- primitive_attr() {
- mkldnn_primitive_attr_t result;
- error::wrap_c_api(mkldnn_primitive_attr_create(&result),
- "could not create a primitive attr");
- reset(result);
- }
-
- scratchpad_mode get_scratchpad_mode() const {
- mkldnn_scratchpad_mode_t result;
- error::wrap_c_api(mkldnn_primitive_attr_get_scratchpad_mode(
- get(), &result), "could not get scratchpad mode");
- return scratchpad_mode(result);
- }
-
- void set_scratchpad_mode(scratchpad_mode mode) {
- error::wrap_c_api(mkldnn_primitive_attr_set_scratchpad_mode(
- get(), mkldnn::convert_to_c(mode)),
- "could not set scratchpad mode");
- }
-
- void get_output_scales(int &mask, std::vector<float> &scales) const
- {
- mkldnn_dim_t count;
- int c_mask;
- const float *c_scales;
- error::wrap_c_api(mkldnn_primitive_attr_get_output_scales(get(),
- &count, &c_mask, &c_scales),
- "could not get int output scales");
- scales.resize(count);
-
- mask = c_mask;
- for (mkldnn_dim_t c = 0; c < count; ++c)
- scales[c] = c_scales[c];
- }
-
- void set_output_scales(int mask, const std::vector<float> &scales)
- {
- error::wrap_c_api(mkldnn_primitive_attr_set_output_scales(get(),
- (mkldnn_dim_t)scales.size(), mask, &scales[0]),
- "could not set int output scales");
- }
-
- const post_ops get_post_ops() const {
- post_ops result;
- const_mkldnn_post_ops_t c_result;
- error::wrap_c_api(mkldnn_primitive_attr_get_post_ops(get(), &c_result),
- "could not get post operation sequence");
- result.reset(const_cast<mkldnn_post_ops_t>(c_result), true);
- return result;
- }
-
- void set_post_ops(post_ops ops) {
- error::wrap_c_api(mkldnn_primitive_attr_set_post_ops(get(), ops.get()),
- "could not set post operation sequence");
- }
-
- void set_rnn_data_qparams(const float scale, const float shift)
- {
- error::wrap_c_api(mkldnn_primitive_attr_set_rnn_data_qparams(get(),
- scale, shift), "could not set rnn data int scale/shift");
- }
-
- void set_rnn_weights_qparams(int mask, const std::vector<float> &scales)
- {
- error::wrap_c_api(mkldnn_primitive_attr_set_rnn_weights_qparams(get(),
- (int)scales.size(), mask, &scales[0]),
- "could not set rnn weights int scales");
- }
-};
-
-/// @}
-
-/// @addtogroup cpp_api_engine Engine
-/// Engine operations.
-///
-/// @sa @ref c_api_engine in @ref c_api
-/// @{
-
-#ifndef DOXYGEN_SHOULD_SKIP_THIS
-template <> struct handle_traits<mkldnn_engine_t> {
- static constexpr auto destructor = &mkldnn_engine_destroy;
-};
-#endif
-
-/// An execution engine.
-struct engine: public handle<mkldnn_engine_t> {
- friend class primitive;
- // gcc bug??? using handle::handle;
-
- /// Kinds of engines.
- enum kind {
- /// An unspecified engine
- any = mkldnn_any_engine,
- /// CPU engine
- cpu = mkldnn_cpu,
- };
-
- /// Returns the number of engines of a certain kind.
- ///
- /// @param akind The kind of engines to count.
-
- static size_t get_count(kind akind) {
- return mkldnn_engine_get_count(convert_to_c(akind));
- }
-
- /// Constructs an engine.
- ///
- /// @param akind The kind of engine to construct.
- /// @param index The index of the engine. Must be less than the value
- /// returned by #get_count() for this particular kind of engine.
-
- engine(kind akind, size_t index) {
- mkldnn_engine_t aengine;
- error::wrap_c_api(
- mkldnn_engine_create(&aengine,
- convert_to_c(akind), index),
- "could not create an engine");
- reset(aengine);
- }
-
- explicit engine(const mkldnn_engine_t& aengine)
- : handle(aengine, true) {}
-
- engine(const handle<mkldnn_primitive_desc_t> &pd) {
- mkldnn_engine_t engine_q;
- error::wrap_c_api(
- mkldnn_primitive_desc_query(pd.get(),
- mkldnn::convert_to_c(query_engine), 0, &engine_q),
- "could not get engine from primitive_desc");
- reset(engine_q, true);
- }
-
- template <class primitive_desc>
- static engine query(const primitive_desc &pd) {
- mkldnn_engine_t engine_q;
- error::wrap_c_api(
- mkldnn_primitive_desc_query(pd.get(),
- mkldnn::convert_to_c(query_engine), 0, &engine_q),
- "could not get engine from primitive_desc");
-
- return engine(engine_q);
- }
-
-private:
- static mkldnn_engine_kind_t convert_to_c(kind akind) {
- return static_cast<mkldnn_engine_kind_t>(akind);
- }
-};
-
-/// @}
-
-/// @addtogroup cpp_api_stream Stream
-/// Execution stream operations
-///
-/// @sa @ref c_api_stream in @ref c_api
-/// @{
-
-#ifndef DOXYGEN_SHOULD_SKIP_THIS
-template <> struct handle_traits<mkldnn_stream_t> {
- static constexpr auto destructor = &mkldnn_stream_destroy;
-};
-#endif
-
-struct stream: public handle<mkldnn_stream_t> {
- using handle::handle;
-
- enum: unsigned {
- default_flags = mkldnn_stream_default_flags,
- };
-
- /// Constructs a stream.
- stream(const engine &aengine,
- unsigned flags = static_cast<unsigned>(default_flags)) {
- mkldnn_stream_t astream;
- error::wrap_c_api(mkldnn_stream_create(&astream, aengine.get(), flags),
- "could not create a stream");
- reset(astream);
- }
-};
-
-/// @}
-
-/// @addtogroup cpp_api_memory_related Memory and memory related operations
-/// @{
-
-/// @addtogroup cpp_api_memory Memory
-/// A primitive to describe and store data.
-///
-/// For more information, refer to @ref c_api_memory in @ref c_api.
-/// @{
-
-/// Memory that describes the data.
-struct memory: public handle<mkldnn_memory_t> {
- public:
- typedef mkldnn_dim_t dim;
- typedef std::vector<dim> dims;
-
- template <typename T> static void validate_dims(const std::vector<T> &v) {
- if (v.size() > MKLDNN_MAX_NDIMS)
- throw error(mkldnn_invalid_arguments, "invalid dimensions");
- }
-
- /// Data type specification. See #mkldnn_data_type_t for a detailed
- /// description.
- enum data_type {
- data_undef = mkldnn_data_type_undef,
- f32 = mkldnn_f32,
- s32 = mkldnn_s32,
- s8 = mkldnn_s8,
- u8 = mkldnn_u8,
- };
-
- /// Memory format tag specification. See #mkldnn_format_tag_t
- /// for a detailed description.
- enum format_tag {
- format_tag_undef = mkldnn_format_tag_undef,
- any = mkldnn_format_tag_any,
- a = mkldnn_a,
- ab = mkldnn_ab,
- abc = mkldnn_abc,
- abcd = mkldnn_abcd,
- abcde = mkldnn_abcde,
- abcdef = mkldnn_abcdef,
- abdec = mkldnn_abdec,
- acb = mkldnn_acb,
- acbde = mkldnn_acbde,
- acdb = mkldnn_acdb,
- acdeb = mkldnn_acdeb,
- ba = mkldnn_ba,
- bac = mkldnn_bac,
- bacd = mkldnn_bacd,
- bcda = mkldnn_bcda,
- cba = mkldnn_cba,
- cdba = mkldnn_cdba,
- cdeba = mkldnn_cdeba,
- decab = mkldnn_decab,
- Abc16a = mkldnn_Abc16a,
- ABc16a16b = mkldnn_ABc16a16b,
- aBc16b = mkldnn_aBc16b,
- ABc16b16a = mkldnn_ABc16b16a,
- Abc4a = mkldnn_Abc4a,
- aBc4b = mkldnn_aBc4b,
- ABc4b16a4b = mkldnn_ABc4b16a4b,
- ABc4b4a = mkldnn_ABc4b4a,
- ABc8a16b2a = mkldnn_ABc8a16b2a,
- ABc8a8b = mkldnn_ABc8a8b,
- aBc8b = mkldnn_aBc8b,
- ABc8b16a2b = mkldnn_ABc8b16a2b,
- ABc8b8a = mkldnn_ABc8b8a,
- Abcd16a = mkldnn_Abcd16a,
- ABcd16a16b = mkldnn_ABcd16a16b,
- aBcd16b = mkldnn_aBcd16b,
- ABcd16b16a = mkldnn_ABcd16b16a,
- aBCd16b16c = mkldnn_aBCd16b16c,
- aBCd16c16b = mkldnn_aBCd16c16b,
- Abcd4a = mkldnn_Abcd4a,
- aBcd4b = mkldnn_aBcd4b,
- ABcd4b16a4b = mkldnn_ABcd4b16a4b,
- ABcd4b4a = mkldnn_ABcd4b4a,
- aBCd4c16b4c = mkldnn_aBCd4c16b4c,
- aBCd4c4b = mkldnn_aBCd4c4b,
- ABcd8a16b2a = mkldnn_ABcd8a16b2a,
- ABcd8a8b = mkldnn_ABcd8a8b,
- aBcd8b = mkldnn_aBcd8b,
- ABcd8b16a2b = mkldnn_ABcd8b16a2b,
- aBCd8b16c2b = mkldnn_aBCd8b16c2b,
- ABcd8b8a = mkldnn_ABcd8b8a,
- aBCd8b8c = mkldnn_aBCd8b8c,
- aBCd8c16b2c = mkldnn_aBCd8c16b2c,
- aBCd8c8b = mkldnn_aBCd8c8b,
- Abcde16a = mkldnn_Abcde16a,
- ABcde16a16b = mkldnn_ABcde16a16b,
- aBcde16b = mkldnn_aBcde16b,
- ABcde16b16a = mkldnn_ABcde16b16a,
- aBCde16b16c = mkldnn_aBCde16b16c,
- aBCde16c16b = mkldnn_aBCde16c16b,
- aBCde2c8b4c = mkldnn_aBCde2c8b4c,
- Abcde4a = mkldnn_Abcde4a,
- aBcde4b = mkldnn_aBcde4b,
- ABcde4b4a = mkldnn_ABcde4b4a,
- aBCde4b4c = mkldnn_aBCde4b4c,
- aBCde4c16b4c = mkldnn_aBCde4c16b4c,
- aBCde4c4b = mkldnn_aBCde4c4b,
- Abcde8a = mkldnn_Abcde8a,
- ABcde8a8b = mkldnn_ABcde8a8b,
- aBcde8b = mkldnn_aBcde8b,
- ABcde8b16a2b = mkldnn_ABcde8b16a2b,
- aBCde8b16c2b = mkldnn_aBCde8b16c2b,
- ABcde8b8a = mkldnn_ABcde8b8a,
- aBCde8b8c = mkldnn_aBCde8b8c,
- aBCde8c16b2c = mkldnn_aBCde8c16b2c,
- aBCde8c8b = mkldnn_aBCde8c8b,
- aBcdef16b = mkldnn_aBcdef16b,
- aBCdef16b16c = mkldnn_aBCdef16b16c,
- aBCdef16c16b = mkldnn_aBCdef16c16b,
- aBcdef4b = mkldnn_aBcdef4b,
- aBCdef4c4b = mkldnn_aBCdef4c4b,
- aBCdef8b8c = mkldnn_aBCdef8b8c,
- aBCdef8c16b2c = mkldnn_aBCdef8c16b2c,
- aBCdef8c8b = mkldnn_aBCdef8c8b,
- aBdc16b = mkldnn_aBdc16b,
- aBdc4b = mkldnn_aBdc4b,
- aBdc8b = mkldnn_aBdc8b,
- aBdec16b = mkldnn_aBdec16b,
- aBdec4b = mkldnn_aBdec4b,
- aBdec8b = mkldnn_aBdec8b,
- aBdefc16b = mkldnn_aBdefc16b,
- aBdefc4b = mkldnn_aBdefc4b,
- aBdefc8b = mkldnn_aBdefc8b,
- Acb16a = mkldnn_Acb16a,
- Acb4a = mkldnn_Acb4a,
- Acb8a = mkldnn_Acb8a,
- aCBd16b16c = mkldnn_aCBd16b16c,
- aCBde16b16c = mkldnn_aCBde16b16c,
- Acdb16a = mkldnn_Acdb16a,
- Acdb4a = mkldnn_Acdb4a,
- Acdb8a = mkldnn_Acdb8a,
- Acdeb16a = mkldnn_Acdeb16a,
- Acdeb4a = mkldnn_Acdeb4a,
- Acdeb8a = mkldnn_Acdeb8a,
- BAc16a16b = mkldnn_BAc16a16b,
- BAcd16a16b = mkldnn_BAcd16a16b,
- format_tag_last = mkldnn_format_tag_last,
-
- x = mkldnn_x,
- nc = mkldnn_nc,
- cn = mkldnn_cn,
- ncw = mkldnn_ncw,
- nwc = mkldnn_nwc,
- nchw = mkldnn_nchw,
- nhwc = mkldnn_nhwc,
- chwn = mkldnn_chwn,
- ncdhw = mkldnn_ncdhw,
- ndhwc = mkldnn_ndhwc,
- oi = mkldnn_oi,
- io = mkldnn_io,
- oiw = mkldnn_oiw,
- wio = mkldnn_wio,
- oihw = mkldnn_oihw,
- hwio = mkldnn_hwio,
- ihwo = mkldnn_ihwo,
- iohw = mkldnn_iohw,
- oidhw = mkldnn_oidhw,
- dhwio = mkldnn_dhwio,
- goiw = mkldnn_goiw,
- goihw = mkldnn_goihw,
- hwigo = mkldnn_hwigo,
- giohw = mkldnn_giohw,
- goidhw = mkldnn_goidhw,
- tnc = mkldnn_tnc,
- ntc = mkldnn_ntc,
- ldsnc = mkldnn_ldsnc,
- ldigo = mkldnn_ldigo,
- ldgoi = mkldnn_ldgoi,
- ldgo = mkldnn_ldgo,
- nCdhw16c = mkldnn_nCdhw16c,
- nCdhw4c = mkldnn_nCdhw4c,
- nCdhw8c = mkldnn_nCdhw8c,
- nChw16c = mkldnn_nChw16c,
- nChw4c = mkldnn_nChw4c,
- nChw8c = mkldnn_nChw8c,
- nCw16c = mkldnn_nCw16c,
- nCw4c = mkldnn_nCw4c,
- nCw8c = mkldnn_nCw8c,
- IOw16o16i = mkldnn_IOw16o16i,
- OIw16i16o = mkldnn_OIw16i16o,
- OIw16o16i = mkldnn_OIw16o16i,
- Oiw16o = mkldnn_Oiw16o,
- OIw4i16o4i = mkldnn_OIw4i16o4i,
- OIw4i4o = mkldnn_OIw4i4o,
- Oiw4o = mkldnn_Oiw4o,
- OIw8i16o2i = mkldnn_OIw8i16o2i,
- OIw8i8o = mkldnn_OIw8i8o,
- OIw8o16i2o = mkldnn_OIw8o16i2o,
- OIw8o8i = mkldnn_OIw8o8i,
- Owi16o = mkldnn_Owi16o,
- Owi4o = mkldnn_Owi4o,
- Owi8o = mkldnn_Owi8o,
- IOhw16o16i = mkldnn_IOhw16o16i,
- Ohwi16o = mkldnn_Ohwi16o,
- Ohwi4o = mkldnn_Ohwi4o,
- Ohwi8o = mkldnn_Ohwi8o,
- OIhw16i16o = mkldnn_OIhw16i16o,
- OIhw16o16i = mkldnn_OIhw16o16i,
- Oihw16o = mkldnn_Oihw16o,
- OIhw4i16o4i = mkldnn_OIhw4i16o4i,
- OIhw4i4o = mkldnn_OIhw4i4o,
- Oihw4o = mkldnn_Oihw4o,
- OIhw8i16o2i = mkldnn_OIhw8i16o2i,
- OIhw8i8o = mkldnn_OIhw8i8o,
- OIhw8o16i2o = mkldnn_OIhw8o16i2o,
- OIhw8o8i = mkldnn_OIhw8o8i,
- Odhwi16o = mkldnn_Odhwi16o,
- Odhwi4o = mkldnn_Odhwi4o,
- Odhwi8o = mkldnn_Odhwi8o,
- OIdhw16i16o = mkldnn_OIdhw16i16o,
- OIdhw16o16i = mkldnn_OIdhw16o16i,
- Oidhw16o = mkldnn_Oidhw16o,
- OIdhw4i4o = mkldnn_OIdhw4i4o,
- Oidhw4o = mkldnn_Oidhw4o,
- OIdhw8i16o2i = mkldnn_OIdhw8i16o2i,
- OIdhw8i8o = mkldnn_OIdhw8i8o,
- OIdhw8o8i = mkldnn_OIdhw8o8i,
- gIOw16o16i = mkldnn_gIOw16o16i,
- gOIw16i16o = mkldnn_gOIw16i16o,
- gOIw16o16i = mkldnn_gOIw16o16i,
- gOiw16o = mkldnn_gOiw16o,
- gOIw4i16o4i = mkldnn_gOIw4i16o4i,
- gOIw4i4o = mkldnn_gOIw4i4o,
- gOiw4o = mkldnn_gOiw4o,
- gOIw8i16o2i = mkldnn_gOIw8i16o2i,
- gOIw8i8o = mkldnn_gOIw8i8o,
- gOIw8o16i2o = mkldnn_gOIw8o16i2o,
- gOIw8o8i = mkldnn_gOIw8o8i,
- gOwi16o = mkldnn_gOwi16o,
- gOwi4o = mkldnn_gOwi4o,
- gOwi8o = mkldnn_gOwi8o,
- gIOhw16o16i = mkldnn_gIOhw16o16i,
- gOhwi16o = mkldnn_gOhwi16o,
- gOhwi4o = mkldnn_gOhwi4o,
- gOhwi8o = mkldnn_gOhwi8o,
- Goihw16g = mkldnn_Goihw16g,
- gOIhw16i16o = mkldnn_gOIhw16i16o,
- gOIhw16o16i = mkldnn_gOIhw16o16i,
- gOihw16o = mkldnn_gOihw16o,
- gOIhw2i8o4i = mkldnn_gOIhw2i8o4i,
- gOIhw4i16o4i = mkldnn_gOIhw4i16o4i,
- gOIhw4i4o = mkldnn_gOIhw4i4o,
- gOIhw4o4i = mkldnn_gOIhw4o4i,
- gOihw4o = mkldnn_gOihw4o,
- Goihw8g = mkldnn_Goihw8g,
- gOIhw8i16o2i = mkldnn_gOIhw8i16o2i,
- gOIhw8i8o = mkldnn_gOIhw8i8o,
- gOIhw8o16i2o = mkldnn_gOIhw8o16i2o,
- gOIhw8o8i = mkldnn_gOIhw8o8i,
- gOdhwi16o = mkldnn_gOdhwi16o,
- gOdhwi4o = mkldnn_gOdhwi4o,
- gOdhwi8o = mkldnn_gOdhwi8o,
- gOIdhw16i16o = mkldnn_gOIdhw16i16o,
- gOIdhw16o16i = mkldnn_gOIdhw16o16i,
- gOidhw16o = mkldnn_gOidhw16o,
- gOIdhw4i4o = mkldnn_gOIdhw4i4o,
- gOidhw4o = mkldnn_gOidhw4o,
- gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i,
- gOIdhw8i8o = mkldnn_gOIdhw8i8o,
- gOIdhw8o8i = mkldnn_gOIdhw8o8i,
- };
-
- /// A memory descriptor.
- struct desc {
- friend struct memory;
- /// The underlying C API data structure.
- mkldnn_memory_desc_t data;
-
- /// Constructs a zero memory descriptor
- desc(): data() {}
-
- /// Constructs a memory descriptor.
- ///
- /// @param adims Data dimensions
- /// @param adata_type Data precision/type.
- /// @param aformat Data layout format tag.
- desc(const dims &adims, data_type adata_type,
- format_tag aformat) {
- validate_dims(adims);
- error::wrap_c_api(mkldnn_memory_desc_init_by_tag(&data, (int)adims.size(),
- adims.size() == 0 ? nullptr : &adims[0],
- convert_to_c(adata_type), convert_to_c(aformat)),
- "could not initialize a memory descriptor");
- }
-
- /// Constructs a memory descriptor from a C API data structure.
- ///
- /// @param adata A C API #mkldnn_memory_desc_t structure.
- desc(const mkldnn_memory_desc_t &adata): data(adata) {}
-
- /// Constructs a sub-memory descriptor
- //
- /// @param adims Sizes of a sub-memory
- /// @param offsets Offsets of a sub-memory
- desc submemory_desc(const dims &adims, const dims &offsets) {
- mkldnn_memory_desc_t sub_md;
- error::wrap_c_api(mkldnn_memory_desc_init_submemory(&sub_md,
- &data, &adims[0], &offsets[0]),
- "could not initialize a sub-memory");
- return desc(sub_md);
- }
-
- /// Returns the number of bytes required to allocate the memory described
- /// including the padding area.
- size_t get_size() const { return mkldnn_memory_desc_get_size(&data); }
-
- bool operator==(const desc &other) const {
- return mkldnn_memory_desc_equal(&data, &other.data) != 0;
- }
-
- bool operator!=(const desc &other) const { return !operator==(other); }
- };
-
- /// Constructs a memory.
- ///
- /// @param md Memory descriptor.
- /// @param aengine Engine.
- /// @param ahandle Native handle.
- memory(const desc &md, const engine &aengine, void *ahandle) {
- mkldnn_memory_t result;
- error::wrap_c_api(mkldnn_memory_create(&result, &md.data,
- aengine.get(), ahandle), "could not create a memory");
- reset(result);
- }
-
- /// Constructs a memory.
- ///
- /// @param md Memory descriptor.
- /// @param aengine Engine.
- memory(const desc &md, const engine &aengine)
- : memory(md, aengine, MKLDNN_NATIVE_HANDLE_ALLOCATE) {}
-
- /// Returns the descriptor of the memory.
- desc get_desc() const {
- const mkldnn_memory_desc_t *cdesc;
- error::wrap_c_api(mkldnn_memory_get_memory_desc(get(), &cdesc),
- "could not get memory descriptor from a memory");
- return desc(*cdesc);
- }
-
- /// Returns the engine of the memory.
- engine get_engine() const {
- mkldnn_engine_t engine_q;
- error::wrap_c_api(mkldnn_memory_get_engine(get(), &engine_q),
- "could not get engine from a memory");
- return engine(engine_q);
- }
-
- /// Returns a handle of the data contained in the memory.
- ///
- /// On the CPU engine, this is a pointer to the allocated memory.
- void *get_data_handle() const {
- void *handle;
- error::wrap_c_api(mkldnn_memory_get_data_handle(get(), &handle),
- "could not get native handle");
- return handle;
- }
-
- void set_data_handle(void *handle) const {
- error::wrap_c_api(mkldnn_memory_set_data_handle(get(), handle),
- "could not set native handle");
- }
-
- // Must go away or be private:
- static mkldnn_data_type_t convert_to_c(data_type adata_type) {
- return static_cast<mkldnn_data_type_t>(adata_type);
- }
- static mkldnn_format_tag_t convert_to_c(format_tag aformat) {
- return static_cast<mkldnn_format_tag_t>(aformat);
- }
-};
-
-inline bool operator==(mkldnn_data_type_t a, memory::data_type b) {
- return a == memory::convert_to_c(b);
-}
-inline bool operator!=(mkldnn_data_type_t a, memory::data_type b) {
- return !(a == b);
-}
-inline bool operator==(memory::data_type a, mkldnn_data_type_t b) {
- return b == a;
-}
-inline bool operator!=(memory::data_type a, mkldnn_data_type_t b) {
- return !(a == b);
-}
-
-inline bool operator==(mkldnn_format_tag_t a, memory::format_tag b) {
- return a == memory::convert_to_c(b);
-}
-inline bool operator!=(mkldnn_format_tag_t a, memory::format_tag b) {
- return !(a == b);
-}
-inline bool operator==(memory::format_tag a, mkldnn_format_tag_t b) {
- return b == a;
-}
-inline bool operator!=(memory::format_tag a, mkldnn_format_tag_t b) {
- return !(a == b);
-}
-
-/// @}
-
-/// @addtogroup cpp_api_reorder Reorder
-/// A primitive to copy data between memory formats.
-///
-/// @sa @ref c_api_reorder in @ref c_api
-/// @{
-
-struct reorder : public primitive {
- struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
- primitive_desc(const engine &src_engine, const memory::desc &src_md,
- const engine &dst_engine, const memory::desc &dst_md,
- const primitive_attr &aattr) {
- mkldnn_primitive_desc_t result;
- error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result,
- src_engine.get(), &src_md.data,
- dst_engine.get(), &dst_md.data, aattr.get()),
- "could not create a reorder primitive descriptor");
- reset(result);
- }
-
- primitive_desc(const engine &src_engine, const memory::desc &src_md,
- const engine &dst_engine, const memory::desc &dst_md) {
- mkldnn_primitive_desc_t result;
- error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result,
- src_engine.get(), &src_md.data,
- dst_engine.get(), &dst_md.data, nullptr),
- "could not create a reorder primitive descriptor");
- reset(result);
- }
-
- primitive_desc(const memory &src, const memory &dst,
- const primitive_attr &aattr) {
- mkldnn_primitive_desc_t result;
- auto src_md = src.get_desc();
- auto dst_md = dst.get_desc();
- error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result,
- src.get_engine().get(), &src_md.data,
- dst.get_engine().get(), &dst_md.data, aattr.get()),
- "could not create a reorder primitive descriptor");
- reset(result);
- }
-
- primitive_desc(const memory &src, const memory &dst) {
- mkldnn_primitive_desc_t result;
- auto src_md = src.get_desc();
- auto dst_md = dst.get_desc();
- error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result,
- src.get_engine().get(), &src_md.data,
- dst.get_engine().get(), &dst_md.data, nullptr),
- "could not create a reorder primitive descriptor");
- reset(result);
- }
-
- memory::desc scratchpad_desc() const {
- const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md(
- get(), mkldnn::convert_to_c(scratchpad_md), 0);
- if (cdesc == nullptr)
- return memory::desc();
- return memory::desc(*cdesc);
- }
-
- engine scratchpad_engine() {
- mkldnn_engine_t engine_q;
- error::wrap_c_api(
- mkldnn_primitive_desc_query(get(),
- mkldnn::convert_to_c(query_scratchpad_engine), 0, &engine_q),
- "could not get scratchpad engine from reorder primitive_desc");
-
- return engine(engine_q);
- }
-
- engine get_engine() { return engine::query(*this); }
- };
-
- reorder(const primitive_desc &pd): primitive(pd.get()) {}
-
- reorder(const memory &src, const memory &dst):
- primitive(primitive_desc(src, dst).get()) {}
-
- void execute(stream astream, memory &src, memory &dst) {
- primitive::execute(astream,
- {{MKLDNN_ARG_FROM, src}, {MKLDNN_ARG_TO, dst}});
- }
-};
-
-/// @}
-
-/// @addtogroup cpp_api_concat Concat
-/// A primitive to concatenate data by arbitrary dimension.
-///
-/// @sa @ref c_api_concat in @ref c_api
-/// @{
-
-struct concat : public primitive {
- struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
- std::vector<mkldnn_memory_desc_t> cpp_to_c(
- const std::vector<memory::desc> &srcs) {
- std::vector<mkldnn_memory_desc_t> c_api_srcs;
- c_api_srcs.reserve(srcs.size());
- for (const auto &s : srcs) c_api_srcs.push_back(s.data);
- return c_api_srcs;
- }
-
- primitive_desc(const memory::desc &dst, int concat_dimension,
- const std::vector<memory::desc> &srcs, const engine &aengine) {
- auto c_api_srcs = cpp_to_c(srcs);
-
- mkldnn_primitive_desc_t result;
- error::wrap_c_api(mkldnn_concat_primitive_desc_create(
- &result, &dst.data, (int)c_api_srcs.size(),
- concat_dimension, &c_api_srcs[0], nullptr, aengine.get()),
- "could not create a concat primitive descriptor");
- reset(result);
- }
-
- primitive_desc(int concat_dimension,
- const std::vector<memory::desc> &srcs, const engine &aengine) {
- auto c_api_srcs = cpp_to_c(srcs);
-
- mkldnn_primitive_desc_t result;
- error::wrap_c_api(mkldnn_concat_primitive_desc_create(
- &result, nullptr, (int)c_api_srcs.size(),
- concat_dimension, &c_api_srcs[0], nullptr, aengine.get()),
- "could not create a concat primitive descriptor");
- reset(result);
- }
-
- memory::desc dst_desc() const {
- const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md(
- get(), mkldnn::convert_to_c(dst_md), 0);
- error::wrap_c_api(
- cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success,
- "could not get a dst memory descriptor");
- return memory::desc(*cdesc);
- }
-
- memory::desc scratchpad_desc() const {
- const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md(
- get(), mkldnn::convert_to_c(scratchpad_md), 0);
- if (cdesc == nullptr)
- return memory::desc();
- return memory::desc(*cdesc);
- }
-
- engine get_engine() { return engine::query(*this); }
- };
-
- concat(const primitive_desc &pd): primitive(pd.get()) {}
-};
-
-/// @}
-
-/// @addtogroup cpp_api_sum Sum
-/// A primitive to sum data.
-///
-/// @sa @ref c_api_sum in @ref c_api
-/// @{
-
-struct sum : public primitive {
- struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
- std::vector<mkldnn_memory_desc_t> cpp_to_c(
- const std::vector<memory::desc> &srcs) {
- std::vector<mkldnn_memory_desc_t> c_api_srcs;
- c_api_srcs.reserve(srcs.size());
- for (const auto &s : srcs) c_api_srcs.push_back(s.data);
- return c_api_srcs;
- }
-
- primitive_desc(const memory::desc &dst,
- const std::vector<float> &scales,
- const std::vector<memory::desc> &srcs, const engine &aengine) {
- error::wrap_c_api(scales.size() == srcs.size()
- ? mkldnn_success : mkldnn_invalid_arguments,
- "number of scales not equal to number of srcs");
-
- auto c_api_srcs = cpp_to_c(srcs);
-
- mkldnn_primitive_desc_t result;
- error::wrap_c_api(mkldnn_sum_primitive_desc_create(
- &result, &dst.data, (int)c_api_srcs.size(),
- &scales[0], &c_api_srcs[0], nullptr, aengine.get()),
- "could not create a sum primitive descriptor");
- reset(result);
- }
-
- primitive_desc(const std::vector<float> &scales,
- const std::vector<memory::desc> &srcs, const engine &aengine) {
- error::wrap_c_api(scales.size() == srcs.size()
- ? mkldnn_success : mkldnn_invalid_arguments,
- "number of scales not equal to number of srcs");
-
- auto c_api_srcs = cpp_to_c(srcs);
- mkldnn_primitive_desc_t result;
- error::wrap_c_api(mkldnn_sum_primitive_desc_create(&result,
- nullptr, (int)c_api_srcs.size(), &scales[0],
- &c_api_srcs[0], nullptr, aengine.get()),
- "could not create a sum primitive descriptor");
- reset(result);
- }
-
- memory::desc dst_desc() const {
- const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md(
- get(), mkldnn::convert_to_c(dst_md), 0);
- error::wrap_c_api(
- cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success,
- "could not get a dst memory descriptor");
- return memory::desc(*cdesc);
- }
-
- memory::desc scratchpad_desc() const {
- const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md(
- get(), mkldnn::convert_to_c(scratchpad_md), 0);
- if (cdesc == nullptr)
- return memory::desc();
- return memory::desc(*cdesc);
- }
-
- engine get_engine() { return engine::query(*this); }
- };
-
- sum(const primitive_desc &pd): primitive(pd.get()) {}
-};
-
-/// @}
-
-/// @}
-
-/// @addtogroup cpp_api_primitives Primitives
-/// @{
-
-/// @addtogroup cpp_api_primitive_descriptors Primitive descriptors
-/// @{
-
-/// A base class for all primitive descriptors.
-struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
- primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr,
- const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) {
- mkldnn_primitive_desc_iterator_t iterator = nullptr;
- mkldnn_status_t status = mkldnn_primitive_desc_iterator_create(
- &iterator, desc, attr ? attr->get() : nullptr, e.get(),
- hint_fwd_pd);
- error::wrap_c_api(status,
- "could not create a primitive descriptor iterator");
- pd_iterator.reset(iterator);
- fetch_impl();
- }
-
- engine get_engine() { return engine::query(*this); }
-
- primitive_attr get_primitive_attr() const {
- const_mkldnn_primitive_attr_t const_cattr;
- error::wrap_c_api(mkldnn_primitive_desc_get_attr(get(), &const_cattr),
- "could not get attributes");
- mkldnn_primitive_attr_t cattr;
- error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr),
- "could not clone attributes");
-
- primitive_attr attr;
- attr.reset(cattr);
- return attr;
- }
-
- /// Returns implementation name
- const char *impl_info_str() const {
- const char *res;
- error::wrap_c_api(mkldnn_primitive_desc_query(get(),
- mkldnn_query_impl_info_str, 0, &res),
- "could not query implementation info string");
- return res;
- }
-
- /// Queries the memory::dim value (same as int64_t)
- memory::dim query_s64(query q) const {
- memory::dim res;
- mkldnn_status_t status = mkldnn_primitive_desc_query(get(),
- mkldnn::convert_to_c(q), 0, &res);
- return status == mkldnn_success ? res : 0;
- }
-
- /// Advances the next implementation for the given op descriptor.
- ///
- /// Returns:
- /// - @c true on success
- /// - @c false if the last implementation reached, and
- /// the primitive descriptor itself is kept unchanged
- bool next_impl() {
- mkldnn_status_t status = mkldnn_primitive_desc_iterator_next(
- pd_iterator.get());
- if (status == mkldnn_iterator_ends) return false;
- error::wrap_c_api(status, "primitive descriptor iterator next failed");
-
- fetch_impl();
- return true;
- }
-
- /// Queries and returns requested memory descriptor.
- memory::desc query_md(query what, int idx = 0) const {
- std::vector<query> valid_q{src_md, diff_src_md, weights_md,
- diff_weights_md, dst_md, diff_dst_md, workspace_md, scratchpad_md};
- if (!std::any_of(valid_q.cbegin(), valid_q.cend(),
- [=](query q) { return what == q; }))
- throw error(mkldnn_invalid_arguments, "invalid memory query");
-
- const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md(
- get(), mkldnn::convert_to_c(what), idx);
- if (cdesc == nullptr) return memory::desc();
-
- return memory::desc(*cdesc);
- }
-
- // register specialized queries, e.g. src_desc()
-# define REG_QUERY_MD(name, what, idx) \
- memory::desc name ## _desc() const { return query_md(what ## _md, idx); }
-
- private:
- handle<mkldnn_primitive_desc_iterator_t> pd_iterator;
- void fetch_impl() {
- mkldnn_primitive_desc_t pd = mkldnn_primitive_desc_iterator_fetch(
- pd_iterator.get());
- error::wrap_c_api(pd != nullptr ? mkldnn_success : mkldnn_runtime_error,
- "could not fetch a primitive descriptor from the iterator");
- reset(pd);
- }
-};
-
-/// @}
-
-/// @addtogroup cpp_api_convolution Convolution
-/// A primitive to compute convolution using different algorithms.
-///
-/// @sa @ref c_api_convolution in @ref c_api
-/// @{
-
-struct convolution_forward: public primitive {
- struct desc {
- mkldnn_convolution_desc_t data;
- desc(prop_kind aprop_kind, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &weights_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_desc,
- const memory::dims strides,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
- &src_desc.data, &weights_desc.data, &bias_desc.data,
- &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not create a convolution forward descriptor");
- }
- desc(prop_kind aprop_kind, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &weights_desc,
- const memory::desc &dst_desc,
- const memory::dims strides,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
- &src_desc.data, &weights_desc.data, nullptr,
- &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not create a convolution forward descriptor");
- }
- desc(prop_kind aprop_kind, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &weights_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_desc,
- const memory::dims strides,
- const memory::dims dilates,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(dilates);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(
- mkldnn_dilated_convolution_forward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
- &src_desc.data, &weights_desc.data, &bias_desc.data,
- &dst_desc.data, &strides[0], &dilates[0],
- &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not create a dilated convolution forward descriptor");
- }
- desc(prop_kind aprop_kind, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &weights_desc,
- const memory::desc &dst_desc,
- const memory::dims strides,
- const memory::dims dilates,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(dilates);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(
- mkldnn_dilated_convolution_forward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
- &src_desc.data, &weights_desc.data, nullptr,
- &dst_desc.data, &strides[0], &dilates[0],
- &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not create a dilated convolution forward descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
- : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
-
- REG_QUERY_MD(src, src, 0);
- REG_QUERY_MD(weights, weights, 0);
- REG_QUERY_MD(bias, weights, 1);
- REG_QUERY_MD(dst, dst, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- convolution_forward(const primitive_desc &pd): primitive(pd) {}
-};
-
-struct convolution_backward_data : public primitive {
- struct desc {
- mkldnn_convolution_desc_t data;
- desc(algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &weights_desc,
- const memory::desc &diff_dst_desc,
- const memory::dims strides,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_convolution_backward_data_desc_init(
- &data, convert_to_c(aalgorithm), &diff_src_desc.data,
- &weights_desc.data, &diff_dst_desc.data,
- &strides[0], &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not create a convolution backward data descriptor");
- }
- desc(algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &weights_desc,
- const memory::desc &diff_dst_desc,
- const memory::dims strides,
- const memory::dims dilates,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(dilates);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(
- mkldnn_dilated_convolution_backward_data_desc_init(
- &data, convert_to_c(aalgorithm), &diff_src_desc.data,
- &weights_desc.data, &diff_dst_desc.data,
- &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not create a convolution backward data descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e,
- const convolution_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
- const convolution_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
-
- REG_QUERY_MD(diff_src, diff_src, 0);
- REG_QUERY_MD(weights, weights, 0);
- REG_QUERY_MD(diff_dst, diff_dst, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- convolution_backward_data(const primitive_desc &pd): primitive(pd) {}
-};
-
-struct convolution_backward_weights : public primitive {
- struct desc {
- mkldnn_convolution_desc_t data;
- desc(algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_desc,
- const memory::dims strides,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init(
- &data, convert_to_c(aalgorithm), &src_desc.data,
- &diff_weights_desc.data, &diff_bias_desc.data,
- &diff_dst_desc.data,
- &strides[0], &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not create a convolution backward weights descriptor");
- }
- desc(algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_dst_desc,
- const memory::dims strides,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init(
- &data, convert_to_c(aalgorithm), &src_desc.data,
- &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
- &strides[0], &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not create a convolution backward weights descriptor");
- }
- desc(algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_desc,
- const memory::dims strides,
- const memory::dims dilates,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(dilates);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init(
- &data, convert_to_c(aalgorithm), &src_desc.data,
- &diff_weights_desc.data, &diff_bias_desc.data,
- &diff_dst_desc.data,
- &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not create a convolution backward weights descriptor");
- }
- desc(algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_dst_desc,
- const memory::dims strides,
- const memory::dims dilates,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(dilates);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init(
- &data, convert_to_c(aalgorithm), &src_desc.data,
- &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
- &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not create a convolution backward weights descriptor");
- }
-
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e,
- const convolution_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
- const convolution_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
-
- REG_QUERY_MD(src, src, 0);
- REG_QUERY_MD(diff_weights, diff_weights, 0);
- REG_QUERY_MD(diff_bias, diff_weights, 1);
- REG_QUERY_MD(diff_dst, diff_dst, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- convolution_backward_weights(const primitive_desc &pd): primitive(pd) {}
-};
-
-/// @}
-//
-/// @addtogroup cpp_api_deconvolution Deconvolution
-/// A primitive to compute deconvolution using different algorithms.
-///
-/// @sa @ref c_api_deconvolution in @ref c_api
-/// @{
-
-struct deconvolution_forward: public primitive {
- struct desc {
- mkldnn_deconvolution_desc_t data;
- desc(prop_kind aprop_kind, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &weights_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_desc,
- const memory::dims strides,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
- &src_desc.data, &weights_desc.data, &bias_desc.data,
- &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not create a deconvolution forward descriptor");
- }
- desc(prop_kind aprop_kind, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &weights_desc,
- const memory::desc &dst_desc,
- const memory::dims strides,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
- &src_desc.data, &weights_desc.data, nullptr,
- &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not create a deconvolution forward descriptor");
- }
- desc(prop_kind aprop_kind, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &weights_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_desc,
- const memory::dims strides,
- const memory::dims dilates,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(dilates);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
- &src_desc.data, &weights_desc.data, &bias_desc.data,
- &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
- &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
- "could not create a dilated deconvolution forward descriptor");
- }
- desc(prop_kind aprop_kind, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &weights_desc,
- const memory::desc &dst_desc,
- const memory::dims strides,
- const memory::dims dilates,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(dilates);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
- &src_desc.data, &weights_desc.data, nullptr,
- &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
- &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
- "could not create a dilated deconvolution forward descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
- : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
-
- REG_QUERY_MD(src, src, 0);
- REG_QUERY_MD(weights, weights, 0);
- REG_QUERY_MD(bias, weights, 1);
- REG_QUERY_MD(dst, dst, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- deconvolution_forward(const primitive_desc &pd): primitive(pd) {}
-};
-
-struct deconvolution_backward_data : public primitive {
- struct desc {
- mkldnn_deconvolution_desc_t data;
- desc(algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &weights_desc,
- const memory::desc &diff_dst_desc,
- const memory::dims strides,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_deconvolution_backward_data_desc_init(
- &data, convert_to_c(aalgorithm), &diff_src_desc.data,
- &weights_desc.data, &diff_dst_desc.data,
- &strides[0], &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not create a deconvolution backward data descriptor");
- }
- desc(algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &weights_desc,
- const memory::desc &diff_dst_desc,
- const memory::dims strides,
- const memory::dims dilates,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(dilates);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_dilated_deconvolution_backward_data_desc_init(
- &data, convert_to_c(aalgorithm), &diff_src_desc.data,
- &weights_desc.data, &diff_dst_desc.data,
- &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not create a dilated deconvolution backward data descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e,
- const deconvolution_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
- const deconvolution_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
-
- REG_QUERY_MD(diff_src, diff_src, 0);
- REG_QUERY_MD(weights, weights, 0);
- REG_QUERY_MD(diff_dst, diff_dst, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- deconvolution_backward_data(const primitive_desc &pd): primitive(pd) {}
-};
-
-struct deconvolution_backward_weights : public primitive {
- struct desc {
- mkldnn_deconvolution_desc_t data;
- desc(algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_desc,
- const memory::dims strides,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init(
- &data, convert_to_c(aalgorithm), &src_desc.data,
- &diff_weights_desc.data, &diff_bias_desc.data,
- &diff_dst_desc.data,
- &strides[0], &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not create a deconvolution backward weights descriptor");
- }
- desc(algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_dst_desc,
- const memory::dims strides,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init(
- &data, convert_to_c(aalgorithm), &src_desc.data,
- &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
- &strides[0], &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not create a deconvolution backward weights descriptor");
- }
- desc(algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_desc,
- const memory::dims strides,
- const memory::dims dilates,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(dilates);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init(
- &data, convert_to_c(aalgorithm), &src_desc.data,
- &diff_weights_desc.data, &diff_bias_desc.data,
- &diff_dst_desc.data,
- &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not create a dilated deconvolution backward weights descriptor");
- }
- desc(algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_dst_desc,
- const memory::dims strides,
- const memory::dims dilates,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(dilates);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init(
- &data, convert_to_c(aalgorithm), &src_desc.data,
- &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
- &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not create a dilated deconvolution backward weights descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e,
- const deconvolution_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
- const deconvolution_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
-
- REG_QUERY_MD(src, src, 0);
- REG_QUERY_MD(diff_weights, diff_weights, 0);
- REG_QUERY_MD(diff_bias, diff_weights, 1);
- REG_QUERY_MD(diff_dst, diff_dst, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- deconvolution_backward_weights(const primitive_desc &pd): primitive(pd) {}
-};
-
-/// @}
-
-/// @addtogroup cpp_api_lrn LRN
-/// A primitive to perform local response normalization (LRN) across or within
-/// channels.
-///
-/// @sa @ref c_api_lrn in @ref c_api
-/// @{
-
-struct lrn_forward : public primitive {
- struct desc {
- mkldnn_lrn_desc_t data;
-
- desc(prop_kind aprop_kind, algorithm aalgorithm,
- const memory::desc &src_desc, memory::dim local_size,
- float alpha, float beta, float k = 1.f) {
- error::wrap_c_api(mkldnn_lrn_forward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
- &src_desc.data, local_size, alpha, beta, k),
- "could not create a lrn forward descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
- : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
-
- REG_QUERY_MD(src, src, 0);
- REG_QUERY_MD(dst, dst, 0);
- REG_QUERY_MD(workspace, workspace, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- lrn_forward(const primitive_desc &pd): primitive(pd) {}
-};
-
-struct lrn_backward : public primitive {
- struct desc {
- mkldnn_lrn_desc_t data;
-
- desc(algorithm aalgorithm, const memory::desc &data_desc,
- const memory::desc &diff_data_desc, memory::dim local_size,
- float alpha, float beta, float k = 1.f) {
- error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data,
- convert_to_c(aalgorithm), &diff_data_desc.data,
- &data_desc.data, local_size, alpha, beta, k),
- "could not create a lrn backward descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e,
- const lrn_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
- const lrn_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
-
- REG_QUERY_MD(diff_src, diff_src, 0);
- REG_QUERY_MD(diff_dst, diff_dst, 0);
- REG_QUERY_MD(workspace, workspace, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- lrn_backward(const primitive_desc &pd): primitive(pd) {}
-};
-
-/// @}
-
-/// @addtogroup cpp_api_pooling Pooling
-/// A primitive to perform max or average pooling.
-///
-/// @sa @ref c_api_pooling in @ref c_api
-/// @{
-
-struct pooling_forward : public primitive {
- struct desc {
- mkldnn_pooling_desc_t data;
- desc(prop_kind aprop_kind, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &dst_desc,
- const memory::dims strides,
- const memory::dims kernel,
- const memory::dims padding_l,
- const memory::dims padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(kernel);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_pooling_forward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind),
- convert_to_c(aalgorithm),
- &src_desc.data, &dst_desc.data,
- &strides[0], &kernel[0],
- &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not init a forward pooling descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
- : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
-
- REG_QUERY_MD(src, src, 0);
- REG_QUERY_MD(dst, dst, 0);
- REG_QUERY_MD(workspace, workspace, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- pooling_forward(const primitive_desc &pd): primitive(pd) {}
-};
-
-struct pooling_backward : public primitive {
- struct desc {
- mkldnn_pooling_desc_t data;
- desc(algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc,
- const memory::dims &strides,
- const memory::dims &kernel,
- const memory::dims &padding_l,
- const memory::dims &padding_r,
- const padding_kind apadding_kind) {
- memory::validate_dims(strides);
- memory::validate_dims(kernel);
- memory::validate_dims(padding_l);
- memory::validate_dims(padding_r);
- error::wrap_c_api(mkldnn_pooling_backward_desc_init(&data,
- convert_to_c(aalgorithm),
- &diff_src_desc.data, &diff_dst_desc.data,
- &strides[0], &kernel[0],
- &padding_l[0], &padding_r[0],
- mkldnn::convert_to_c(apadding_kind)),
- "could not init a backward pooling descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e,
- const pooling_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
- const pooling_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
-
- REG_QUERY_MD(diff_src, diff_src, 0);
- REG_QUERY_MD(diff_dst, diff_dst, 0);
- REG_QUERY_MD(workspace, workspace, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- pooling_backward(const primitive_desc &pd): primitive(pd) {}
-};
-
-/// @}
-
-/// @addtogroup cpp_api_eltwise Eltwise
-/// A primitive to compute element-wise operations like parametric rectifier
-/// linear unit (ReLU).
-///
-/// @sa @ref c_api_eltwise in @ref c_api
-/// @{
-
-struct eltwise_forward : public primitive {
- struct desc {
- mkldnn_eltwise_desc_t data;
- template <typename T>
- desc(prop_kind aprop_kind, algorithm alg_kind,
- const memory::desc &src_desc, T alpha = 0, T beta = 0) {
- error::wrap_c_api(mkldnn_eltwise_forward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind),
- mkldnn::convert_to_c(alg_kind), &src_desc.data,
- static_cast<float>(alpha), static_cast<float>(beta)),
- "could not create a eltwise forward descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
- : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
-
- REG_QUERY_MD(src, src, 0);
- REG_QUERY_MD(dst, dst, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- eltwise_forward(const primitive_desc &pd): primitive(pd) {}
-};
-
-struct eltwise_backward : public primitive {
- struct desc {
- mkldnn_eltwise_desc_t data;
-
- template <typename T>
- desc(algorithm alg_kind, const memory::desc &diff_data_desc,
- const memory::desc &data_desc, T alpha = 0, T beta = 0) {
- error::wrap_c_api(mkldnn_eltwise_backward_desc_init(&data,
- mkldnn::convert_to_c(alg_kind), &diff_data_desc.data,
- &data_desc.data, static_cast<float>(alpha),
- static_cast<float>(beta)),
- "could not create a eltwise backward descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e,
- const eltwise_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
- const eltwise_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
-
- REG_QUERY_MD(src, src, 0);
- REG_QUERY_MD(diff_src, diff_src, 0);
- REG_QUERY_MD(diff_dst, diff_dst, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- eltwise_backward(const primitive_desc &pd): primitive(pd) {}
-};
-
-/// @}
-
-/// @addtogroup cpp_api_softmax Softmax
-/// A primitive to perform softmax.
-///
-/// @sa @ref c_api_softmax in @ref c_api
-/// @{
-
-struct softmax_forward : public primitive {
- struct desc {
- mkldnn_softmax_desc_t data;
- desc(prop_kind aprop_kind, const memory::desc &data_desc,
- int softmax_axis) {
- error::wrap_c_api(mkldnn_softmax_forward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind), &data_desc.data,
- softmax_axis),
- "could not create a softmax forward descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
- : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
-
- REG_QUERY_MD(src, src, 0);
- REG_QUERY_MD(dst, dst, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- softmax_forward(const primitive_desc &pd): primitive(pd) {}
-};
-
-struct softmax_backward : public primitive {
- struct desc {
- mkldnn_softmax_desc_t data;
- desc(const memory::desc &diff_desc, const memory::desc &data_desc,
- int softmax_axis) {
- error::wrap_c_api(mkldnn_softmax_backward_desc_init(&data,
- &diff_desc.data, &data_desc.data, softmax_axis),
- "could not init a backward softmax descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e,
- const softmax_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
- const softmax_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
-
- REG_QUERY_MD(dst, dst, 0);
- REG_QUERY_MD(diff_src, diff_src, 0);
- REG_QUERY_MD(diff_dst, diff_dst, 0);
- REG_QUERY_MD(workspace, workspace, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- softmax_backward(const primitive_desc &pd): primitive(pd) {}
-};
-
-/// @}
-
-/// @addtogroup cpp_api_batch_norm Batch normalization
-/// A primitive to perform batch normalization.
-///
-/// @sa @ref c_api_batch_normalization in @ref c_api
-/// @{
-
-struct batch_normalization_forward : public primitive {
- struct desc {
- mkldnn_batch_normalization_desc_t data;
- template <typename T>
- desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon,
- unsigned flags) {
- error::wrap_c_api(
- mkldnn_batch_normalization_forward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind), &src_desc.data,
- static_cast<float>(epsilon), flags),
- "could not create a batch normalization forward descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
- : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
-
- REG_QUERY_MD(src, src, 0);
- REG_QUERY_MD(weights, weights, 0);
- REG_QUERY_MD(dst, dst, 0);
- REG_QUERY_MD(workspace, workspace, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
-
- memory::desc mean_desc() const { return stat_desc(mean); }
- memory::desc variance_desc() const { return stat_desc(var); }
-
- private:
- enum { mean = 1, var = 2, };
- memory::desc stat_desc(int kind) const {
- mkldnn_batch_normalization_desc_t *p;
- error::wrap_c_api(mkldnn_primitive_desc_query(
- get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p),
- "could not get a batch-normalization descriptor");
- return query_md(p->flags & use_global_stats ? src_md : dst_md, kind);
- }
- };
-
- batch_normalization_forward(const primitive_desc &pd): primitive(pd) {}
-};
-
-struct batch_normalization_backward : public primitive {
- struct desc {
- mkldnn_batch_normalization_desc_t data;
- template <typename T>
- desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
- const memory::desc &data_desc, T epsilon, unsigned flags) {
- error::wrap_c_api(
- mkldnn_batch_normalization_backward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind),
- &diff_data_desc.data, &data_desc.data,
- static_cast<float>(epsilon), flags),
- "could not create a batch normalization backward descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e,
- const batch_normalization_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
- const batch_normalization_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
-
- REG_QUERY_MD(src, src, 0);
- REG_QUERY_MD(mean, src, 1);
- REG_QUERY_MD(variance, src, 2);
- REG_QUERY_MD(weights, weights, 0);
- REG_QUERY_MD(dst, dst, 0);
- REG_QUERY_MD(diff_dst, diff_dst, 0);
- REG_QUERY_MD(workspace, workspace, 0);
-
- REG_QUERY_MD(diff_src, diff_src, 0);
- REG_QUERY_MD(diff_weights, diff_weights, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- batch_normalization_backward(const primitive_desc &pd): primitive(pd) {}
-};
-
-/// @}
-
-/// @addtogroup cpp_api_inner_product Inner Product
-/// A primitive to compute an inner product.
-///
-/// @sa @ref c_api_inner_product in @ref c_api
-/// @{
-
-struct inner_product_forward: public primitive {
- struct desc {
- mkldnn_inner_product_desc_t data;
- desc(prop_kind aprop_kind, const memory::desc &src_desc,
- const memory::desc &weights_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_desc) {
- error::wrap_c_api(
- mkldnn_inner_product_forward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind), &src_desc.data,
- &weights_desc.data, &bias_desc.data, &dst_desc.data),
- "could not create a inner product forward descriptor");
- }
-
- desc(prop_kind aprop_kind, const memory::desc &src_desc,
- const memory::desc &weights_desc,
- const memory::desc &dst_desc) {
- error::wrap_c_api(
- mkldnn_inner_product_forward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind), &src_desc.data,
- &weights_desc.data, nullptr, &dst_desc.data),
- "could not create a inner product forward descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
- : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
-
- REG_QUERY_MD(src, src, 0);
- REG_QUERY_MD(weights, weights, 0);
- REG_QUERY_MD(bias, weights, 1);
- REG_QUERY_MD(dst, dst, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- inner_product_forward(const primitive_desc &pd): primitive(pd) {}
-};
-
-struct inner_product_backward_data: public primitive {
- struct desc {
- mkldnn_inner_product_desc_t data;
- desc(const memory::desc &diff_src_desc,
- const memory::desc &weights_desc,
- const memory::desc &diff_dst_desc) {
- error::wrap_c_api(
- mkldnn_inner_product_backward_data_desc_init(&data,
- &diff_src_desc.data, &weights_desc.data,
- &diff_dst_desc.data),
- "could not create a inner product backward data descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e,
- const inner_product_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
- const inner_product_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
-
- REG_QUERY_MD(diff_src, diff_src, 0);
- REG_QUERY_MD(weights, weights, 0);
- REG_QUERY_MD(diff_dst, diff_dst, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- inner_product_backward_data(const primitive_desc &pd): primitive(pd) {}
-};
-
-struct inner_product_backward_weights: public primitive {
- struct desc {
- mkldnn_inner_product_desc_t data;
- desc(const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_desc) {
- error::wrap_c_api(
- mkldnn_inner_product_backward_weights_desc_init(
- &data, &src_desc.data, &diff_weights_desc.data,
- &diff_bias_desc.data, &diff_dst_desc.data),
- "could not create a inner product backward weights descriptor");
- }
- desc(const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_dst_desc) {
- error::wrap_c_api(
- mkldnn_inner_product_backward_weights_desc_init(
- &data, &src_desc.data, &diff_weights_desc.data,
- nullptr, &diff_dst_desc.data),
- "could not create a inner product backward weights descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e,
- const inner_product_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
- const inner_product_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
-
- REG_QUERY_MD(src, src, 0);
- REG_QUERY_MD(diff_weights, diff_weights, 0);
- REG_QUERY_MD(diff_bias, diff_weights, 1);
- REG_QUERY_MD(diff_dst, diff_dst, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- inner_product_backward_weights(const primitive_desc &pd): primitive(pd) {}
-};
-
-/// @}
-
-/// @addtogroup cpp_api_rnn RNN
-/// A primitive to compute common recurrent layer.
-///
-/// @sa @ref c_api_rnn in @ref c_api
-/// @{
-
-struct rnn_cell {
- struct desc {
- mkldnn_rnn_cell_desc_t c_rnn_cell_;
-
- desc(algorithm kind, algorithm activation_f) {
- error::wrap_c_api(mkldnn_rnn_cell_desc_init(&c_rnn_cell_,
- mkldnn::convert_to_c(kind),
- mkldnn::convert_to_c(activation_f), 0U, 0, 0),
- "could not init an rnn cell descriptor");
- }
- desc(algorithm kind): desc(kind, algorithm::algorithm_undef) {}
-
- operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; }
-
- algorithm get_cell_kind() const
- { return algorithm(c_rnn_cell_.cell_kind); }
- algorithm get_activation() const
- { return algorithm(c_rnn_cell_.activation_kind); }
-
- float get_alpha() const { return c_rnn_cell_.alpha; }
- void set_alpha(float alpha) {
- c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu;
- c_rnn_cell_.alpha = alpha;
- }
-
- float get_clipping() const { return c_rnn_cell_.clipping; }
- void set_clipping(float clipping) {
- c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping;
- c_rnn_cell_.clipping = clipping;
- }
-
- int get_gates_count() const {
- return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_);
- }
- int get_state_count() const {
- return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_);
- }
- };
-};
-
-struct rnn_forward : public primitive {
- struct desc {
- mkldnn_rnn_desc_t data;
- desc(prop_kind aprop_kind, rnn_cell::desc cell,
- const rnn_direction direction,
- const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc
- ) {
- error::wrap_c_api(mkldnn_rnn_forward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind), cell,
- mkldnn::convert_to_c(direction),
- &src_layer_desc.data, &src_iter_desc.data,
- &weights_layer_desc.data, &weights_iter_desc.data,
- &bias_desc.data,
- &dst_layer_desc.data, &dst_iter_desc.data),
- "could not create an RNN forward descriptor");
- }
-
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
- : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
-
- REG_QUERY_MD(src_layer, src, 0);
- REG_QUERY_MD(src_iter, src, 1);
- REG_QUERY_MD(weights_layer, weights, 0);
- REG_QUERY_MD(weights_iter, weights, 1);
- REG_QUERY_MD(bias, weights, 2);
- REG_QUERY_MD(dst_layer, dst, 0);
- REG_QUERY_MD(dst_iter, dst, 1);
- REG_QUERY_MD(workspace, workspace, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- rnn_forward(const primitive_desc &pd): primitive(pd) {}
-};
-
-struct rnn_backward : public primitive {
- struct desc {
- mkldnn_rnn_desc_t data;
- desc(prop_kind aprop_kind, rnn_cell::desc cell,
- const rnn_direction direction,
- const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const memory::desc &diff_src_layer_desc,
- const memory::desc &diff_src_iter_desc,
- const memory::desc &diff_weights_layer_desc,
- const memory::desc &diff_weights_iter_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_layer_desc,
- const memory::desc &diff_dst_iter_desc) {
- error::wrap_c_api(mkldnn_rnn_backward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind), cell,
- mkldnn::convert_to_c(direction),
- &src_layer_desc.data, &src_iter_desc.data,
- &weights_layer_desc.data, &weights_iter_desc.data,
- &bias_desc.data,
- &dst_layer_desc.data, &dst_iter_desc.data,
- &diff_src_layer_desc.data, &diff_src_iter_desc.data,
- &diff_weights_layer_desc.data,
- &diff_weights_iter_desc.data, &diff_bias_desc.data,
- &diff_dst_layer_desc.data, &diff_dst_iter_desc.data),
- "could not create an RNN backward descriptor");
- }
-
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e,
- const rnn_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
-
- primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
- const rnn_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
-
- REG_QUERY_MD(src_layer, src, 0);
- REG_QUERY_MD(src_iter, src, 1);
- REG_QUERY_MD(weights_layer, weights, 0);
- REG_QUERY_MD(weights_iter, weights, 1);
- REG_QUERY_MD(bias, weights, 2);
- REG_QUERY_MD(dst_layer, dst, 0);
- REG_QUERY_MD(dst_iter, dst, 1);
- REG_QUERY_MD(workspace, workspace, 0);
-
- REG_QUERY_MD(diff_src_layer, diff_src, 0);
- REG_QUERY_MD(diff_src_iter, diff_src, 1);
- REG_QUERY_MD(diff_weights_layer, diff_weights, 0);
- REG_QUERY_MD(diff_weights_iter, diff_weights, 1);
- REG_QUERY_MD(diff_bias, diff_weights, 2);
- REG_QUERY_MD(diff_dst_layer, diff_dst, 0);
- REG_QUERY_MD(diff_dst_iter, diff_dst, 1);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- // With last iteration (with and without input src_iter)
- rnn_backward(const primitive_desc &pd): primitive(pd) {}
-};
-
-/// @}
-
-/// @addtogroup cpp_api_shuffle Shuffle
-/// A primitive to shuffle data along the axis.
-///
-/// @sa @ref c_api_shuffle in @ref c_api
-/// @{
-
-struct shuffle_forward : public primitive {
- struct desc {
- mkldnn_shuffle_desc_t data;
- desc(prop_kind aprop_kind, const memory::desc &data_desc,
- int axis, int group_size) {
- error::wrap_c_api(mkldnn_shuffle_forward_desc_init(&data,
- mkldnn::convert_to_c(aprop_kind), &data_desc.data,
- axis, group_size),
- "could not create a shuffle forward descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
-
- REG_QUERY_MD(src, src, 0);
- REG_QUERY_MD(dst, dst, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- shuffle_forward(const primitive_desc &pd): primitive(pd) {}
-};
-
-struct shuffle_backward : public primitive {
- struct desc {
- mkldnn_shuffle_desc_t data;
- desc(const memory::desc &diff_data_desc, int axis, int group_size) {
- error::wrap_c_api(mkldnn_shuffle_backward_desc_init(&data,
- &diff_data_desc.data, axis, group_size),
- "could not create a shuffle backward descriptor");
- }
- };
-
- struct primitive_desc : public mkldnn::primitive_desc {
- primitive_desc(const desc &desc, const engine &e,
- const shuffle_forward::primitive_desc &hint_fwd_pd)
- : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
-
- REG_QUERY_MD(diff_src, diff_src, 0);
- REG_QUERY_MD(diff_dst, diff_dst, 0);
- REG_QUERY_MD(scratchpad, scratchpad, 0);
- };
-
- shuffle_backward(const primitive_desc &pd): primitive(pd) {}
-};
-
-/// @}
-
-/// @} Primitives
-
-/// @} C++ API
-
-#undef REG_QUERY_MD
-
-// implementation section
-#ifndef DOXYGEN_SHOULD_SKIP_THIS
-
-inline primitive::primitive(const_mkldnn_primitive_desc_t c_pd) {
- mkldnn_primitive_t result;
- error::wrap_c_api(mkldnn_primitive_create(&result, c_pd),
- "could not create a primitive");
- reset(result);
-}
-
-inline primitive::primitive(const primitive_desc &pd): primitive(pd.get()) {}
-
-inline void primitive::execute(stream &astream,
- const std::unordered_map<int, memory> &args) const {
- std::vector<mkldnn_exec_arg_t> c_args;
- c_args.reserve(args.size());
- for (const auto &a: args)
- c_args.push_back({a.first, a.second.get()});
-
- error::wrap_c_api(mkldnn_primitive_execute(get(), astream.get(),
- (int)c_args.size(), c_args.data()),
- "primitive execution fail");
-}
-#endif // DOXYGEN_SHOULD_SKIP_THIS
-
-} // namespace mkldnn
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h b/thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h
deleted file mode 100644
index f4dc2fdfa6..0000000000
--- a/thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h
+++ /dev/null
@@ -1,98 +0,0 @@
-/*******************************************************************************
-* Copyright 2018-2019 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-/* DO NOT EDIT, AUTO-GENERATED */
-
-#ifndef MKLDNN_DEBUG_H
-#define MKLDNN_DEBUG_H
-
-#ifndef DOXYGEN_SHOULD_SKIP_THIS
-
-/* All symbols shall be internal unless marked as MKLDNN_API */
-#if defined _WIN32 || defined __CYGWIN__
-# define MKLDNN_HELPER_DLL_IMPORT __declspec(dllimport)
-# define MKLDNN_HELPER_DLL_EXPORT __declspec(dllexport)
-#else
-# if __GNUC__ >= 4
-# define MKLDNN_HELPER_DLL_IMPORT __attribute__ ((visibility ("default")))
-# define MKLDNN_HELPER_DLL_EXPORT __attribute__ ((visibility ("default")))
-# else
-# define MKLDNN_HELPER_DLL_IMPORT
-# define MKLDNN_HELPER_DLL_EXPORT
-# endif
-#endif
-
-#ifdef MKLDNN_DLL
-# ifdef MKLDNN_DLL_EXPORTS
-# define MKLDNN_API MKLDNN_HELPER_DLL_EXPORT
-# else
-# define MKLDNN_API MKLDNN_HELPER_DLL_IMPORT
-# endif
-#else
-# define MKLDNN_API
-#endif
-
-#if defined (__GNUC__)
-# define MKLDNN_DEPRECATED __attribute__((deprecated))
-#elif defined(_MSC_VER)
-# define MKLDNN_DEPRECATED __declspec(deprecated)
-#else
-# define MKLDNN_DEPRECATED
-#endif
-
-#include "mkldnn_types.h"
-#endif /* DOXYGEN_SHOULD_SKIP_THIS */
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-const char MKLDNN_API *mkldnn_status2str(mkldnn_status_t v);
-const char MKLDNN_API *mkldnn_dt2str(mkldnn_data_type_t v);
-const char MKLDNN_API *mkldnn_fmt_kind2str(mkldnn_format_kind_t v);
-const char MKLDNN_API *mkldnn_fmt_tag2str(mkldnn_format_tag_t v);
-const char MKLDNN_API *mkldnn_prop_kind2str(mkldnn_prop_kind_t v);
-const char MKLDNN_API *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v);
-const char MKLDNN_API *mkldnn_alg_kind2str(mkldnn_alg_kind_t v);
-const char MKLDNN_API *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v);
-
-/** Forms a format string for a given memory descriptor.
- *
- * The format is defined as: 'dt:[p|o|0]:fmt_kind:fmt:extra'.
- * Here:
- * - dt -- data type
- * - p -- indicates there is non-trivial padding
- * - o -- indicates there is non-trivial padding offset
- * - 0 -- indicates there is non-trivial offset0
- * - fmt_kind -- format kind (blocked, wino, etc...)
- * - fmt -- extended format string (format_kind specific)
- * - extra -- shows extra fields (underspecified)
- */
-int MKLDNN_API mkldnn_md2fmt_str(char *fmt_str, size_t fmt_str_len,
- const mkldnn_memory_desc_t *md);
-
-/** Forms a dimension string for a given memory descriptor.
- *
- * The format is defined as: 'dim0xdim1x...xdimN
- */
-int MKLDNN_API mkldnn_md2dim_str(char *dim_str, size_t dim_str_len,
- const mkldnn_memory_desc_t *md);
-
-#ifdef __cplusplus
-}
-#endif
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_types.h b/thirdparty/oidn/mkl-dnn/include/mkldnn_types.h
deleted file mode 100644
index 1b6c356982..0000000000
--- a/thirdparty/oidn/mkl-dnn/include/mkldnn_types.h
+++ /dev/null
@@ -1,1415 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef MKLDNN_TYPES_H
-#define MKLDNN_TYPES_H
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-#ifndef DOXYGEN_SHOULD_SKIP_THIS
-#include <stddef.h>
-#include <stdint.h>
-#endif
-
-/** @addtogroup c_api C API
- * @{
- *
- * @addtogroup c_api_types Types
- * @{
- *
- * @addtogroup c_api_types_generic Generic
- * @{ */
-
-/** Intel(R) MKL-DNN Version type */
-typedef struct {
- int major;
- int minor;
- int patch;
- const char *hash;
-} mkldnn_version_t;
-
-/** Status values returned by Intel(R) MKL-DNN functions. */
-typedef enum {
- /** The operation was successful */
- mkldnn_success = 0,
- /** The operation failed due to an out-of-memory condition */
- mkldnn_out_of_memory = 1,
- /** The operation failed and should be retried */
- mkldnn_try_again = 2,
- /** The operation failed because of incorrect function arguments */
- mkldnn_invalid_arguments = 3,
- /** The operation failed because a primitive was not ready for execution */
- mkldnn_not_ready = 4,
- /** The operation failed because requested functionality is not implemented
- */
- mkldnn_unimplemented = 5,
- /** Primitive iterator passed over last primitive descriptor */
- mkldnn_iterator_ends = 6,
- /** Primitive or engine failed on execution */
- mkldnn_runtime_error = 7,
- /** Queried element is not required for given primitive */
- mkldnn_not_required = 8,
-} mkldnn_status_t;
-
-/** Data type specification */
-typedef enum {
- /** Undefined data type, used for empty memory descriptors. */
- mkldnn_data_type_undef = 0,
- /** 32-bit/single-precision floating point. */
- mkldnn_f32 = 1,
- /** 32-bit signed integer. */
- mkldnn_s32 = 2,
- /** 8-bit signed integer. */
- mkldnn_s8 = 3,
- /** 8-bit unsigned integer. */
- mkldnn_u8 = 4,
-} mkldnn_data_type_t;
-
-/** Memory format kind */
-typedef enum {
- /** Undefined memory format, used for empty memory descriptors. */
- mkldnn_format_kind_undef = 0,
- /** Unspecified format. The primitive selects a format automatically. */
- mkldnn_format_kind_any,
- /** A tensor in a generic format described by the stride and blocking
- * values in each dimension. See #mkldnn_blocking_desc_t for more
- * information. */
- mkldnn_blocked,
- /** Weights format used in 8bit Winograd convolution */
- mkldnn_format_kind_wino,
- /** Packed weights format used in RNN */
- mkldnn_format_kind_rnn_packed,
-} mkldnn_format_kind_t;
-
-/** Memory format tag specification.
- *
- * Intel MKL-DNN formats describe physical data layout. The physical layout
- * is described as a sequence of the dimensions as they are laid out in the
- * memory (from the outer-most to the inner-most). Note that this order
- * doesn't affect the logical order of the dimensions that is kept in the
- * `dims` field of the mkldnn_memory_desc_t structure. The logical order of the
- * dimensions is specified by the type of tensor.
- *
- * For example, CNN 5D tensor always has its logical dimensions in the order
- * `(batch, channels, depth, height, width)`, while the physical layout might be
- * #mkldnn_ncdhw or #mkldnn_ndhwc:
- *
- * ~~~cpp
- * int batch = 2, channels = 16, depth = 13, height = 13, width = 13;
- *
- * int ndims = 5; // 5D tensor
- * mkldnn_dims_t dims = {batch, channels, depth, height, width};
- * mkldnn_memory_desc_t data_in_ncdhw;
- * mkldnn_memory_desc_init_by_tag(
- * &data_in_ncdhw, 5, dims, mkldnn_f32, mkldnn_ncdhw);
- *
- * // note that in both cases dims passed are the same
- * mkldnn_memory_desc_t data_in_ndhwc;
- * mkldnn_memory_desc_init_by_tag(
- * &data_in_ndhwc, 5, dims, mkldnn_f32, mkldnn_ndhwc);
- * ~~~
- *
- * The following notation applies to memory format names:
- * - @c 'n' denotes the mini-batch dimension
- * - @c 'c' denotes a channels dimension
- * - When there are multiple channel dimensions (for example, in convolution
- * weights tensor), @c 'i' and @c 'o' denote dimensions of input and output
- * channels
- * - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width
- * respectively
- * - Upper-case letters indicate that the data is laid out in blocks
- * for a particular dimension. In such cases, the format name contains both
- * upper- and lower-case letters for that dimension with a lower-case letter
- * preceded by the block size. For example: @c 'mkldnn_nChw8c' describes a
- * format where the outermost dimension is mini-batch, followed by the
- * channel block number, followed by the spatial height and width, and
- * finally followed by 8-element channel blocks.
- *
- * @note
- * Channel designations can be different. For example, both the @c
- * 'mkldnn_nc' and @c 'mkldnn_io' formats can be used to describe a 2D
- * tensor.
- *
- * @sa @ref understanding_memory_formats
- */
-typedef enum {
- /** Undefined memory format tag */
- mkldnn_format_tag_undef = 0,
- /** Undefined memory format tag.
- * The primitive selects a format automatically. */
- mkldnn_format_tag_any,
-
- /* Semantic agnostic section */
- /* The physical order of dimensions is defined by the permutation of the
- * characters, assuming that ab..z defines the natural order.
- */
-
- /* Plain formats */
-
- mkldnn_a,
- mkldnn_ab,
- mkldnn_abc,
- mkldnn_abcd,
- mkldnn_abcde,
- mkldnn_abcdef,
- mkldnn_abdec,
- mkldnn_acb,
- mkldnn_acbde,
- mkldnn_acdb,
- mkldnn_acdeb,
- mkldnn_ba,
- mkldnn_bac,
- mkldnn_bacd,
- mkldnn_bcda,
- mkldnn_cba,
- mkldnn_cdba,
- mkldnn_cdeba,
- mkldnn_decab,
-
- /* Opaque blocked formats */
-
- mkldnn_Abc16a,
- mkldnn_ABc16a16b,
- mkldnn_aBc16b,
- mkldnn_ABc16b16a,
- mkldnn_Abc4a,
- mkldnn_aBc4b,
- mkldnn_ABc4b16a4b,
- mkldnn_ABc4b4a,
- mkldnn_ABc8a16b2a,
- mkldnn_ABc8a8b,
- mkldnn_aBc8b,
- mkldnn_ABc8b16a2b,
- mkldnn_ABc8b8a,
- mkldnn_Abcd16a,
- mkldnn_ABcd16a16b,
- mkldnn_aBcd16b,
- mkldnn_ABcd16b16a,
- mkldnn_aBCd16b16c,
- mkldnn_aBCd16c16b,
- mkldnn_Abcd4a,
- mkldnn_aBcd4b,
- mkldnn_ABcd4b16a4b,
- mkldnn_ABcd4b4a,
- mkldnn_aBCd4c16b4c,
- mkldnn_aBCd4c4b,
- mkldnn_ABcd8a16b2a,
- mkldnn_ABcd8a8b,
- mkldnn_aBcd8b,
- mkldnn_ABcd8b16a2b,
- mkldnn_aBCd8b16c2b,
- mkldnn_ABcd8b8a,
- mkldnn_aBCd8b8c,
- mkldnn_aBCd8c16b2c,
- mkldnn_aBCd8c8b,
- mkldnn_Abcde16a,
- mkldnn_ABcde16a16b,
- mkldnn_aBcde16b,
- mkldnn_ABcde16b16a,
- mkldnn_aBCde16b16c,
- mkldnn_aBCde16c16b,
- mkldnn_aBCde2c8b4c,
- mkldnn_Abcde4a,
- mkldnn_aBcde4b,
- mkldnn_ABcde4b4a,
- mkldnn_aBCde4b4c,
- mkldnn_aBCde4c16b4c,
- mkldnn_aBCde4c4b,
- mkldnn_Abcde8a,
- mkldnn_ABcde8a8b,
- mkldnn_aBcde8b,
- mkldnn_ABcde8b16a2b,
- mkldnn_aBCde8b16c2b,
- mkldnn_ABcde8b8a,
- mkldnn_aBCde8b8c,
- mkldnn_aBCde8c16b2c,
- mkldnn_aBCde8c8b,
- mkldnn_aBcdef16b,
- mkldnn_aBCdef16b16c,
- mkldnn_aBCdef16c16b,
- mkldnn_aBcdef4b,
- mkldnn_aBCdef4c4b,
- mkldnn_aBCdef8b8c,
- mkldnn_aBCdef8c16b2c,
- mkldnn_aBCdef8c8b,
- mkldnn_aBdc16b,
- mkldnn_aBdc4b,
- mkldnn_aBdc8b,
- mkldnn_aBdec16b,
- mkldnn_aBdec4b,
- mkldnn_aBdec8b,
- mkldnn_aBdefc16b,
- mkldnn_aBdefc4b,
- mkldnn_aBdefc8b,
- mkldnn_Acb16a,
- mkldnn_Acb4a,
- mkldnn_Acb8a,
- mkldnn_aCBd16b16c,
- mkldnn_aCBde16b16c,
- mkldnn_Acdb16a,
- mkldnn_Acdb4a,
- mkldnn_Acdb8a,
- mkldnn_Acdeb16a,
- mkldnn_Acdeb4a,
- mkldnn_Acdeb8a,
- mkldnn_BAc16a16b,
- mkldnn_BAcd16a16b,
-
- /** Just a sentinel, not real memory format tag. Must be changed after new
- * format tag is added. */
- mkldnn_format_tag_last,
-
- /* Aliases */
-
- mkldnn_x = mkldnn_a,
- mkldnn_nc = mkldnn_ab,
- mkldnn_cn = mkldnn_ba,
- mkldnn_ncw = mkldnn_abc,
- mkldnn_nwc = mkldnn_acb,
- mkldnn_nchw = mkldnn_abcd,
- mkldnn_nhwc = mkldnn_acdb,
- mkldnn_chwn = mkldnn_bcda,
- mkldnn_ncdhw = mkldnn_abcde,
- mkldnn_ndhwc = mkldnn_acdeb,
-
- mkldnn_oi = mkldnn_ab,
- mkldnn_io = mkldnn_ba,
- mkldnn_oiw = mkldnn_abc,
- mkldnn_wio = mkldnn_cba,
- mkldnn_oihw = mkldnn_abcd,
- mkldnn_hwio = mkldnn_cdba,
- mkldnn_ihwo = mkldnn_bcda,
- mkldnn_iohw = mkldnn_bacd,
- mkldnn_oidhw = mkldnn_abcde,
- mkldnn_dhwio = mkldnn_cdeba,
- mkldnn_goiw = mkldnn_abcd,
- mkldnn_goihw = mkldnn_abcde,
- mkldnn_hwigo = mkldnn_decab,
- mkldnn_giohw = mkldnn_acbde,
- mkldnn_goidhw = mkldnn_abcdef,
-
- /** 3D RNN data tensor in the format (seq_length, batch, input channels). */
- mkldnn_tnc = mkldnn_abc,
- /** 3D RNN data tensor in the format (batch, seq_length, input channels). */
- mkldnn_ntc = mkldnn_bac,
- /** 5D RNN states tensor in the format (num_layers, num_directions,
- * num_states, batch, state channels). */
- mkldnn_ldsnc = mkldnn_abcde,
- /** 5D RNN weights tensor in the format (num_layers, num_directions,
- * input_channels, num_gates, output_channels).
- *
- * - For LSTM cells, the gates order is input, forget, candidate
- * and output gate.
- * - For GRU cells, the gates order is update, reset and output gate. */
- mkldnn_ldigo = mkldnn_abcde,
- /** 5D RNN weights tensor in the format (num_layers, num_directions,
- * num_gates, output_channels, input_channels).
- *
- * - For LSTM cells, the gates order is input, forget, candidate
- * and output gate.
- * - For GRU cells, the gates order is update, reset and output gate. */
- mkldnn_ldgoi = mkldnn_abdec,
- /** 4D RNN bias tensor in the format (num_layers, num_directions,
- * num_gates, output_channels).
- *
- * - For LSTM cells, the gates order is input, forget, candidate
- * and output gate.
- * - For GRU cells, the gates order is update, reset and output gate. */
- mkldnn_ldgo = mkldnn_abcd,
-
- /* Opaque data types, are not to be used explicitly */
-
- /* data */
- mkldnn_nCdhw16c = mkldnn_aBcde16b,
- mkldnn_nCdhw4c = mkldnn_aBcde4b,
- mkldnn_nCdhw8c = mkldnn_aBcde8b,
- mkldnn_nChw16c = mkldnn_aBcd16b,
- mkldnn_nChw4c = mkldnn_aBcd4b,
- mkldnn_nChw8c = mkldnn_aBcd8b,
- mkldnn_nCw16c = mkldnn_aBc16b,
- mkldnn_nCw4c = mkldnn_aBc4b,
- mkldnn_nCw8c = mkldnn_aBc8b,
-
- /* weights, 3D */
- mkldnn_IOw16o16i = mkldnn_BAc16a16b,
- mkldnn_OIw16i16o = mkldnn_ABc16b16a,
- mkldnn_OIw16o16i = mkldnn_ABc16a16b,
- mkldnn_Oiw16o = mkldnn_Abc16a,
- mkldnn_OIw4i16o4i = mkldnn_ABc4b16a4b,
- mkldnn_OIw4i4o = mkldnn_ABc4b4a,
- mkldnn_Oiw4o = mkldnn_Abc4a,
- mkldnn_OIw8i16o2i = mkldnn_ABc8b16a2b,
- mkldnn_OIw8i8o = mkldnn_ABc8b8a,
- mkldnn_OIw8o16i2o = mkldnn_ABc8a16b2a,
- mkldnn_OIw8o8i = mkldnn_ABc8a8b,
- mkldnn_Owi16o = mkldnn_Acb16a,
- mkldnn_Owi4o = mkldnn_Acb4a,
- mkldnn_Owi8o = mkldnn_Acb8a,
-
- /* weights, 4D */
- mkldnn_IOhw16o16i = mkldnn_BAcd16a16b,
- mkldnn_Ohwi16o = mkldnn_Acdb16a,
- mkldnn_Ohwi4o = mkldnn_Acdb4a,
- mkldnn_Ohwi8o = mkldnn_Acdb8a,
- mkldnn_OIhw16i16o = mkldnn_ABcd16b16a,
- mkldnn_OIhw16o16i = mkldnn_ABcd16a16b,
- mkldnn_Oihw16o = mkldnn_Abcd16a,
- mkldnn_OIhw4i16o4i = mkldnn_ABcd4b16a4b,
- mkldnn_OIhw4i4o = mkldnn_ABcd4b4a,
- mkldnn_Oihw4o = mkldnn_Abcd4a,
- mkldnn_OIhw8i16o2i = mkldnn_ABcd8b16a2b,
- mkldnn_OIhw8i8o = mkldnn_ABcd8b8a,
- mkldnn_OIhw8o16i2o = mkldnn_ABcd8a16b2a,
- mkldnn_OIhw8o8i = mkldnn_ABcd8a8b,
-
- /* weights, 5D */
- mkldnn_Odhwi16o = mkldnn_Acdeb16a,
- mkldnn_Odhwi4o = mkldnn_Acdeb4a,
- mkldnn_Odhwi8o = mkldnn_Acdeb8a,
- mkldnn_OIdhw16i16o = mkldnn_ABcde16b16a,
- mkldnn_OIdhw16o16i = mkldnn_ABcde16a16b,
- mkldnn_Oidhw16o = mkldnn_Abcde16a,
- mkldnn_OIdhw4i4o = mkldnn_ABcde4b4a,
- mkldnn_Oidhw4o = mkldnn_Abcde4a,
- mkldnn_OIdhw8i16o2i = mkldnn_ABcde8b16a2b,
- mkldnn_OIdhw8i8o = mkldnn_ABcde8b8a,
- mkldnn_OIdhw8o8i = mkldnn_ABcde8a8b,
-
- /* weights w/ groups, 3D */
- mkldnn_Goiw16g = mkldnn_Abcd16a,
- mkldnn_gIOw16o16i = mkldnn_aCBd16b16c,
- mkldnn_gOIw16i16o = mkldnn_aBCd16c16b,
- mkldnn_gOIw16o16i = mkldnn_aBCd16b16c,
- mkldnn_gOiw16o = mkldnn_aBcd16b,
- mkldnn_gOIw4i16o4i = mkldnn_aBCd4c16b4c,
- mkldnn_gOIw4i4o = mkldnn_aBCd4c4b,
- mkldnn_gOiw4o = mkldnn_aBcd4b,
- mkldnn_gOIw8i16o2i = mkldnn_aBCd8c16b2c,
- mkldnn_gOIw8i8o = mkldnn_aBCd8c8b,
- mkldnn_gOIw8o16i2o = mkldnn_aBCd8b16c2b,
- mkldnn_gOIw8o8i = mkldnn_aBCd8b8c,
- mkldnn_gOwi16o = mkldnn_aBdc16b,
- mkldnn_gOwi4o = mkldnn_aBdc4b,
- mkldnn_gOwi8o = mkldnn_aBdc8b,
-
- /* weights w/ groups, 4D */
- mkldnn_gIOhw16o16i = mkldnn_aCBde16b16c,
- mkldnn_gOhwi16o = mkldnn_aBdec16b,
- mkldnn_gOhwi4o = mkldnn_aBdec4b,
- mkldnn_gOhwi8o = mkldnn_aBdec8b,
- mkldnn_Goihw16g = mkldnn_Abcde16a,
- mkldnn_gOIhw16i16o = mkldnn_aBCde16c16b,
- mkldnn_gOIhw16o16i = mkldnn_aBCde16b16c,
- mkldnn_gOihw16o = mkldnn_aBcde16b,
- mkldnn_gOIhw2i8o4i = mkldnn_aBCde2c8b4c,
- mkldnn_gOIhw4i16o4i = mkldnn_aBCde4c16b4c,
- mkldnn_gOIhw4i4o = mkldnn_aBCde4c4b,
- mkldnn_gOIhw4o4i = mkldnn_aBCde4b4c,
- mkldnn_gOihw4o = mkldnn_aBcde4b,
- mkldnn_Goihw8g = mkldnn_Abcde8a,
- mkldnn_gOIhw8i16o2i = mkldnn_aBCde8c16b2c,
- mkldnn_gOIhw8i8o = mkldnn_aBCde8c8b,
- mkldnn_gOIhw8o16i2o = mkldnn_aBCde8b16c2b,
- mkldnn_gOIhw8o8i = mkldnn_aBCde8b8c,
-
- /* weights w/ groups, 6D */
- mkldnn_gOdhwi16o = mkldnn_aBdefc16b,
- mkldnn_gOdhwi4o = mkldnn_aBdefc4b,
- mkldnn_gOdhwi8o = mkldnn_aBdefc8b,
- mkldnn_gOIdhw16i16o = mkldnn_aBCdef16c16b,
- mkldnn_gOIdhw16o16i = mkldnn_aBCdef16b16c,
- mkldnn_gOidhw16o = mkldnn_aBcdef16b,
- mkldnn_gOIdhw4i4o = mkldnn_aBCdef4c4b,
- mkldnn_gOidhw4o = mkldnn_aBcdef4b,
- mkldnn_gOIdhw8i16o2i = mkldnn_aBCdef8c16b2c,
- mkldnn_gOIdhw8i8o = mkldnn_aBCdef8c8b,
- mkldnn_gOIdhw8o8i = mkldnn_aBCdef8b8c,
-} mkldnn_format_tag_t;
-
-/** Kinds of padding. Define how to interpret the data in padding regions. */
-typedef enum {
- /** The data in padding regions is zero. */
- mkldnn_padding_zero,
-} mkldnn_padding_kind_t;
-
-/** Kinds of propagation. */
-typedef enum {
- /* TODO: suggest renames */
- /** Undefined propagation type. */
- mkldnn_prop_kind_undef = 0,
- /** Forward data propagation (training mode). In this mode primitives
- * perform computations necessary for subsequent backward propagation. */
- mkldnn_forward_training = 64,
- /** Forward data propagation (inference mode). In this mode primitives
- * perform only computations that are necessary for inference and omit
- * computations that are necessary only for backward propagation. */
- mkldnn_forward_inference = 96,
- /** Forward data propagation (alias for @c mkldnn_forward_inference) */
- mkldnn_forward_scoring = mkldnn_forward_inference,
- /** Forward data propagation (alias for @c mkldnn_forward_training) */
- mkldnn_forward = mkldnn_forward_training,
- /** Backward propagation (with respect to all parameters */
- mkldnn_backward = 128,
- /** Backward data propagation */
- mkldnn_backward_data = 160,
- /** Backward weights propagation */
- mkldnn_backward_weights = 192,
- /** Backward bias propagation */
- mkldnn_backward_bias = 193,
-} mkldnn_prop_kind_t;
-
-/** Kinds of primitives. Used to implement a way to extend the library with new
- * primitives without changing the ABI. */
-typedef enum {
- /** Undefined primitive (XXX: why do we have it?). */
- mkldnn_undefined_primitive,
- /** A reorder primitive.*/
- mkldnn_reorder,
- /** A shuffle primitive.*/
- mkldnn_shuffle,
- /** A (out-of-place) concat primitive. */
- mkldnn_concat,
- /** A sum primitive. */
- mkldnn_sum,
- /** A convolution primitive. */
- mkldnn_convolution,
- /** A deconvolution primitive. */
- mkldnn_deconvolution,
- /** An element-wise primitive. */
- mkldnn_eltwise,
- /** A Softmax primitive. */
- mkldnn_softmax,
- /** A pooling primitive. */
- mkldnn_pooling,
- /** An LRN primitive. */
- mkldnn_lrn,
- /** An batch normalization primitive. */
- mkldnn_batch_normalization,
- /** An inner product primitive. */
- mkldnn_inner_product,
- /** A rnn primitive. */
- mkldnn_rnn,
-} mkldnn_primitive_kind_t;
-
-/** Kinds of algorithms. */
-typedef enum {
- mkldnn_alg_kind_undef,
- /** Direct convolution */
- mkldnn_convolution_direct = 0x1,
- /** Winograd convolution */
- mkldnn_convolution_winograd = 0x2,
- /** Convolution algorithm(either direct or Winograd) is chosen just in time **/
- mkldnn_convolution_auto = 0x3,
- /** Direct deconvolution */
- mkldnn_deconvolution_direct = 0xa,
- /** Winograd deconvolution */
- mkldnn_deconvolution_winograd = 0xb,
- /** Eltwise: ReLU */
- mkldnn_eltwise_relu = 0x1f,
- /** Eltwise: hyperbolic tangent non-linearity (tanh) */
- mkldnn_eltwise_tanh = 0x2f,
- /** Eltwise: parametric exponential linear unit (elu) */
- mkldnn_eltwise_elu = 0x3f,
- /** Eltwise: square */
- mkldnn_eltwise_square = 0x4f,
- /** Eltwise: abs */
- mkldnn_eltwise_abs = 0x5f,
- /** Eltwise: square root */
- mkldnn_eltwise_sqrt = 0x6f,
- /** Eltwise: linear */
- mkldnn_eltwise_linear = 0x7f,
- /** Eltwise: bounded_relu */
- mkldnn_eltwise_bounded_relu = 0x8f,
- /** Eltwise: soft_relu */
- mkldnn_eltwise_soft_relu = 0x9f,
- /** Eltwise: logistic */
- mkldnn_eltwise_logistic = 0xaf,
- /** Max pooling */
- mkldnn_pooling_max = 0x1ff,
- /** Average pooling include padding */
- mkldnn_pooling_avg_include_padding = 0x2ff,
- /** Average pooling exclude padding */
- mkldnn_pooling_avg_exclude_padding = 0x3ff,
- mkldnn_pooling_avg = mkldnn_pooling_avg_exclude_padding,
- /** Local response normalization (LRN) across multiple channels */
- mkldnn_lrn_across_channels = 0xaff,
- /** LRN within a single channel */
- mkldnn_lrn_within_channel = 0xbff,
- /** RNN cell */
- mkldnn_vanilla_rnn = 0x1fff,
- /** LSTM cell */
- mkldnn_vanilla_lstm = 0x2fff,
- /** GRU cell */
- mkldnn_vanilla_gru = 0x3fff,
- /** GRU cell with linear before reset
- *
- * Modification of original GRU cell. Differs from #mkldnn_vanilla_gru
- * in how the new memory gate is calculated:
- * \f[ c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f]
- * Primitive expects 4 biases on input:
- * \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$
- * */
- mkldnn_gru_linear_before_reset = 0x4fff,
-} mkldnn_alg_kind_t;
-
-/** Flags for batch-normalization primititve. */
-typedef enum {
- /** Use global statistics
- *
- * If specified
- * - on forward propagation use mean and variance provided by user (input)
- * - on backward propagation reduces the amount of computations, since
- * mean and variance are considered as constants
- *
- * If not specified:
- * - on forward propagation mean and variance are computed and stored in
- * output
- * - on backward propagation compute full derivative wrt to data
- */
- mkldnn_use_global_stats = 0x1U,
- /** Use scale and shift parameters
- *
- * If specified:
- * - on forward propagation use scale and shift (aka scale and bias) for
- * the batch normalization results
- * - on backward propagation (for prop_kind == #mkldnn_backward) compute
- * diff wrt to scale and shift (hence one extra output used)
- *
- * If no specified:
- * - on backward propagation prop_kind == #mkldnn_backward_data has the
- * same behavior as prop_kind == #mkldnn_backward
- */
- mkldnn_use_scaleshift = 0x2U,
- /** Fuse with ReLU
- *
- * If specified:
- * - on inference this option behaves the same as if the primitive were
- * fused with ReLU via post ops API
- * - on training primitive requires workspace (required to be able to
- * perform backward pass)
- */
- mkldnn_fuse_bn_relu = 0x4U,
-} mkldnn_batch_normalization_flag_t;
-
-/** @} */
-
-/** @addtogroup c_api_types_memory Memory
- * @{ */
-
-/** Maximum number of dimensions a tensor can have. Only restricts the amount
- * of space used for the tensor description. Individual computational
- * primitives may support only tensors of certain dimensions. */
-#define MKLDNN_MAX_NDIMS 12
-
-/** A type to describe tensor dimension. */
-typedef int64_t mkldnn_dim_t;
-
-/** A type to describe tensor dimensions. */
-typedef mkldnn_dim_t mkldnn_dims_t[MKLDNN_MAX_NDIMS];
-
-/** A type to describe strides within a tensor. */
-typedef mkldnn_dim_t mkldnn_strides_t[MKLDNN_MAX_NDIMS];
-
-/** Generic description of blocked data layout for most memory formats.
- *
- * @sa @ref understanding_memory_formats */
-typedef struct {
- /** The strides between the outermost blocks.
- * In case of plain (non-blocked) formats the strides between dimensions. */
- mkldnn_dims_t strides;
- /* Innermost section
- * ASSUMPTION: the innermost blocks are always dense */
- /** The number of innermost blocks, e.g. 3 in case of `OIhw_4i16o4i_` */
- int inner_nblks;
- /** The size of the blocks, e.g. `{4, 16, 4}` in case of `OIhw_4i16o4i` */
- mkldnn_dims_t inner_blks;
- /** The logical indices of the blocks, e.g. `{1, 0, 1}` in case of
- * `4i16o4i`, because `i` is the 1st dim and `o` is the 0st dim */
- mkldnn_dims_t inner_idxs;
-} mkldnn_blocking_desc_t;
-
-typedef enum {
- /** Undefined memory format, used for empty memory descriptors. */
- mkldnn_wino_undef = 0,
- /** Tensors of weights for 2x3 winograd convolutions. */
- mkldnn_wino_wei_aaOIoi,
- mkldnn_wino_wei_aaOio,
- mkldnn_wino_wei_aaOBiOo,
- /** Tensor of weights for 4x3 convolution. */
- mkldnn_wino_wei_OBaaIBOIio
-} mkldnn_wino_memory_format_t;
-
-/** Description of tensor of weights for winograd 2x3 convolution. */
-typedef struct {
- mkldnn_wino_memory_format_t wino_format;
- int r;
- int alpha;
- int ic;
- int oc;
- int ic_block;
- int oc_block;
- int ic2_block;
- int oc2_block;
- float adj_scale;
- size_t size;
-} mkldnn_wino_desc_t;
-
-typedef enum {
- mkldnn_packed_format_undef = 0,
- mkldnn_ldigo_p,
- mkldnn_ldgoi_p
-} mkldnn_rnn_packed_memory_format_t;
-
-/* Maximum number of parts of RNN weights tensor that require separate
- * computation. */
-#define MKLDNN_RNN_MAX_N_PARTS 4
-
-/** Description of tensor of packed weights for rnn. */
-typedef struct {
- mkldnn_rnn_packed_memory_format_t format;
- int n_parts;
- int n;
- int parts[MKLDNN_RNN_MAX_N_PARTS];
- size_t part_pack_size[MKLDNN_RNN_MAX_N_PARTS];
- size_t offset_compensation;
- size_t size;
-} mkldnn_rnn_packed_desc_t;
-
-typedef enum {
- mkldnn_memory_extra_flag_none = 0x0U,
- /** Indicates the weights have an additional buffer, that depends on the
- * @p compensation_mask.
- *
- * For instance, in 4D case with the compensation mask equals (1 << 0)
- * the additional buffer would consist of OC values:
- * O[oc : 0,OC] =
- * -128 * SUM(ic : 0,IC; kh : 0,KH; kw : 0,KW){ weights(oc, ic, kh, kw) }
- */
- mkldnn_memory_extra_flag_compensation_conv_s8s8 = 0x1U,
- mkldnn_memory_extra_flag_scale_adjust = 0x2U,
-} mkldnn_memory_extra_flags_t;
-
-/** Description of extra information stored in memory */
-typedef struct {
- /** The flags contain arbitrary extra information, such as compensation.
- * @sa mkldnn_memory_extra_flags_t */
- uint64_t flags;
- /** Compensation mask */
- int compensation_mask;
- /** Scale applied to the data */
- float scale_adjust;
- /** For future backwards compatibility */
- char reserved[64];
-} mkldnn_memory_extra_desc_t;
-
-/** Memory descriptor. The description is based on a number of dimensions,
- * dimensions themselves, plus information about elements type and memory
- * format. Additionally, contains format-specific descriptions of the data
- * layout. */
-typedef struct {
- /** Number of dimensions */
- int ndims;
- /** Dimensions in the following order:
- * - CNN data tensors: mini-batch, channel, spatial
- * (<code>{N, C, [[D,] H,] W}</code>)
- * - CNN weight tensors: group (optional), output channel, input channel,
- * spatial (<code>{[G,] O, I, [[D,] H,] W}</code>)
- * - RNN data tensors: time, mini-batch, channels (<code>{T, N, C}</code>)
- * or layers, directions, states, mini-batch, channels (<code>{L, D, S, N, C}</code>)
- * - RNN weight tensor: layers, directions, input channel, gates, output channels
- * (<code>{L, D, I, G, O}</code>).
- *
- * @note
- * The order of dimensions does not depend on the memory format, so
- * whether the data is laid out in #mkldnn_nchw or #mkldnn_nhwc
- * the dims for 4D CN data tensor would be <code>{N, C, H, W}</code>.
- */
- mkldnn_dims_t dims;
- /** Data type of the tensor elements. */
- mkldnn_data_type_t data_type;
-
- /** Size of the data including padding in each dimension. */
- mkldnn_dims_t padded_dims;
- /** Per-dimension offset from the padding to actual data, the top-level
- * tensor with offsets applied must lie within the padding area. */
- mkldnn_dims_t padded_offsets;
-
- /** Offset from memory origin to the current block, non-zero only in
- * a description of a memory sub-block. */
- mkldnn_dim_t offset0;
-
- /** Memory format kind. */
- mkldnn_format_kind_t format_kind;
- union {
- /** Description of the data layout for memory formats that use
- * blocking. */
- mkldnn_blocking_desc_t blocking;
- /** Tensor of weights for integer 8bit winograd convolution. */
- mkldnn_wino_desc_t wino_desc;
- /** Tensor of packed weights for RNN. */
- mkldnn_rnn_packed_desc_t rnn_packed_desc;
- /* ... other descriptions possible */
- } format_desc;
-
- mkldnn_memory_extra_desc_t extra;
-} mkldnn_memory_desc_t;
-
-/** @struct mkldnn_memory
- * An opaque structure to describe a memory. */
-struct mkldnn_memory;
-
-/** A memory handle. */
-typedef struct mkldnn_memory *mkldnn_memory_t;
-
-/** A constant memory handle. */
-typedef const struct mkldnn_memory *const_mkldnn_memory_t;
-
-#define MKLDNN_NATIVE_HANDLE_NONE (NULL)
-#define MKLDNN_NATIVE_HANDLE_ALLOCATE ((void *)(size_t)-1)
-
-/** @} */
-
-/** @addtogroup c_api_types_op_descs Operation descriptors
- * @{*/
-
-/** A pointer to any of the operation descriptors. */
-typedef void *mkldnn_op_desc_t;
-/** A pointer to any of the operation descriptors (constant variant). */
-typedef const void *const_mkldnn_op_desc_t;
-
-/** A descriptor of a convolution operation. */
-typedef struct {
- /** The kind of primitive. Used for self-identifying the primitive
- * descriptor. Must be #mkldnn_convolution. */
- mkldnn_primitive_kind_t primitive_kind;
- /** The kind of propagation. Possible values: #mkldnn_forward_training,
- * #mkldnn_forward_inference, #mkldnn_backward_data,
- * #mkldnn_backward_weights, and #mkldnn_backward_bias. */
- mkldnn_prop_kind_t prop_kind;
- /** The kind of the convolution algorithm. Possible values:
- * #mkldnn_convolution_direct. */
- mkldnn_alg_kind_t alg_kind;
- /** Source memory descriptor. */
- mkldnn_memory_desc_t src_desc;
- /** Source gradient memory descriptor. */
- mkldnn_memory_desc_t diff_src_desc;
- /** Weights memory descriptor. */
- mkldnn_memory_desc_t weights_desc;
- /** Weights gradient memory descriptor. */
- mkldnn_memory_desc_t diff_weights_desc;
- /** Bias memory descriptor. */
- mkldnn_memory_desc_t bias_desc;
- /** Bias gradient memory descriptor. */
- mkldnn_memory_desc_t diff_bias_desc;
- /** Destination memory descriptor. */
- mkldnn_memory_desc_t dst_desc;
- /** Destination gradient memory descriptor. */
- mkldnn_memory_desc_t diff_dst_desc;
- /** Convolution strides in each spatial dimension. */
- mkldnn_dims_t strides;
- /** Convolution dilates in each spatial dimension. */
- mkldnn_dims_t dilates;
- /** Padding in each spatial dimension. padding[0] is a padding in the
- * beginning (@p padding_l), padding[1] is a padding in the end (@p
- * padding_r). */
- mkldnn_dims_t padding[2];
- /** The kind of padding to use. */
- mkldnn_padding_kind_t padding_kind;
- /** The accumulator data type. Initialized automatically. */
- mkldnn_data_type_t accum_data_type;
-} mkldnn_convolution_desc_t;
-
-/** A descriptor of a deconvolution operation. */
-typedef mkldnn_convolution_desc_t mkldnn_deconvolution_desc_t;
-
-/** A descriptor of a shuffle operation. */
-typedef struct {
- /** The kind of primitive. Used for self-identifying the primitive
- * descriptor. Must be #mkldnn_convolution. */
- mkldnn_primitive_kind_t primitive_kind;
- /** The kind of propagation. Possible values: #mkldnn_forward_training,
- * #mkldnn_forward_inference, and #mkldnn_backward_data. */
- mkldnn_prop_kind_t prop_kind;
- /** Source and destination memory descriptor,
- * and source and destination gradient memory descriptor. */
- mkldnn_memory_desc_t data_desc;
- /** axis for shuffling. */
- int axis;
- /** number of groups in group convolution */
- mkldnn_dim_t group_size;
-} mkldnn_shuffle_desc_t;
-
-/** A descriptor of a element-wise operation. */
-typedef struct {
- /** The kind of primitive. Used for self-identifying the primitive
- * descriptor. Must be #mkldnn_eltwise. */
- mkldnn_primitive_kind_t primitive_kind;
- /** The kind of propagation. Possible values: #mkldnn_forward_training,
- * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
- */
- mkldnn_prop_kind_t prop_kind;
- /** The kind of eltwise algorithm. Possible values: #mkldnn_eltwise_relu,
- * #mkldnn_eltwise_tanh, #mkldnn_eltwise_elu, #mkldnn_eltwise_square,
- * #mkldnn_eltwise_abs, #mkldnn_eltwise_sqrt, #mkldnn_eltwise_linear,
- * #mkldnn_eltwise_bounded_relu, #mkldnn_eltwise_soft_relu, and
- * #mkldnn_eltwise_logistic. */
- mkldnn_alg_kind_t alg_kind;
- /** Source and destination memory descriptor. */
- mkldnn_memory_desc_t data_desc;
- /** Source and destination gradient memory descriptor. */
- mkldnn_memory_desc_t diff_data_desc;
- /** Algorithm specific parameter.
- * Accordance table:
- * - #mkldnn_eltwise_relu: @p alpha -- negative slope, @p beta ignored
- * - #mkldnn_eltwise_tanh: @p alpha and @p beta ignored
- * - #mkldnn_eltwise_elu: @p alpha -- negative slope, @p beta ignored
- * - #mkldnn_eltwise_square: @p alpha and @p beta ignored
- * - #mkldnn_eltwise_abs: @p alpha and @p beta ignored
- * - #mkldnn_eltwise_sqrt: @p alpha and @p beta ignored
- * - #mkldnn_eltwise_linear: @p alpha -- scale, @p beta -- shift
- * - #mkldnn_eltwise_bounded_relu: @p alpha -- upper bound, @p beta ignored
- * - #mkldnn_eltwise_soft_relu: @p alpha and @p beta ignored
- * - #mkldnn_eltwise_logistic: @p alpha and @p beta ignored
- */
- float alpha, beta;
-} mkldnn_eltwise_desc_t;
-
-/** A descriptor of a Softmax operation. */
-typedef struct {
- /** The kind of primitive. Used for self-identifying the primitive
- * descriptor. Must be #mkldnn_softmax. */
- mkldnn_primitive_kind_t primitive_kind;
- /** The kind of propagation. Possible values: #mkldnn_forward_training and
- * #mkldnn_forward_inference. */
- mkldnn_prop_kind_t prop_kind;
- /** Source and destination memory descriptor. */
- mkldnn_memory_desc_t data_desc;
- /** Source and Destination of gradient memory descriptor. */
- mkldnn_memory_desc_t diff_desc;
- /** The axis along which to perform the softmax. */
- int softmax_axis;
-} mkldnn_softmax_desc_t;
-
-/** A descriptor of a pooling operation. */
-typedef struct {
- /** The kind of primitive. Used for self-identifying the primitive
- * descriptor. Must be #mkldnn_pooling. */
- mkldnn_primitive_kind_t primitive_kind;
- /** The kind of propagation. Possible values: #mkldnn_forward_training,
- * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
- */
- mkldnn_prop_kind_t prop_kind;
- /** The kind of pooling algorithm. Possible values: #mkldnn_pooling_max and
- * #mkldnn_pooling_avg. */
- mkldnn_alg_kind_t alg_kind;
- /** Source memory descriptor. */
- mkldnn_memory_desc_t src_desc;
- /** Source gradient memory descriptor. */
- mkldnn_memory_desc_t diff_src_desc;
- /** Destination memory descriptor. */
- mkldnn_memory_desc_t dst_desc;
- /** Destination gradient memory descriptor. */
- mkldnn_memory_desc_t diff_dst_desc;
- /** Pooling kernel strides for spatial dimensions. */
- mkldnn_dims_t strides;
- /** Pooling kernel spatial dimensions. */
- mkldnn_dims_t kernel;
- /** Padding in each spatial dimension. padding[0] is a padding in the
- * beginning (@p padding_l), padding[1] is a padding in the end (@p
- * padding_r). */
- mkldnn_dims_t padding[2];
- /** The kind of padding to use. */
- mkldnn_padding_kind_t padding_kind;
- /** The accumulator data type. Initialized automatically. */
- mkldnn_data_type_t accum_data_type;
-} mkldnn_pooling_desc_t;
-
-/** A descriptor of a Local Response Normalization (LRN) operation. */
-typedef struct {
- /** The kind of primitive. Used for self-identifying the primitive
- * descriptor. Must be #mkldnn_lrn. */
- mkldnn_primitive_kind_t primitive_kind;
- /** The kind of propagation. Possible values: #mkldnn_forward_training,
- * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
- */
- mkldnn_prop_kind_t prop_kind;
- /** LRN algorithm. Possible values: #mkldnn_lrn_within_channel and
- * #mkldnn_lrn_across_channels. */
- mkldnn_alg_kind_t alg_kind;
- /** Source and destination memory descriptor. */
- mkldnn_memory_desc_t data_desc;
- /** Source and destination gradient memory descriptor. */
- mkldnn_memory_desc_t diff_data_desc;
- /** The number of channels to sum over (for cross-channel LRN) or the side
- * length of the square region to sum over (for within-channel LRN). */
- mkldnn_dim_t local_size;
- /** LRN alpha parameter. */
- float lrn_alpha;
- /** LRN beta parameter. */
- float lrn_beta;
- /** LRN k parameter. */
- float lrn_k;
-} mkldnn_lrn_desc_t;
-
-/** A descriptor of a Batch Normalization operation. */
-typedef struct {
- /** The kind of primitive. Used for self-identifying the primitive
- * descriptor. Must be #mkldnn_batch_normalization. */
- mkldnn_primitive_kind_t primitive_kind;
- /** The kind of propagation. Possible values: #mkldnn_forward_training,
- * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
- */
- mkldnn_prop_kind_t prop_kind;
- /** Source and destination memory descriptor. */
- mkldnn_memory_desc_t data_desc;
- /** Source and destination gradient memory descriptor. */
- mkldnn_memory_desc_t diff_data_desc;
- /** Scale and shift data and gradient memory descriptors.
- *
- * Scaleshift memory descriptor uses 2D #mkldnn_nc format[2,Channels]. 1-st
- * dimension contains gamma parameter, 2-nd dimension contains beta
- * parameter. */
- mkldnn_memory_desc_t data_scaleshift_desc;
- mkldnn_memory_desc_t diff_data_scaleshift_desc;
- /** Mean and variance data memory descriptors.
- *
- * Mean and variance memory descriptors use 1D #mkldnn_x format[Channels].
- */
- mkldnn_memory_desc_t mean_desc;
- mkldnn_memory_desc_t variance_desc;
- /** Batch normalization epsilon parameter. */
- float batch_norm_epsilon;
- unsigned flags;
-} mkldnn_batch_normalization_desc_t;
-
-/** A descriptor of an inner product operation. */
-typedef struct {
- /** The kind of primitive. Used for self-identifying the primitive
- * descriptor. Must be #mkldnn_inner_product. */
- mkldnn_primitive_kind_t primitive_kind;
- /** The kind of propagation. Possible values: #mkldnn_forward_training,
- * #mkldnn_forward_inference, #mkldnn_backward_data,
- * #mkldnn_backward_weights, and #mkldnn_backward_bias. */
- mkldnn_prop_kind_t prop_kind;
- /** Source memory descriptor. */
- mkldnn_memory_desc_t src_desc;
- /** Source gradient memory descriptor. */
- mkldnn_memory_desc_t diff_src_desc;
- /** Weights memory descriptor. */
- mkldnn_memory_desc_t weights_desc;
- /** Weights gradient memory descriptor. */
- mkldnn_memory_desc_t diff_weights_desc;
- /** Bias memory descriptor. */
- mkldnn_memory_desc_t bias_desc;
- /** Bias gradient memory descriptor. */
- mkldnn_memory_desc_t diff_bias_desc;
- /** Destination memory descriptor. */
- mkldnn_memory_desc_t dst_desc;
- /** Destination gradient memory descriptor. */
- mkldnn_memory_desc_t diff_dst_desc;
- /** The accumulator data type. Initialized automatically. */
- mkldnn_data_type_t accum_data_type;
-} mkldnn_inner_product_desc_t;
-
-/** Flags for RNN cell. */
-typedef enum {
- mkldnn_rnn_cell_with_relu = 0x1U,
- mkldnn_rnn_cell_with_clipping = 0x2U,
-} mkldnn_rnn_cell_flags_t;
-
-typedef struct {
- /** RNN cell kind. Must be one of #mkldnn_vanilla_rnn,
- * #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru,
- * or #mkldnn_gru_linear_before_reset. */
- mkldnn_alg_kind_t cell_kind;
- /** Activation function used. Must be either #mkldnn_eltwise_relu or
- * #mkldnn_eltwise_tanh. */
- mkldnn_alg_kind_t activation_kind;
- /** RNN cell flags */
- unsigned int flags;
- /** @c alpha is a negative slope parameter (used only if
- * `(flags & #mkldnn_rnn_cell_with_relu) != 0`) */
- float alpha;
- /** clipping parameter (used only if
- * `(flags & #mkldnn_rnn_cell_with_clipping) != 0`) */
- float clipping;
-} mkldnn_rnn_cell_desc_t;
-
-/** A direction of RNN primitive execution. */
-typedef enum {
- /* Unidirectional execution of RNN primitive from left to right. */
- mkldnn_unidirectional_left2right,
- /* Unidirectional execution of RNN primitive from right to left. */
- mkldnn_unidirectional_right2left,
- /* Bidirectional execution of RNN primitive with concatenation of the
- * results. */
- mkldnn_bidirectional_concat,
- /* Bidirectional execution of RNN primitive with summation of the
- * results. */
- mkldnn_bidirectional_sum,
- mkldnn_unidirectional = mkldnn_unidirectional_left2right,
-} mkldnn_rnn_direction_t;
-
-/** A descriptor for an RNN operation. */
-typedef struct {
- /** The kind of primitive. Used for self-identifying the primitive
- * descriptor. Must be #mkldnn_rnn. */
- mkldnn_primitive_kind_t primitive_kind;
- /** The kind of propagation. Possible values: #mkldnn_forward_training,
- * #mkldnn_forward_inference, and #mkldnn_backward. */
- mkldnn_prop_kind_t prop_kind;
- /** The RNN cell desc. */
- mkldnn_rnn_cell_desc_t cell_desc;
- /** The direction of RNN primitive execution. */
- mkldnn_rnn_direction_t direction;
- /** Source layer memory descriptor. */
- mkldnn_memory_desc_t src_layer_desc;
- /** Source iteration memory descriptor. */
- mkldnn_memory_desc_t src_iter_desc;
- /** Weights layer memory descriptor. */
- mkldnn_memory_desc_t weights_layer_desc;
- /** Weights iteration memory descriptor. */
- mkldnn_memory_desc_t weights_iter_desc;
- /** Bias memory descriptor. */
- mkldnn_memory_desc_t bias_desc;
- /** Destination layer memory descriptor. */
- mkldnn_memory_desc_t dst_layer_desc;
- /** Destination iter memory descriptor. */
- mkldnn_memory_desc_t dst_iter_desc;
- /** Source gradient layer memory descriptor. */
- mkldnn_memory_desc_t diff_src_layer_desc;
- /** Source gradient iter memory descriptor. */
- mkldnn_memory_desc_t diff_src_iter_desc;
- /** Weights gradient layer memory descriptor. */
- mkldnn_memory_desc_t diff_weights_layer_desc;
- /** Weights gradient iter memory descriptor. */
- mkldnn_memory_desc_t diff_weights_iter_desc;
- /** Bias gradient memory descriptor. */
- mkldnn_memory_desc_t diff_bias_desc;
- /** Destination gradient layer memory descriptor. */
- mkldnn_memory_desc_t diff_dst_layer_desc;
- /** Destination gradient iteration memory descriptor. */
- mkldnn_memory_desc_t diff_dst_iter_desc;
-} mkldnn_rnn_desc_t;
-
-/** @} */
-
-/** @addtogroup c_api_engine_types Engine
- * @{ */
-
-/** @brief Kinds of engines. */
-typedef enum {
- /** An unspecified engine. */
- mkldnn_any_engine,
- /** CPU engine. */
- mkldnn_cpu,
-} mkldnn_engine_kind_t;
-
-/** @struct mkldnn_engine
- * @brief An opaque structure to describe an engine. */
-struct mkldnn_engine;
-/** @brief An engine handle. */
-typedef struct mkldnn_engine *mkldnn_engine_t;
-#if 0
-/* FIXME: looks like this never happens */
-/** @brief A constant engine handle. */
-typedef const struct mkldnn_engine *const_mkldnn_engine_t;
-#endif
-
-/** @} */
-
-/** @addtogroup c_api_primitive_desc_iterators Primitive descriptor iterators
- * @{ */
-
-/** @struct mkldnn_primitive_desc_iterator
- * @brief An opaque structure to describe a primitive descriptor iterator. */
-struct mkldnn_primitive_desc_iterator;
-
-/** @brief A primitive descriptor iterator handle. */
-typedef struct mkldnn_primitive_desc_iterator
- *mkldnn_primitive_desc_iterator_t;
-
-/** @brief A constant primitive descriptor iterator handle. */
-typedef const struct mkldnn_primitive_desc_iterator
- *const_mkldnn_primitive_desc_iterator_t;
-
-/** @} */
-
-/** @addtogroup c_api_primitive_descs Primitive descriptors
- * @{ */
-
-/** @struct mkldnn_primitive_desc
- * @brief An opaque structure to describe a primitive descriptor. */
-struct mkldnn_primitive_desc;
-
-/** @brief A primitive descriptor handle. */
-typedef struct mkldnn_primitive_desc *mkldnn_primitive_desc_t;
-
-/** @brief A constant primitive descriptor handle. */
-typedef const struct mkldnn_primitive_desc *const_mkldnn_primitive_desc_t;
-
-/** @} */
-
-/** @addtogroup c_api_primitive_attr Primitive descriptor attributes
- * @{ */
-
-/** Scratchpad mode */
-typedef enum {
- /** The library manages scratchpad (default) */
- mkldnn_scratchpad_mode_library,
- /** A user shall query and provide the scratchpad memory to primitives */
- mkldnn_scratchpad_mode_user,
-} mkldnn_scratchpad_mode_t;
-
-/** @struct mkldnn_primitive_attr
- * @brief An opaque structure for primitive descriptor attributes.
- *
- * Attributes may contain:
- * - output scales (to scale the result prior to storing it to the memory)
- */
-struct mkldnn_primitive_attr;
-
-/** @brief A primitive descriptor attributes handle that controls primitive
- * behavior. */
-typedef struct mkldnn_primitive_attr *mkldnn_primitive_attr_t;
-
-/** @brief A constant primitive descriptor attributes handle. */
-typedef const struct mkldnn_primitive_attr *const_mkldnn_primitive_attr_t;
-
-/** @struct mkldnn_post_ops
- * @brief An opaque structure for a chain of post operations.
- *
- * mkldnn_post_ops can be used to perform some (trivial) operations like
- * accumulation or eltwise after certain primitives like convolution.
- *
- * Post operations might be combined together, making a chain of post
- * operations. For instance one can configure convolution followed by
- * accumulation followed by eltwise. This might be especially beneficial
- * for residual learning blocks.
- *
- * @warning
- * Of course not all combinations are supported, so the user should handle
- * errors accordingly.
- *
- * Supported post operations:
- * - accumulation (base primitive: convolution)
- * - eltwise (base primitive: convolution)
- */
-struct mkldnn_post_ops;
-
-/** @brief A post operation chain handle. */
-typedef struct mkldnn_post_ops *mkldnn_post_ops_t;
-
-/** @brief A constant post operation chain handle. */
-typedef const struct mkldnn_post_ops *const_mkldnn_post_ops_t;
-
-/** @} */
-
-/** @addtogroup c_api_types_primitive Primitive
- * @{ */
-
-/** @struct mkldnn_primitive
- * An opaque structure to describe a primitive. */
-struct mkldnn_primitive;
-/** A primitive handle. */
-typedef struct mkldnn_primitive *mkldnn_primitive_t;
-/** A constant primitive handle. */
-typedef const struct mkldnn_primitive *const_mkldnn_primitive_t;
-
-/** @addtogroup c_api_types_arguments Argument indices
- * @{ */
-
-#define MKLDNN_ARG_SRC_0 1
-#define MKLDNN_ARG_SRC MKLDNN_ARG_SRC_0
-#define MKLDNN_ARG_SRC_LAYER MKLDNN_ARG_SRC_0
-#define MKLDNN_ARG_FROM MKLDNN_ARG_SRC_0
-
-#define MKLDNN_ARG_SRC_1 2
-#define MKLDNN_ARG_SRC_ITER MKLDNN_ARG_SRC_1
-
-#define MKLDNN_ARG_DST_0 17
-#define MKLDNN_ARG_DST MKLDNN_ARG_DST_0
-#define MKLDNN_ARG_TO MKLDNN_ARG_DST_0
-#define MKLDNN_ARG_DST_LAYER MKLDNN_ARG_DST_0
-
-#define MKLDNN_ARG_DST_1 18
-#define MKLDNN_ARG_DST_ITER MKLDNN_ARG_DST_1
-
-#define MKLDNN_ARG_WEIGHTS_0 33
-#define MKLDNN_ARG_WEIGHTS MKLDNN_ARG_WEIGHTS_0
-#define MKLDNN_ARG_SCALE_SHIFT MKLDNN_ARG_WEIGHTS_0
-#define MKLDNN_ARG_WEIGHTS_LAYER MKLDNN_ARG_WEIGHTS_0
-
-#define MKLDNN_ARG_WEIGHTS_1 34
-#define MKLDNN_ARG_WEIGHTS_ITER MKLDNN_ARG_WEIGHTS_1
-
-#define MKLDNN_ARG_BIAS 41
-
-#define MKLDNN_ARG_MEAN 49
-#define MKLDNN_ARG_VARIANCE 50
-
-#define MKLDNN_ARG_WORKSPACE 64
-#define MKLDNN_ARG_SCRATCHPAD 80
-
-#define MKLDNN_ARG_DIFF_SRC_0 129
-#define MKLDNN_ARG_DIFF_SRC MKLDNN_ARG_DIFF_SRC_0
-#define MKLDNN_ARG_DIFF_SRC_LAYER MKLDNN_ARG_DIFF_SRC_0
-
-#define MKLDNN_ARG_DIFF_SRC_1 130
-#define MKLDNN_ARG_DIFF_SRC_ITER MKLDNN_ARG_DIFF_SRC_1
-
-#define MKLDNN_ARG_DIFF_DST_0 145
-#define MKLDNN_ARG_DIFF_DST MKLDNN_ARG_DIFF_DST_0
-#define MKLDNN_ARG_DIFF_DST_LAYER MKLDNN_ARG_DIFF_DST_0
-
-#define MKLDNN_ARG_DIFF_DST_1 146
-#define MKLDNN_ARG_DIFF_DST_ITER MKLDNN_ARG_DIFF_DST_1
-
-#define MKLDNN_ARG_DIFF_WEIGHTS_0 161
-#define MKLDNN_ARG_DIFF_WEIGHTS MKLDNN_ARG_DIFF_WEIGHTS_0
-#define MKLDNN_ARG_DIFF_SCALE_SHIFT MKLDNN_ARG_DIFF_WEIGHTS_0
-#define MKLDNN_ARG_DIFF_WEIGHTS_LAYER MKLDNN_ARG_DIFF_WEIGHTS_0
-
-#define MKLDNN_ARG_DIFF_WEIGHTS_1 162
-#define MKLDNN_ARG_DIFF_WEIGHTS_ITER MKLDNN_ARG_DIFF_WEIGHTS_1
-
-#define MKLDNN_ARG_DIFF_BIAS 169
-
-#define MKLDNN_ARG_MULTIPLE_SRC 1024
-#define MKLDNN_ARG_MULTIPLE_DST 2048
-
-/** @} */
-
-/** An auxiliary structure to specify primitive's inputs/outputs at execution
- *
- * @warning
- * With this API it's impossible to preserve constness of memory, so all
- * memories are passed w/o const qualifier. However only memories with
- * output semantics might be changed during the execution */
-typedef struct {
- int arg; /**< An argument index, e.g. MKLDNN_ARG_SRC */
- mkldnn_memory_t memory; /**< Input/output memory */
-} mkldnn_exec_arg_t;
-
-/** @} */
-
-/** @addtogroup c_api_types_query Queries
- * @{ */
-
-/** Primitive descriptor query specification
- *
- * For generic function mkldnn_primitive_desc_query(), the type of result must
- * agree with the queried argument. The correspondence table:
- * Query | type of result
- * --------------------------------------------------------------
- * #mkldnn_query_engine | mkldnn_engine_t *
- * #mkldnn_query_scratchpad_engine | mkldnn_engine_t *
- * #mkldnn_query_primitive_kind | mkldnn_primitive_kind_t *
- * *_s32 | int *
- * *_s64 | mkldnn_dim_t * (same as int64_t *)
- * *_f64 | double *
- * *_str | const char **
- * #mkldnn_query_op_d | const_mkldnn_op_desc_t *
- * *_md | const mkldnn_memory_desc_t **
- * *_${op}_d | const mkldnn_${op}_desc_t **
- * *_pd | const_mkldnn_primitive_desc_t *
- *
- * @note
- * Rule of thumb: all opaque types and structures are returned by
- * reference. All numbers are returned by value.
- *
- * @warning
- * All returned references point to constant objects and are valid only
- * during the lifetime of the queried primitive descriptor. Returned objects
- * must not be destroyed by the user. If you need to keep the object longer
- * than the lifetime of the queried primitive descriptor, use
- * mkldnn_primitive_desc_clone() to make a copy. */
-typedef enum {
- mkldnn_query_undef = 0, /**< no query */
-
- mkldnn_query_engine, /**< execution engine */
- mkldnn_query_primitive_kind, /**< primitive kind */
-
- mkldnn_query_num_of_inputs_s32, /**< number of inputs expected */
- mkldnn_query_num_of_outputs_s32, /**< number of outputs expected */
-
- mkldnn_query_time_estimate_f64, /**< runtime estimation (seconds) */
- mkldnn_query_memory_consumption_s64, /**< memory consumption -- extra
- (scratch) memory, additional to all
- inputs and outputs memory (bytes) */
-
- mkldnn_query_scratchpad_engine, /**< scratchpad engine -- engine to be used
- for creating scratchpad memory */
-
- mkldnn_query_impl_info_str, /**< implementation name */
-
- /* memory and op descriptor section */
- mkldnn_query_some_d = 64, /**< stub */
- mkldnn_query_op_d, /**< op descriptor */
- mkldnn_query_convolution_d, /**< convolution descriptor */
- mkldnn_query_deconvolution_d, /**< deconvolution descriptor */
- mkldnn_query_shuffle_d, /**< shuffle descriptor */
- mkldnn_query_eltwise_d, /**< eltwise descriptor */
- mkldnn_query_softmax_d, /**< softmax descriptor */
- mkldnn_query_pooling_d, /**< pooling descriptor */
- mkldnn_query_lrn_d, /**< lrn descriptor */
- mkldnn_query_batch_normalization_d, /**< batch normalization descriptor */
- mkldnn_query_inner_product_d, /**< inner product descriptor */
- mkldnn_query_rnn_d, /**< rnn descriptor */
-
- /* memory descriptor section */
- mkldnn_query_some_md = 128, /**< stub */
- mkldnn_query_src_md, /**< source memory desc */
- mkldnn_query_diff_src_md, /**< source gradient memory desc */
- mkldnn_query_weights_md, /**< weights memory descriptor desc */
- mkldnn_query_diff_weights_md, /**< weights grad. memory desc */
- mkldnn_query_dst_md, /**< destination memory desc */
- mkldnn_query_diff_dst_md, /**< destination grad. memory desc */
- mkldnn_query_workspace_md, /**< workspace memory desc */
- mkldnn_query_scratchpad_md, /**< scratchpad memory desc */
-} mkldnn_query_t;
-
-/** @} */
-
-/** @addtogroup c_api_types_stream Execution stream
- * @{ */
-
-/** @brief Stream flags. */
-typedef enum {
- /** A default stream configuration. */
- mkldnn_stream_default_flags = 0x0U,
-} mkldnn_stream_flags_t;
-
-/** @struct mkldnn_stream
- * An opaque structure to describe an execution stream. */
-struct mkldnn_stream;
-/** An execution stream handle. */
-typedef struct mkldnn_stream *mkldnn_stream_t;
-/** A constant execution stream handle. */
-typedef const struct mkldnn_stream *const_mkldnn_stream_t;
-
-/** @} */
-/** @} */
-/** @} */
-
-#ifdef __cplusplus
-}
-#endif
-
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h
deleted file mode 100644
index a2713deccb..0000000000
--- a/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h
+++ /dev/null
@@ -1,32 +0,0 @@
-/*******************************************************************************
-* Copyright 2019 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef MKLDNN_VERSION_H
-#define MKLDNN_VERSION_H
-
-/* Major version of MKL-DNN */
-#define MKLDNN_VERSION_MAJOR 0
-
-/* Minor version of MKL-DNN */
-#define MKLDNN_VERSION_MINOR 90
-
-/* Patch version of MKL-DNN */
-#define MKLDNN_VERSION_PATCH 0
-
-/* Git Commit Hash of MKL-DNN */
-#define MKLDNN_VERSION_HASH "096bda1ca23324879f2df5a129e610e4405f775c"
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in b/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in
deleted file mode 100644
index 5ee0126188..0000000000
--- a/thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in
+++ /dev/null
@@ -1,32 +0,0 @@
-/*******************************************************************************
-* Copyright 2019 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef MKLDNN_VERSION_H
-#define MKLDNN_VERSION_H
-
-/* Major version of MKL-DNN */
-#define MKLDNN_VERSION_MAJOR @MKLDNN_VERSION_MAJOR@
-
-/* Minor version of MKL-DNN */
-#define MKLDNN_VERSION_MINOR @MKLDNN_VERSION_MINOR@
-
-/* Patch version of MKL-DNN */
-#define MKLDNN_VERSION_PATCH @MKLDNN_VERSION_PATCH@
-
-/* Git Commit Hash of MKL-DNN */
-#define MKLDNN_VERSION_HASH "@MKLDNN_VERSION_HASH@"
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp
deleted file mode 100644
index 1a51d8562b..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp
+++ /dev/null
@@ -1,104 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::prop_kind;
-using namespace mkldnn::impl::alg_kind;
-using namespace mkldnn::impl::types;
-
-namespace {
-status_t bnrm_desc_init(batch_normalization_desc_t *bnrm_desc,
- prop_kind_t prop_kind, const memory_desc_t *data_desc,
- const memory_desc_t *diff_data_desc, float epsilon, unsigned flags) {
- bool args_ok = true
- && !any_null(bnrm_desc, data_desc)
- && one_of(prop_kind, forward_training, forward_inference,
- backward_data, backward)
- && IMPLICATION(prop_kind & backward, diff_data_desc != nullptr);
- if (!args_ok) return invalid_arguments;
-
- auto bd = batch_normalization_desc_t();
- bd.primitive_kind = primitive_kind::batch_normalization;
- bd.prop_kind = prop_kind;
-
- bd.data_desc = *data_desc;
- bd.diff_data_desc = zero_md();
- if ( one_of(bd.prop_kind,backward_data, backward) )
- bd.diff_data_desc = *diff_data_desc;
-
- dims_t scaleshift_dims = { 2, data_desc->dims[1] };
- mkldnn_memory_desc_init_by_tag(&bd.data_scaleshift_desc, 2,
- scaleshift_dims, data_type::f32, mkldnn_nc);
- bd.diff_data_scaleshift_desc = zero_md();
- if (bd.prop_kind == backward) {
- bd.diff_data_scaleshift_desc = bd.data_scaleshift_desc;
- }
-
- dims_t stats_dims = { data_desc->dims[1] };
- mkldnn_memory_desc_init_by_tag(&bd.mean_desc, 1, stats_dims,
- data_type::f32, mkldnn_x);
- bd.variance_desc = bd.mean_desc;
- bd.batch_norm_epsilon = epsilon;
-
- unsigned bnorm_flags =
- mkldnn_use_global_stats | mkldnn_use_scaleshift | mkldnn_fuse_bn_relu;
- if ((~bnorm_flags & flags) != 0) return invalid_arguments;
-
- bd.flags = flags;
-
- bool consistency = true
- && utils::one_of(bd.data_desc.ndims, 2, 4, 5);
- if (bd.prop_kind == backward_data)
- consistency = consistency
- && utils::one_of(bd.diff_data_desc.ndims, 2, 4, 5)
- && array_cmp(bd.diff_data_desc.dims, bd.data_desc.dims,
- bd.diff_data_desc.ndims);
- if (!consistency) return invalid_arguments;
-
- *bnrm_desc = bd;
- return success;
-}
-}
-
-status_t mkldnn_batch_normalization_forward_desc_init(
- batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind,
- const memory_desc_t *data_desc, float epsilon, unsigned flags) {
- if (!one_of(prop_kind, forward_training, forward_inference))
- return invalid_arguments;
- return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, nullptr,
- epsilon, flags);
-}
-
-status_t mkldnn_batch_normalization_backward_desc_init(
- batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind,
- const memory_desc_t *diff_data_desc, const memory_desc_t *data_desc,
- float epsilon, unsigned flags) {
- if (!one_of(prop_kind, backward, backward_data))
- return invalid_arguments;
- return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, diff_data_desc,
- epsilon, flags);
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp
deleted file mode 100644
index f61410b33c..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp
+++ /dev/null
@@ -1,240 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef BATCH_NORMALIZATION_PD_HPP
-#define BATCH_NORMALIZATION_PD_HPP
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "primitive_desc.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-struct batch_normalization_fwd_pd_t;
-
-struct batch_normalization_pd_t: public primitive_desc_t {
- static constexpr auto base_pkind = primitive_kind::batch_normalization;
-
- batch_normalization_pd_t(engine_t *engine,
- const batch_normalization_desc_t *adesc,
- const primitive_attr_t *attr,
- const batch_normalization_fwd_pd_t *hint_fwd_pd)
- : primitive_desc_t(engine, attr, base_pkind)
- , desc_(*adesc)
- , hint_fwd_pd_(hint_fwd_pd)
- , data_md_(desc_.data_desc)
- , stat_md_(desc_.mean_desc)
- , scaleshift_md_(desc_.data_scaleshift_desc)
- , ws_md_()
- {}
-
- const batch_normalization_desc_t *desc() const { return &desc_; }
- virtual const op_desc_t *op_desc() const override
- { return reinterpret_cast<const op_desc_t *>(this->desc()); }
- virtual void init_info() override { impl::init_info(this, this->info_); }
-
- virtual status_t query(query_t what, int idx, void *result) const override {
- switch (what) {
- case query::batch_normalization_d:
- *(const batch_normalization_desc_t**)result = desc(); break;
- default: return primitive_desc_t::query(what, idx, result);
- }
- return status::success;
- }
-
- /* common batch_normalization aux functions */
-
- dim_t MB() const { return data_desc().dims[0]; }
- dim_t C() const { return data_desc().dims[1]; }
- dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
- dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
- dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
-
- int ndims() const { return desc_.data_desc.ndims; }
-
- bool stats_is_src() const { return desc_.flags & mkldnn_use_global_stats; }
- bool use_scaleshift() const { return desc_.flags & mkldnn_use_scaleshift; }
- bool use_global_stats() const
- { return desc_.flags & mkldnn_use_global_stats; }
- bool fuse_bn_relu() const { return desc_.flags & mkldnn_fuse_bn_relu; }
- bool with_relu_post_op() const {
- const auto &p = this->attr()->post_ops_;
- return p.len_ == 1 && p.entry_[0].is_relu(true, true);
- }
-
- bool is_fwd() const {
- return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
- prop_kind::forward_inference);
- }
- bool is_bwd() const { return !this->is_fwd(); }
- bool is_training() const
- { return desc_.prop_kind == prop_kind::forward_training; }
-
- bool has_zero_dim_memory() const
- { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
-
-protected:
- batch_normalization_desc_t desc_;
- const batch_normalization_fwd_pd_t *hint_fwd_pd_;
-
- memory_desc_t data_md_;
- memory_desc_t stat_md_;
- memory_desc_t scaleshift_md_;
-
- memory_desc_t ws_md_;
-
- void init_default_ws(size_t bits_per_element) {
- const auto data_mdw = memory_desc_wrapper(data_md_);
-
- const dim_t data_nelems = data_mdw.nelems(true);
- const dim_t bits_per_byte = 8;
- const dims_t ws_sz = { (dim_t)utils::div_up(
- data_nelems * bits_per_element, bits_per_byte) };
- mkldnn_memory_desc_init_by_tag(&ws_md_, 1, ws_sz, impl::data_type::u8,
- format_tag::x);
- }
-
-private:
- const memory_desc_t &data_desc() const { return desc_.data_desc; }
-};
-
-struct batch_normalization_fwd_pd_t: public batch_normalization_pd_t {
- typedef batch_normalization_fwd_pd_t base_class;
- typedef batch_normalization_fwd_pd_t hint_class;
-
- batch_normalization_fwd_pd_t(engine_t *engine,
- const batch_normalization_desc_t *adesc,
- const primitive_attr_t *attr,
- const batch_normalization_fwd_pd_t *hint_fwd_pd)
- : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (arg == MKLDNN_ARG_SRC) return arg_usage_t::input;
- if (arg == MKLDNN_ARG_DST) return arg_usage_t::output;
-
- if (utils::one_of(arg, MKLDNN_ARG_MEAN, MKLDNN_ARG_VARIANCE)) {
- if (stats_is_src()) return arg_usage_t::input;
- if (!stats_is_src() && is_training()) return arg_usage_t::output;
- return arg_usage_t::unused;
- }
-
- if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift())
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_WORKSPACE && is_training() && fuse_bn_relu())
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override {
- if (index == 0) return &data_md_;
- if (stats_is_src() && (index == 1 || index == 2)) return &stat_md_;
- return nullptr;
- }
-
- virtual const memory_desc_t *dst_md(int index = 0) const override {
- if (index == 0) return &data_md_;
- if (!stats_is_src() && is_training() && (index == 1 || index == 2))
- return &stat_md_;
- return nullptr;
- }
-
- virtual const memory_desc_t *weights_md(int index = 0) const override
- { return index == 0 ? &scaleshift_md_ : nullptr; }
-
- virtual const memory_desc_t *workspace_md(int index = 0) const override
- { return index == 0 && is_training() && fuse_bn_relu() ? &ws_md_ : nullptr; }
-
- const memory_desc_t *stat_md() const
- { return stats_is_src() ? src_md(1) : dst_md(1); }
-
- virtual int n_inputs() const override
- { return 1 + 2 * stats_is_src() + use_scaleshift(); }
- virtual int n_outputs() const override
- { return 1 + (fuse_bn_relu() + 2 * (!stats_is_src())) * is_training(); }
-};
-
-struct batch_normalization_bwd_pd_t: public batch_normalization_pd_t {
- typedef batch_normalization_bwd_pd_t base_class;
- typedef batch_normalization_fwd_pd_t hint_class;
-
- batch_normalization_bwd_pd_t(engine_t *engine,
- const batch_normalization_desc_t *adesc,
- const primitive_attr_t *attr,
- const batch_normalization_fwd_pd_t *hint_fwd_pd)
- : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd)
- , diff_data_md_(desc_.diff_data_desc)
- , diff_scaleshift_md_(desc_.diff_data_scaleshift_desc)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_MEAN,
- MKLDNN_ARG_VARIANCE, MKLDNN_ARG_DIFF_DST))
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift())
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_WORKSPACE && fuse_bn_relu())
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DIFF_SRC)
- return arg_usage_t::output;
-
- if (arg == MKLDNN_ARG_DIFF_SCALE_SHIFT && use_scaleshift())
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override
- { return index == 0 ? &data_md_ : index <= 2 ? &stat_md_ : nullptr; }
- virtual const memory_desc_t *diff_dst_md(int index = 0) const override
- { return index == 0 ? &diff_data_md_ : nullptr; }
- virtual const memory_desc_t *diff_src_md(int index = 0) const override
- { return index == 0 ? &diff_data_md_ : nullptr; }
-
- virtual const memory_desc_t *weights_md(int index = 0) const override
- { return index == 0 ? &scaleshift_md_ : nullptr; }
- virtual const memory_desc_t *diff_weights_md(int index = 0) const override
- { return index == 0 ? &diff_scaleshift_md_ : nullptr; }
-
- virtual const memory_desc_t *workspace_md(int index = 0) const override
- { return index == 0 && fuse_bn_relu() ? &ws_md_ : nullptr; }
-
- const memory_desc_t *stat_md() const { return src_md(1); }
-
- virtual int n_inputs() const override
- { return 4 + use_scaleshift() + fuse_bn_relu(); }
- virtual int n_outputs() const override
- { return 1 + (desc_.prop_kind == prop_kind::backward); }
-
-protected:
- memory_desc_t diff_data_md_;
- memory_desc_t diff_scaleshift_md_;
-};
-
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp b/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp
deleted file mode 100644
index 3d43a0fbee..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp
+++ /dev/null
@@ -1,550 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef TYPE_MAPPING_HPP
-#define TYPE_MAPPING_HPP
-
-#include "mkldnn_types.h"
-
-namespace mkldnn {
-namespace impl {
-
-// TODO: autogenerate this
-
-using dim_t = mkldnn_dim_t;
-using dims_t = mkldnn_dims_t;
-using stride_t = mkldnn_dim_t;
-using strides_t = mkldnn_strides_t;
-
-using status_t = mkldnn_status_t;
-namespace status {
- const status_t success = mkldnn_success;
- const status_t out_of_memory = mkldnn_out_of_memory;
- const status_t try_again = mkldnn_try_again;
- const status_t invalid_arguments = mkldnn_invalid_arguments;
- const status_t not_ready = mkldnn_not_ready;
- const status_t unimplemented = mkldnn_unimplemented;
- const status_t iterator_ends = mkldnn_iterator_ends;
- const status_t runtime_error = mkldnn_runtime_error;
- const status_t not_required = mkldnn_not_required;
-}
-
-using prop_kind_t = mkldnn_prop_kind_t;
-namespace prop_kind {
- const prop_kind_t undef = mkldnn_prop_kind_undef;
- const prop_kind_t forward_training = mkldnn_forward_training;
- const prop_kind_t forward_inference = mkldnn_forward_inference;
- const prop_kind_t forward_scoring = mkldnn_forward_scoring;
- const prop_kind_t forward = mkldnn_forward;
- const prop_kind_t backward = mkldnn_backward;
- const prop_kind_t backward_data = mkldnn_backward_data;
- const prop_kind_t backward_weights = mkldnn_backward_weights;
- const prop_kind_t backward_bias = mkldnn_backward_bias;
-}
-
-using alg_kind_t = mkldnn_alg_kind_t;
-namespace alg_kind {
- const alg_kind_t undef = mkldnn_alg_kind_undef;
- const alg_kind_t convolution_auto = mkldnn_convolution_auto;
- const alg_kind_t convolution_direct = mkldnn_convolution_direct;
- const alg_kind_t convolution_winograd = mkldnn_convolution_winograd;
- const alg_kind_t deconvolution_direct = mkldnn_deconvolution_direct;
- const alg_kind_t deconvolution_winograd = mkldnn_deconvolution_winograd;
- const alg_kind_t eltwise_relu = mkldnn_eltwise_relu;
- const alg_kind_t eltwise_tanh = mkldnn_eltwise_tanh;
- const alg_kind_t eltwise_elu = mkldnn_eltwise_elu;
- const alg_kind_t eltwise_square = mkldnn_eltwise_square;
- const alg_kind_t eltwise_abs = mkldnn_eltwise_abs;
- const alg_kind_t eltwise_sqrt = mkldnn_eltwise_sqrt;
- const alg_kind_t eltwise_linear = mkldnn_eltwise_linear;
- const alg_kind_t eltwise_bounded_relu = mkldnn_eltwise_bounded_relu;
- const alg_kind_t eltwise_soft_relu = mkldnn_eltwise_soft_relu;
- const alg_kind_t eltwise_logistic = mkldnn_eltwise_logistic;
- const alg_kind_t pooling_max = mkldnn_pooling_max;
- const alg_kind_t pooling_avg = mkldnn_pooling_avg;
- const alg_kind_t pooling_avg_include_padding = mkldnn_pooling_avg_include_padding;
- const alg_kind_t pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding;
- const alg_kind_t lrn_across_channels = mkldnn_lrn_across_channels;
- const alg_kind_t lrn_within_channel = mkldnn_lrn_within_channel;
- const alg_kind_t vanilla_rnn = mkldnn_vanilla_rnn;
- const alg_kind_t vanilla_lstm = mkldnn_vanilla_lstm;
- const alg_kind_t vanilla_gru = mkldnn_vanilla_gru;
- const alg_kind_t gru_linear_before_reset = mkldnn_gru_linear_before_reset;
-}
-
-using data_type_t = mkldnn_data_type_t;
-namespace data_type {
- const data_type_t undef = mkldnn_data_type_undef;
- const data_type_t f32 = mkldnn_f32;
- const data_type_t s32 = mkldnn_s32;
- const data_type_t s8 = mkldnn_s8;
- const data_type_t u8 = mkldnn_u8;
-}
-
-using scratchpad_mode_t = mkldnn_scratchpad_mode_t;
-namespace scratchpad_mode {
- const scratchpad_mode_t library = mkldnn_scratchpad_mode_library;
- const scratchpad_mode_t user = mkldnn_scratchpad_mode_user;
-}
-
-using rnn_packed_format_t = mkldnn_rnn_packed_memory_format_t;
-namespace rnn_packed_format {
- const rnn_packed_format_t undef = mkldnn_packed_format_undef;
- const rnn_packed_format_t ldigo_p = mkldnn_ldigo_p;
- const rnn_packed_format_t ldgoi_p = mkldnn_ldgoi_p;
-}
-
-using format_kind_t = mkldnn_format_kind_t;
-namespace format_kind {
- const format_kind_t undef = mkldnn_format_kind_undef;
- const format_kind_t any = mkldnn_format_kind_any;
- const format_kind_t blocked = mkldnn_blocked;
- const format_kind_t wino = mkldnn_format_kind_wino;
- const format_kind_t rnn_packed = mkldnn_format_kind_rnn_packed;
-}
-
-using format_tag_t = mkldnn_format_tag_t;
-namespace format_tag {
- const format_tag_t undef = mkldnn_format_tag_undef;
- const format_tag_t any = mkldnn_format_tag_any;
- const format_tag_t a = mkldnn_a;
- const format_tag_t ab = mkldnn_ab;
- const format_tag_t abc = mkldnn_abc;
- const format_tag_t abcd = mkldnn_abcd;
- const format_tag_t abcde = mkldnn_abcde;
- const format_tag_t abcdef = mkldnn_abcdef;
- const format_tag_t abdec = mkldnn_abdec;
- const format_tag_t acb = mkldnn_acb;
- const format_tag_t acbde = mkldnn_acbde;
- const format_tag_t acdb = mkldnn_acdb;
- const format_tag_t acdeb = mkldnn_acdeb;
- const format_tag_t ba = mkldnn_ba;
- const format_tag_t bac = mkldnn_bac;
- const format_tag_t bacd = mkldnn_bacd;
- const format_tag_t bcda = mkldnn_bcda;
- const format_tag_t cba = mkldnn_cba;
- const format_tag_t cdba = mkldnn_cdba;
- const format_tag_t cdeba = mkldnn_cdeba;
- const format_tag_t decab = mkldnn_decab;
- const format_tag_t Abc16a = mkldnn_Abc16a;
- const format_tag_t ABc16a16b = mkldnn_ABc16a16b;
- const format_tag_t aBc16b = mkldnn_aBc16b;
- const format_tag_t ABc16b16a = mkldnn_ABc16b16a;
- const format_tag_t Abc4a = mkldnn_Abc4a;
- const format_tag_t aBc4b = mkldnn_aBc4b;
- const format_tag_t ABc4b16a4b = mkldnn_ABc4b16a4b;
- const format_tag_t ABc4b4a = mkldnn_ABc4b4a;
- const format_tag_t ABc8a16b2a = mkldnn_ABc8a16b2a;
- const format_tag_t ABc8a8b = mkldnn_ABc8a8b;
- const format_tag_t aBc8b = mkldnn_aBc8b;
- const format_tag_t ABc8b16a2b = mkldnn_ABc8b16a2b;
- const format_tag_t ABc8b8a = mkldnn_ABc8b8a;
- const format_tag_t Abcd16a = mkldnn_Abcd16a;
- const format_tag_t ABcd16a16b = mkldnn_ABcd16a16b;
- const format_tag_t aBcd16b = mkldnn_aBcd16b;
- const format_tag_t ABcd16b16a = mkldnn_ABcd16b16a;
- const format_tag_t aBCd16b16c = mkldnn_aBCd16b16c;
- const format_tag_t aBCd16c16b = mkldnn_aBCd16c16b;
- const format_tag_t Abcd4a = mkldnn_Abcd4a;
- const format_tag_t aBcd4b = mkldnn_aBcd4b;
- const format_tag_t ABcd4b16a4b = mkldnn_ABcd4b16a4b;
- const format_tag_t ABcd4b4a = mkldnn_ABcd4b4a;
- const format_tag_t aBCd4c16b4c = mkldnn_aBCd4c16b4c;
- const format_tag_t aBCd4c4b = mkldnn_aBCd4c4b;
- const format_tag_t ABcd8a16b2a = mkldnn_ABcd8a16b2a;
- const format_tag_t ABcd8a8b = mkldnn_ABcd8a8b;
- const format_tag_t aBcd8b = mkldnn_aBcd8b;
- const format_tag_t ABcd8b16a2b = mkldnn_ABcd8b16a2b;
- const format_tag_t aBCd8b16c2b = mkldnn_aBCd8b16c2b;
- const format_tag_t ABcd8b8a = mkldnn_ABcd8b8a;
- const format_tag_t aBCd8b8c = mkldnn_aBCd8b8c;
- const format_tag_t aBCd8c16b2c = mkldnn_aBCd8c16b2c;
- const format_tag_t aBCd8c8b = mkldnn_aBCd8c8b;
- const format_tag_t Abcde16a = mkldnn_Abcde16a;
- const format_tag_t ABcde16a16b = mkldnn_ABcde16a16b;
- const format_tag_t aBcde16b = mkldnn_aBcde16b;
- const format_tag_t ABcde16b16a = mkldnn_ABcde16b16a;
- const format_tag_t aBCde16b16c = mkldnn_aBCde16b16c;
- const format_tag_t aBCde16c16b = mkldnn_aBCde16c16b;
- const format_tag_t aBCde2c8b4c = mkldnn_aBCde2c8b4c;
- const format_tag_t Abcde4a = mkldnn_Abcde4a;
- const format_tag_t aBcde4b = mkldnn_aBcde4b;
- const format_tag_t ABcde4b4a = mkldnn_ABcde4b4a;
- const format_tag_t aBCde4b4c = mkldnn_aBCde4b4c;
- const format_tag_t aBCde4c16b4c = mkldnn_aBCde4c16b4c;
- const format_tag_t aBCde4c4b = mkldnn_aBCde4c4b;
- const format_tag_t Abcde8a = mkldnn_Abcde8a;
- const format_tag_t ABcde8a8b = mkldnn_ABcde8a8b;
- const format_tag_t aBcde8b = mkldnn_aBcde8b;
- const format_tag_t ABcde8b16a2b = mkldnn_ABcde8b16a2b;
- const format_tag_t aBCde8b16c2b = mkldnn_aBCde8b16c2b;
- const format_tag_t ABcde8b8a = mkldnn_ABcde8b8a;
- const format_tag_t aBCde8b8c = mkldnn_aBCde8b8c;
- const format_tag_t aBCde8c16b2c = mkldnn_aBCde8c16b2c;
- const format_tag_t aBCde8c8b = mkldnn_aBCde8c8b;
- const format_tag_t aBcdef16b = mkldnn_aBcdef16b;
- const format_tag_t aBCdef16b16c = mkldnn_aBCdef16b16c;
- const format_tag_t aBCdef16c16b = mkldnn_aBCdef16c16b;
- const format_tag_t aBcdef4b = mkldnn_aBcdef4b;
- const format_tag_t aBCdef4c4b = mkldnn_aBCdef4c4b;
- const format_tag_t aBCdef8b8c = mkldnn_aBCdef8b8c;
- const format_tag_t aBCdef8c16b2c = mkldnn_aBCdef8c16b2c;
- const format_tag_t aBCdef8c8b = mkldnn_aBCdef8c8b;
- const format_tag_t aBdc16b = mkldnn_aBdc16b;
- const format_tag_t aBdc4b = mkldnn_aBdc4b;
- const format_tag_t aBdc8b = mkldnn_aBdc8b;
- const format_tag_t aBdec16b = mkldnn_aBdec16b;
- const format_tag_t aBdec4b = mkldnn_aBdec4b;
- const format_tag_t aBdec8b = mkldnn_aBdec8b;
- const format_tag_t aBdefc16b = mkldnn_aBdefc16b;
- const format_tag_t aBdefc4b = mkldnn_aBdefc4b;
- const format_tag_t aBdefc8b = mkldnn_aBdefc8b;
- const format_tag_t Acb16a = mkldnn_Acb16a;
- const format_tag_t Acb4a = mkldnn_Acb4a;
- const format_tag_t Acb8a = mkldnn_Acb8a;
- const format_tag_t aCBd16b16c = mkldnn_aCBd16b16c;
- const format_tag_t aCBde16b16c = mkldnn_aCBde16b16c;
- const format_tag_t Acdb16a = mkldnn_Acdb16a;
- const format_tag_t Acdb4a = mkldnn_Acdb4a;
- const format_tag_t Acdb8a = mkldnn_Acdb8a;
- const format_tag_t Acdeb16a = mkldnn_Acdeb16a;
- const format_tag_t Acdeb4a = mkldnn_Acdeb4a;
- const format_tag_t Acdeb8a = mkldnn_Acdeb8a;
- const format_tag_t BAc16a16b = mkldnn_BAc16a16b;
- const format_tag_t BAcd16a16b = mkldnn_BAcd16a16b;
- const format_tag_t last = mkldnn_format_tag_last;
-
- const format_tag_t x = mkldnn_x;
- const format_tag_t nc = mkldnn_nc;
- const format_tag_t cn = mkldnn_cn;
- const format_tag_t ncw = mkldnn_ncw;
- const format_tag_t nwc = mkldnn_nwc;
- const format_tag_t nchw = mkldnn_nchw;
- const format_tag_t nhwc = mkldnn_nhwc;
- const format_tag_t chwn = mkldnn_chwn;
- const format_tag_t ncdhw = mkldnn_ncdhw;
- const format_tag_t ndhwc = mkldnn_ndhwc;
- const format_tag_t oi = mkldnn_oi;
- const format_tag_t io = mkldnn_io;
- const format_tag_t oiw = mkldnn_oiw;
- const format_tag_t wio = mkldnn_wio;
- const format_tag_t oihw = mkldnn_oihw;
- const format_tag_t hwio = mkldnn_hwio;
- const format_tag_t ihwo = mkldnn_ihwo;
- const format_tag_t iohw = mkldnn_iohw;
- const format_tag_t oidhw = mkldnn_oidhw;
- const format_tag_t dhwio = mkldnn_dhwio;
- const format_tag_t goiw = mkldnn_goiw;
- const format_tag_t goihw = mkldnn_goihw;
- const format_tag_t hwigo = mkldnn_hwigo;
- const format_tag_t giohw = mkldnn_giohw;
- const format_tag_t goidhw = mkldnn_goidhw;
- const format_tag_t tnc = mkldnn_tnc;
- const format_tag_t ntc = mkldnn_ntc;
- const format_tag_t ldsnc = mkldnn_ldsnc;
- const format_tag_t ldigo = mkldnn_ldigo;
- const format_tag_t ldgoi = mkldnn_ldgoi;
- const format_tag_t ldgo = mkldnn_ldgo;
- const format_tag_t nCdhw16c = mkldnn_nCdhw16c;
- const format_tag_t nCdhw4c = mkldnn_nCdhw4c;
- const format_tag_t nCdhw8c = mkldnn_nCdhw8c;
- const format_tag_t nChw16c = mkldnn_nChw16c;
- const format_tag_t nChw4c = mkldnn_nChw4c;
- const format_tag_t nChw8c = mkldnn_nChw8c;
- const format_tag_t nCw16c = mkldnn_nCw16c;
- const format_tag_t nCw4c = mkldnn_nCw4c;
- const format_tag_t nCw8c = mkldnn_nCw8c;
- const format_tag_t IOw16o16i = mkldnn_IOw16o16i;
- const format_tag_t OIw16i16o = mkldnn_OIw16i16o;
- const format_tag_t OIw16o16i = mkldnn_OIw16o16i;
- const format_tag_t Oiw16o = mkldnn_Oiw16o;
- const format_tag_t OIw4i16o4i = mkldnn_OIw4i16o4i;
- const format_tag_t OIw4i4o = mkldnn_OIw4i4o;
- const format_tag_t Oiw4o = mkldnn_Oiw4o;
- const format_tag_t OIw8i16o2i = mkldnn_OIw8i16o2i;
- const format_tag_t OIw8i8o = mkldnn_OIw8i8o;
- const format_tag_t OIw8o16i2o = mkldnn_OIw8o16i2o;
- const format_tag_t OIw8o8i = mkldnn_OIw8o8i;
- const format_tag_t Owi16o = mkldnn_Owi16o;
- const format_tag_t Owi4o = mkldnn_Owi4o;
- const format_tag_t Owi8o = mkldnn_Owi8o;
- const format_tag_t IOhw16o16i = mkldnn_IOhw16o16i;
- const format_tag_t Ohwi16o = mkldnn_Ohwi16o;
- const format_tag_t Ohwi4o = mkldnn_Ohwi4o;
- const format_tag_t Ohwi8o = mkldnn_Ohwi8o;
- const format_tag_t OIhw16i16o = mkldnn_OIhw16i16o;
- const format_tag_t OIhw16o16i = mkldnn_OIhw16o16i;
- const format_tag_t Oihw16o = mkldnn_Oihw16o;
- const format_tag_t OIhw4i16o4i = mkldnn_OIhw4i16o4i;
- const format_tag_t OIhw4i4o = mkldnn_OIhw4i4o;
- const format_tag_t Oihw4o = mkldnn_Oihw4o;
- const format_tag_t OIhw8i16o2i = mkldnn_OIhw8i16o2i;
- const format_tag_t OIhw8i8o = mkldnn_OIhw8i8o;
- const format_tag_t OIhw8o16i2o = mkldnn_OIhw8o16i2o;
- const format_tag_t OIhw8o8i = mkldnn_OIhw8o8i;
- const format_tag_t Odhwi16o = mkldnn_Odhwi16o;
- const format_tag_t Odhwi4o = mkldnn_Odhwi4o;
- const format_tag_t Odhwi8o = mkldnn_Odhwi8o;
- const format_tag_t OIdhw16i16o = mkldnn_OIdhw16i16o;
- const format_tag_t OIdhw16o16i = mkldnn_OIdhw16o16i;
- const format_tag_t Oidhw16o = mkldnn_Oidhw16o;
- const format_tag_t OIdhw4i4o = mkldnn_OIdhw4i4o;
- const format_tag_t Oidhw4o = mkldnn_Oidhw4o;
- const format_tag_t OIdhw8i16o2i = mkldnn_OIdhw8i16o2i;
- const format_tag_t OIdhw8i8o = mkldnn_OIdhw8i8o;
- const format_tag_t OIdhw8o8i = mkldnn_OIdhw8o8i;
- const format_tag_t gIOw16o16i = mkldnn_gIOw16o16i;
- const format_tag_t Goiw16g = mkldnn_Goiw16g;
- const format_tag_t gOIw16i16o = mkldnn_gOIw16i16o;
- const format_tag_t gOIw16o16i = mkldnn_gOIw16o16i;
- const format_tag_t gOiw16o = mkldnn_gOiw16o;
- const format_tag_t gOIw4i16o4i = mkldnn_gOIw4i16o4i;
- const format_tag_t gOIw4i4o = mkldnn_gOIw4i4o;
- const format_tag_t gOiw4o = mkldnn_gOiw4o;
- const format_tag_t gOIw8i16o2i = mkldnn_gOIw8i16o2i;
- const format_tag_t gOIw8i8o = mkldnn_gOIw8i8o;
- const format_tag_t gOIw8o16i2o = mkldnn_gOIw8o16i2o;
- const format_tag_t gOIw8o8i = mkldnn_gOIw8o8i;
- const format_tag_t gOwi16o = mkldnn_gOwi16o;
- const format_tag_t gOwi4o = mkldnn_gOwi4o;
- const format_tag_t gOwi8o = mkldnn_gOwi8o;
- const format_tag_t gIOhw16o16i = mkldnn_gIOhw16o16i;
- const format_tag_t gOhwi16o = mkldnn_gOhwi16o;
- const format_tag_t gOhwi4o = mkldnn_gOhwi4o;
- const format_tag_t gOhwi8o = mkldnn_gOhwi8o;
- const format_tag_t Goihw16g = mkldnn_Goihw16g;
- const format_tag_t gOIhw16i16o = mkldnn_gOIhw16i16o;
- const format_tag_t gOIhw16o16i = mkldnn_gOIhw16o16i;
- const format_tag_t gOihw16o = mkldnn_gOihw16o;
- const format_tag_t gOIhw2i8o4i = mkldnn_gOIhw2i8o4i;
- const format_tag_t gOIhw4i16o4i = mkldnn_gOIhw4i16o4i;
- const format_tag_t gOIhw4i4o = mkldnn_gOIhw4i4o;
- const format_tag_t gOIhw4o4i = mkldnn_gOIhw4o4i;
- const format_tag_t gOihw4o = mkldnn_gOihw4o;
- const format_tag_t Goihw8g = mkldnn_Goihw8g;
- const format_tag_t gOIhw8i16o2i = mkldnn_gOIhw8i16o2i;
- const format_tag_t gOIhw8i8o = mkldnn_gOIhw8i8o;
- const format_tag_t gOIhw8o16i2o = mkldnn_gOIhw8o16i2o;
- const format_tag_t gOIhw8o8i = mkldnn_gOIhw8o8i;
- const format_tag_t gOdhwi16o = mkldnn_gOdhwi16o;
- const format_tag_t gOdhwi4o = mkldnn_gOdhwi4o;
- const format_tag_t gOdhwi8o = mkldnn_gOdhwi8o;
- const format_tag_t gOIdhw16i16o = mkldnn_gOIdhw16i16o;
- const format_tag_t gOIdhw16o16i = mkldnn_gOIdhw16o16i;
- const format_tag_t gOidhw16o = mkldnn_gOidhw16o;
- const format_tag_t gOIdhw4i4o = mkldnn_gOIdhw4i4o;
- const format_tag_t gOidhw4o = mkldnn_gOidhw4o;
- const format_tag_t gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i;
- const format_tag_t gOIdhw8i8o = mkldnn_gOIdhw8i8o;
- const format_tag_t gOIdhw8o8i = mkldnn_gOIdhw8o8i;
-}
-
-using memory_extra_flags_t = mkldnn_memory_extra_flags_t;
-namespace memory_extra_flags {
- const memory_extra_flags_t none = mkldnn_memory_extra_flag_none;
- const memory_extra_flags_t compensation_conv_s8s8 = mkldnn_memory_extra_flag_compensation_conv_s8s8;
- const memory_extra_flags_t scale_adjust = mkldnn_memory_extra_flag_scale_adjust;
-}
-
-using padding_kind_t = mkldnn_padding_kind_t;
-namespace padding_kind {
- const padding_kind_t padding_zero = mkldnn_padding_zero;
-}
-
-using engine_kind_t = mkldnn_engine_kind_t;
-namespace engine_kind {
- const engine_kind_t any_engine = mkldnn_any_engine;
- const engine_kind_t cpu = mkldnn_cpu;
-}
-
-using primitive_kind_t = mkldnn_primitive_kind_t;
-namespace primitive_kind {
- const primitive_kind_t undefined = mkldnn_undefined_primitive;
- const primitive_kind_t reorder = mkldnn_reorder;
- const primitive_kind_t concat = mkldnn_concat;
- const primitive_kind_t sum = mkldnn_sum;
- const primitive_kind_t convolution = mkldnn_convolution;
- const primitive_kind_t deconvolution = mkldnn_deconvolution;
- const primitive_kind_t shuffle = mkldnn_shuffle;
- const primitive_kind_t eltwise = mkldnn_eltwise;
- const primitive_kind_t softmax = mkldnn_softmax;
- const primitive_kind_t pooling = mkldnn_pooling;
- const primitive_kind_t lrn = mkldnn_lrn;
- const primitive_kind_t batch_normalization = mkldnn_batch_normalization;
- const primitive_kind_t inner_product = mkldnn_inner_product;
- const primitive_kind_t rnn = mkldnn_rnn;
-}
-
-using query_t = mkldnn_query_t;
-namespace query {
- const query_t undef = mkldnn_query_undef;
-
- const query_t engine = mkldnn_query_engine;
- const query_t primitive_kind = mkldnn_query_primitive_kind;
-
- const query_t num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32;
- const query_t num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32;
-
- const query_t time_estimate_f64 = mkldnn_query_time_estimate_f64;
- const query_t memory_consumption_s64 = mkldnn_query_memory_consumption_s64;
-
- const query_t scratchpad_engine = mkldnn_query_scratchpad_engine;
-
- const query_t impl_info_str = mkldnn_query_impl_info_str;
-
- const query_t some_d = mkldnn_query_some_d;
- const query_t op_d = mkldnn_query_op_d;
- const query_t convolution_d = mkldnn_query_convolution_d;
- const query_t deconvolution_d = mkldnn_query_deconvolution_d;
- const query_t shuffle_d = mkldnn_query_shuffle_d;
- const query_t eltwise_d = mkldnn_query_eltwise_d;
- const query_t softmax_d = mkldnn_query_softmax_d;
- const query_t pooling_d = mkldnn_query_pooling_d;
- const query_t lrn_d = mkldnn_query_lrn_d;
- const query_t batch_normalization_d = mkldnn_query_batch_normalization_d;
- const query_t inner_product_d = mkldnn_query_inner_product_d;
- const query_t rnn_d = mkldnn_query_rnn_d;
-
- const query_t some_md = mkldnn_query_some_md;
- const query_t src_md = mkldnn_query_src_md;
- const query_t diff_src_md = mkldnn_query_diff_src_md;
- const query_t weights_md = mkldnn_query_weights_md;
- const query_t diff_weights_md = mkldnn_query_diff_weights_md;
- const query_t dst_md = mkldnn_query_dst_md;
- const query_t diff_dst_md = mkldnn_query_diff_dst_md;
-
- const query_t workspace_md = mkldnn_query_workspace_md;
- const query_t scratchpad_md = mkldnn_query_scratchpad_md;
-}
-
-using blocking_desc_t = mkldnn_blocking_desc_t;
-using rnn_packed_desc_t = mkldnn_rnn_packed_desc_t;
-using wino_desc_t = mkldnn_wino_desc_t;
-using memory_extra_desc_t = mkldnn_memory_extra_desc_t;
-using memory_desc_t = mkldnn_memory_desc_t;
-using convolution_desc_t = mkldnn_convolution_desc_t;
-using deconvolution_desc_t = mkldnn_deconvolution_desc_t;
-using shuffle_desc_t = mkldnn_shuffle_desc_t;
-using pooling_desc_t = mkldnn_pooling_desc_t;
-using eltwise_desc_t = mkldnn_eltwise_desc_t;
-using softmax_desc_t = mkldnn_softmax_desc_t;
-using lrn_desc_t = mkldnn_lrn_desc_t;
-using batch_normalization_desc_t = mkldnn_batch_normalization_desc_t;
-using inner_product_desc_t = mkldnn_inner_product_desc_t;
-
-using rnn_direction_t = mkldnn_rnn_direction_t;
-using rnn_cell_desc_t = mkldnn_rnn_cell_desc_t;
-using rnn_desc_t = mkldnn_rnn_desc_t;
-
-/* C op_desc_t, which eventually are just (void*) */
-using c_op_desc_t = mkldnn_op_desc_t;
-using const_c_op_desc_t = const_mkldnn_op_desc_t;
-
-struct op_desc_t {
- union {
- primitive_kind_t kind;
- convolution_desc_t convolution;
- deconvolution_desc_t deconvolution;
- shuffle_desc_t shuffle;
- pooling_desc_t pooling;
- eltwise_desc_t eltwise;
- softmax_desc_t softmax;
- lrn_desc_t lrn;
- batch_normalization_desc_t batch_normalization;
- inner_product_desc_t inner_product;
- rnn_desc_t rnn;
- };
-
- op_desc_t(const primitive_kind_t &_): kind(_) {}
-
-# define DECL_CTOR_AND_CONVERTERS(c_type, name) \
- op_desc_t(const c_type &_): name(_) {} \
- static op_desc_t *convert_from_c(c_type *_) \
- { return reinterpret_cast<op_desc_t*>(_); } \
- static const op_desc_t *convert_from_c(const c_type *_) \
- { return reinterpret_cast<const op_desc_t*>(_); }
-
- DECL_CTOR_AND_CONVERTERS(convolution_desc_t, convolution);
- DECL_CTOR_AND_CONVERTERS(shuffle_desc_t, shuffle);
- DECL_CTOR_AND_CONVERTERS(pooling_desc_t, pooling);
- DECL_CTOR_AND_CONVERTERS(eltwise_desc_t, eltwise);
- DECL_CTOR_AND_CONVERTERS(softmax_desc_t, softmax);
- DECL_CTOR_AND_CONVERTERS(lrn_desc_t, lrn);
- DECL_CTOR_AND_CONVERTERS(batch_normalization_desc_t, batch_normalization);
- DECL_CTOR_AND_CONVERTERS(inner_product_desc_t, inner_product);
- DECL_CTOR_AND_CONVERTERS(rnn_desc_t, rnn);
-
-# undef DECL_CTOR_AND_CONVERTERS
-};
-
-using engine_t = mkldnn_engine;
-using primitive_desc_iterator_t = mkldnn_primitive_desc_iterator;
-using primitive_desc_t = mkldnn_primitive_desc;
-using primitive_attr_t = mkldnn_primitive_attr;
-using post_ops_t = mkldnn_post_ops;
-using memory_t = mkldnn_memory;
-using primitive_t = mkldnn_primitive;
-
-using primitive_arg_index_t = int;
-
-using stream_flags_t = mkldnn_stream_flags_t;
-namespace stream_flags {
- const stream_flags_t default_flags = mkldnn_stream_default_flags;
-}
-using stream_t = mkldnn_stream;
-
-/* forward declaration of the internal primitive_desc types */
-struct batch_normalization_bwd_pd_t;
-struct batch_normalization_fwd_pd_t;
-struct batch_normalization_pd_t;
-struct concat_pd_t;
-struct convolution_bwd_data_pd_t;
-struct convolution_bwd_weights_pd_t;
-struct convolution_fwd_pd_t;
-struct convolution_pd_t;
-struct deconvolution_bwd_data_pd_t;
-struct deconvolution_bwd_weights_pd_t;
-struct deconvolution_fwd_pd_t;
-struct deconvolution_pd_t;
-struct eltwise_bwd_pd_t;
-struct eltwise_fwd_pd_t;
-struct eltwise_pd_t;
-struct inner_product_bwd_data_pd_t;
-struct inner_product_bwd_weights_pd_t;
-struct inner_product_fwd_pd_t;
-struct inner_product_pd_t;
-struct lrn_bwd_pd_t;
-struct lrn_fwd_pd_t;
-struct lrn_pd_t;
-struct pooling_bwd_pd_t;
-struct pooling_fwd_pd_t;
-struct pooling_pd_t;
-struct reorder_pd_t;
-struct rnn_bwd_pd_t;
-struct rnn_fwd_pd_t;
-struct rnn_pd_t;
-struct shuffle_pd_t;
-struct softmax_bwd_pd_t;
-struct softmax_fwd_pd_t;
-struct softmax_pd_t;
-struct sum_pd_t;
-
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/concat.cpp b/thirdparty/oidn/mkl-dnn/src/common/concat.cpp
deleted file mode 100644
index ed4c35c6e9..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/concat.cpp
+++ /dev/null
@@ -1,86 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "engine.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "concat_pd.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::status;
-
-status_t mkldnn_concat_primitive_desc_create(primitive_desc_t **concat_pd,
- const memory_desc_t *dst_md, int n, int concat_dim,
- const memory_desc_t *src_mds,
- const primitive_attr_t *attr,
- engine_t *engine) {
- bool args_ok = !any_null(concat_pd, src_mds) && n > 0;
- if (!args_ok) return invalid_arguments;
-
- const primitive_attr_t dummy_attr;
- if (attr == NULL)
- attr = &dummy_attr;
-
- const int ndims = src_mds[0].ndims;
- const dims_t &dims = src_mds[0].dims;
- const data_type_t dt = src_mds[0].data_type;
-
- int concat_dim_sz = dims[concat_dim];
- for (int i = 1; i < n; ++i) {
- if (src_mds[i].ndims != ndims) return invalid_arguments;
- for (int d = 0; d < ndims; ++d) {
- if (d == concat_dim) continue;
- if (src_mds[i].dims[d] != dims[d])
- return invalid_arguments;
- }
- if (src_mds[i].data_type != dt) return invalid_arguments;
- concat_dim_sz += src_mds[i].dims[concat_dim];
- }
-
- memory_desc_t dummy_dst_md;
- if (dst_md) {
- if (dst_md->ndims != ndims) return invalid_arguments;
- for (int d = 0; d < ndims; ++d) {
- if (dst_md->dims[d] !=
- (d == concat_dim ? concat_dim_sz : dims[d]))
- return invalid_arguments;
- }
- } else {
- dummy_dst_md = src_mds[0];
- dummy_dst_md.dims[concat_dim] = concat_dim_sz;
- dummy_dst_md.format_kind = format_kind::any;
- dst_md = &dummy_dst_md;
- }
-
- auto c_pd = reinterpret_cast<concat_pd_t **>(concat_pd);
-
- for (auto c = engine->get_concat_implementation_list(); *c; ++c) {
- if ((*c)(c_pd, engine, attr, dst_md, n, concat_dim, src_mds)
- == success) {
- (*c_pd)->init_info();
- (*c_pd)->init_scratchpad_md();
- return success;
- }
- }
- return unimplemented;
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp
deleted file mode 100644
index 29311927e2..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp
+++ /dev/null
@@ -1,211 +0,0 @@
-/*******************************************************************************
-* Copyright 2019 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CONCAT_PD_HPP
-#define CONCAT_PD_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "nstl.hpp"
-#include "primitive_desc.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-struct concat_pd_t: public primitive_desc_t {
- concat_pd_t(engine_t *engine, const primitive_attr_t *attr,
- const memory_desc_t *dst_md, int n, int concat_dim,
- const memory_desc_t *src_mds)
- : primitive_desc_t(engine, attr, primitive_kind::concat)
- , n_(n), concat_dim_(concat_dim), dst_md_(*dst_md)
- {
- src_mds_.reserve(n_);
- for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]);
- }
-
- concat_pd_t(const concat_pd_t &rhs) = default;
-
- virtual void init_info() override { impl::init_info(this, this->info_); }
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (arg >= MKLDNN_ARG_MULTIPLE_SRC
- && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs())
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DST)
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override
- { return index < n_inputs() ? &src_mds_[index] : nullptr; }
- virtual const memory_desc_t *dst_md(int index = 0) const override
- { return index == 0 ? &dst_md_ : nullptr; }
-
- virtual int n_inputs() const override { return n_; }
- virtual int n_outputs() const override { return 1; }
-
- int concat_dim() const { return concat_dim_; }
-
- const memory_desc_t *src_image_md(int index = 0) const
- { return index < n_inputs() ? &src_image_mds_[index] : nullptr; }
-
-protected:
- int n_, concat_dim_;
- memory_desc_t dst_md_;
- nstl::vector<memory_desc_t> src_mds_;
-
- /* contains images of srcs in the dst memory (if possible)
- * Lives here to simplify some implementations. An implementation might
- * use this auxiliary array iff init() returned success */
- nstl::vector<memory_desc_t> src_image_mds_;
-
-protected:
- /* inits src_image_mds_ and dst_md_ in simple cases. The call may fail */
- status_t init() {
- bool ok = true
- && set_default_params() == status::success
- && attr()->has_default_values();
- if (!ok) return status::unimplemented;
-
- for (int i = 0; i < n_; ++i) {
- const memory_desc_wrapper i_d(&src_mds_[i]);
- if (!i_d.is_blocking_desc() || i_d.is_additional_buffer())
- return status::unimplemented;
- }
-
- const int ndims = dst_md_.ndims;
- int current_concat_dim_offset = 0;
- for (int i = 0; i < n_; ++i) {
- const int dim = src_mds_[i].dims[concat_dim_];
- dims_t dims, offsets = {};
- utils::array_copy(dims, dst_md_.dims, ndims);
- dims[concat_dim_] = dim;
- offsets[concat_dim_] = current_concat_dim_offset;
-
- memory_desc_t src_img_d;
- status_t status = mkldnn_memory_desc_init_submemory(&src_img_d,
- &dst_md_, dims, offsets);
- if (status != status::success) return status;
- src_image_mds_.push_back(src_img_d);
- current_concat_dim_offset += dim;
- }
-
- return status::success;
- }
-
- status_t set_default_params() {
- if (dst_md_.format_kind != format_kind::any)
- return status::success;
-
- const int ndims = dst_md_.ndims;
-
- /* The stupidest ever heuristics (but not the same as we had before):
- * - Pick the first non-plain format;
- * - If all formats are plain or it is not possible to create a
- * blocked format for the output, pick the format of the plain input
- * - If this fails as well, use plain layout (abcd...)
- */
- status_t status = status::unimplemented;
- for (int i = 0; i < n_; ++i) {
- const memory_desc_wrapper src_d(src_mds_[i]);
- if (src_d.is_blocking_desc() && !src_d.is_plain()) {
- status = memory_desc_init_by_blocking_desc(dst_md_,
- src_d.blocking_desc());
- if (status == status::success) break;
- }
- }
-
- if (status == status::success) {
- /* check if we can create a sub-memory for the dst */
- bool desired_format_ok = true;
- int current_concat_dim_offset = 0;
- for (int i = 0; i < n_; ++i) {
- const int dim = src_mds_[i].dims[concat_dim_];
- dims_t dims, offsets = {};
- utils::array_copy(dims, dst_md_.dims, ndims);
- dims[concat_dim_] = dim;
- offsets[concat_dim_] = current_concat_dim_offset;
-
- memory_desc_t src_img_d;
- status_t status = mkldnn_memory_desc_init_submemory(&src_img_d,
- &dst_md_, dims, offsets);
- if (status != status::success) {
- desired_format_ok = false;
- break;
- }
- current_concat_dim_offset += dim;
- }
-
- if (!desired_format_ok)
- status = status::unimplemented;
- }
-
- /* if no success so far, try using the format of the first plain input */
- if (status != status::success) {
- for (int i = 0; i < n_; ++i) {
- const memory_desc_wrapper src_d(src_mds_[i]);
- if (src_d.is_blocking_desc() && src_d.is_plain()) {
- status = memory_desc_init_by_blocking_desc(dst_md_,
- memory_desc_wrapper(src_mds_[0]).blocking_desc());
- if (status == status::success) return status;
- }
- }
- }
-
- /* the last line of defense: use plain abcd... format */
- if (status != status::success)
- status = memory_desc_init_by_strides(dst_md_, nullptr);
-
- return status;
- }
-};
-
-#define DECLARE_CONCAT_PD_t(impl_name, ...) \
- static status_t create(concat_pd_t **concat_pd, \
- engine_t *engine, const primitive_attr_t *attr, \
- const memory_desc_t *dst_md, int n, int concat_dim, \
- const memory_desc_t *src_mds) { \
- using namespace status; \
- auto _pd = new pd_t(engine, attr, dst_md, n, concat_dim, src_mds); \
- if (_pd == nullptr) return out_of_memory; \
- if (_pd->init() != success) { delete _pd; return unimplemented; } \
- return safe_ptr_assign<concat_pd_t>(*concat_pd, _pd); \
- } \
- virtual status_t create_primitive(primitive_t **p) const override { \
- double ms = get_msec(); \
- auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
- ms = get_msec() - ms; \
- if (mkldnn_verbose()->level >= 2) { \
- printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
- fflush(0); \
- } \
- return ret; \
- } \
- virtual pd_t *clone() const override { return new pd_t(*this); } \
- virtual const char *name() const override { return impl_name; } \
-
-#define DECLARE_CONCAT_PD_T(impl_name, ...) \
- DECLARE_CONCAT_PD_t(impl_name, __VA_ARGS__)
-
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp b/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp
deleted file mode 100644
index 0c5c02bcd1..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp
+++ /dev/null
@@ -1,200 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::prop_kind;
-using namespace mkldnn::impl::alg_kind;
-using namespace mkldnn::impl::types;
-
-namespace mkldnn {
-namespace impl {
-status_t conv_desc_init(convolution_desc_t *conv_desc,
- prop_kind_t prop_kind, alg_kind_t alg_kind,
- const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
- const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
- const dims_t strides, const dims_t dilates,
- const dims_t padding_l, const dims_t padding_r,
- padding_kind_t padding_kind) {
- bool args_ok = true
- && !any_null(conv_desc, src_desc, weights_desc, dst_desc, strides,
- padding_l)
- && one_of(alg_kind, convolution_auto, convolution_direct, convolution_winograd)
- && one_of(padding_kind, padding_kind::padding_zero);
- if (!args_ok) return invalid_arguments;
-
- if (padding_r == nullptr) padding_r = padding_l;
-
- auto cd = convolution_desc_t();
- cd.primitive_kind = primitive_kind::convolution;
- cd.prop_kind = prop_kind;
- cd.alg_kind = alg_kind;
-
- cd.diff_src_desc = cd.src_desc = zero_md();
- cd.diff_dst_desc = cd.dst_desc = zero_md();
- cd.diff_weights_desc = cd.weights_desc = zero_md();
- cd.diff_bias_desc = cd.bias_desc = zero_md();
-
- const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
- const bool with_bias =
- bias_desc && bias_desc->format_kind != format_kind::undef;
- const bool with_groups = weights_desc->ndims == src_desc->ndims + 1;
-
- (prop_kind == backward_data ? cd.diff_src_desc : cd.src_desc) = *src_desc;
- (is_fwd ? cd.dst_desc : cd.diff_dst_desc) = *dst_desc;
- (prop_kind == backward_weights ? cd.diff_weights_desc : cd.weights_desc) =
- *weights_desc;
- if (with_bias)
- (prop_kind == backward_weights ? cd.diff_bias_desc : cd.bias_desc) =
- *bias_desc;
-
- int sp_dims = src_desc->ndims - 2;
- utils::array_copy(cd.strides, strides, sp_dims);
- utils::array_copy(cd.padding[0], padding_l, sp_dims);
- utils::array_copy(cd.padding[1], padding_r, sp_dims);
- if (dilates)
- utils::array_copy(cd.dilates, dilates, sp_dims);
- else
- utils::array_set(cd.dilates, 0, sp_dims);
-
- cd.padding_kind = padding_kind;
- cd.accum_data_type = types::default_accum_data_type(src_desc->data_type,
- weights_desc->data_type, dst_desc->data_type, prop_kind);
-
- const int g = with_groups ? weights_desc->dims[0] : 1;
- const int bias_dim = prop_kind == backward_data
- ? src_desc->dims[1]
- : dst_desc->dims[1];
-
- bool consistency = true
- && memory_desc_wrapper(weights_desc).nelems()
- && src_desc->ndims == dst_desc->ndims
- && utils::one_of(src_desc->ndims, 3, 4, 5)
- && utils::one_of(weights_desc->ndims, src_desc->ndims,
- src_desc->ndims + 1)
- && (with_bias ? bias_desc->ndims == 1 : true)
- && (with_bias ? bias_desc->dims[0] == bias_dim : true)
- && src_desc->dims[0] == dst_desc->dims[0]
- && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1]
- && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0];
- for (int i = 2; i < src_desc->ndims; ++i)
- {
- int src = src_desc->dims[i];
- int ker = weights_desc->dims[with_groups + i];
- int dil = cd.dilates[i - 2];
- int pad_l = padding_l[i - 2];
- int pad_r = padding_r[i - 2];
- int str = strides[i - 2];
- int dst = dst_desc->dims[i];
- int ker_range = 1 + (ker - 1) * (dil + 1);
-
- if (str < 1) return invalid_arguments;
- consistency = consistency
- && dil >= 0
- && pad_l >= 0
- && pad_r + str > 0
- && (src - ker_range + pad_l + pad_r) / str + 1 == dst;
- }
- if (!consistency) return invalid_arguments;
-
- *conv_desc = cd;
- return success;
-}
-}
-}
-
-status_t mkldnn_convolution_forward_desc_init(convolution_desc_t *conv_desc,
- prop_kind_t prop_kind, alg_kind_t alg_kind,
- const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
- const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
- const dims_t strides, const dims_t padding_l, const dims_t padding_r,
- padding_kind_t padding_kind) {
- if (!one_of(prop_kind, forward_training, forward_inference))
- return invalid_arguments;
- return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc,
- weights_desc, bias_desc, dst_desc, strides, nullptr,
- padding_l, padding_r, padding_kind);
-}
-
-status_t mkldnn_dilated_convolution_forward_desc_init(
- convolution_desc_t *conv_desc, prop_kind_t prop_kind,
- alg_kind_t alg_kind, const memory_desc_t *src_desc,
- const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
- const memory_desc_t *dst_desc, const dims_t strides,
- const dims_t dilates, const dims_t padding_l,
- const dims_t padding_r, padding_kind_t padding_kind) {
- if (!one_of(prop_kind, forward_training, forward_inference))
- return invalid_arguments;
- return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc,
- weights_desc, bias_desc, dst_desc, strides, dilates,
- padding_l, padding_r, padding_kind);
-}
-
-status_t mkldnn_convolution_backward_data_desc_init(
- convolution_desc_t *conv_desc, alg_kind_t alg_kind,
- const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
- const memory_desc_t *diff_dst_desc, const dims_t strides,
- const dims_t padding_l, const dims_t padding_r,
- padding_kind_t padding_kind) {
- return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc,
- weights_desc, nullptr, diff_dst_desc, strides, nullptr,
- padding_l, padding_r, padding_kind);
-}
-
-status_t mkldnn_dilated_convolution_backward_data_desc_init(
- convolution_desc_t *conv_desc, alg_kind_t alg_kind,
- const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
- const memory_desc_t *diff_dst_desc, const dims_t strides,
- const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
- padding_kind_t padding_kind) {
- return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc,
- weights_desc, nullptr, diff_dst_desc, strides, dilates,
- padding_l, padding_r, padding_kind);
-}
-
-status_t mkldnn_convolution_backward_weights_desc_init(
- convolution_desc_t *conv_desc, alg_kind_t alg_kind,
- const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
- const memory_desc_t *diff_bias_desc,
- const memory_desc_t *diff_dst_desc, const dims_t strides,
- const dims_t padding_l, const dims_t padding_r,
- padding_kind_t padding_kind) {
- return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc,
- diff_weights_desc, diff_bias_desc, diff_dst_desc, strides,
- nullptr, padding_l, padding_r, padding_kind);
-}
-
-status_t mkldnn_dilated_convolution_backward_weights_desc_init(
- convolution_desc_t *conv_desc, alg_kind_t alg_kind,
- const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
- const memory_desc_t *diff_bias_desc,
- const memory_desc_t *diff_dst_desc, const dims_t strides,
- const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
- padding_kind_t padding_kind) {
- return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc,
- diff_weights_desc, diff_bias_desc, diff_dst_desc, strides,
- dilates, padding_l, padding_r, padding_kind);
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp
deleted file mode 100644
index 9604e0acf5..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp
+++ /dev/null
@@ -1,56 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "utils.hpp"
-
-#include "convolution_pd.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-using namespace prop_kind;
-
-memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc) {
- return desc->prop_kind == backward_data
- ? &desc->diff_src_desc : &desc->src_desc;
-}
-
-memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc) {
- return desc->prop_kind == backward_weights
- ? &desc->diff_weights_desc : &desc->weights_desc;
-}
-
-memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc) {
- return desc->prop_kind == backward_weights
- ? &desc->diff_bias_desc : &desc->bias_desc;
-}
-
-memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc) {
- return utils::one_of(desc->prop_kind, forward_inference, forward_training)
- ? &desc->dst_desc : &desc->diff_dst_desc;
-}
-
-const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc)
-{ return conv_prop_invariant_src_d(const_cast<convolution_desc_t *>(desc)); }
-const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc)
-{ return conv_prop_invariant_wei_d(const_cast<convolution_desc_t *>(desc)); }
-const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc)
-{ return conv_prop_invariant_bia_d(const_cast<convolution_desc_t *>(desc)); }
-const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc)
-{ return conv_prop_invariant_dst_d(const_cast<convolution_desc_t *>(desc)); }
-
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp
deleted file mode 100644
index b10c36db49..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp
+++ /dev/null
@@ -1,348 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CONVOLUTION_PD_HPP
-#define CONVOLUTION_PD_HPP
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "primitive_desc.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-status_t conv_desc_init(convolution_desc_t *conv_desc,
- prop_kind_t prop_kind, alg_kind_t alg_kind,
- const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
- const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
- const dims_t strides, const dims_t dilates,
- const dims_t padding_l, const dims_t padding_r,
- padding_kind_t padding_kind);
-
-memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc);
-memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc);
-memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc);
-memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc);
-const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc);
-const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc);
-const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc);
-const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc);
-
-struct convolution_fwd_pd_t;
-
-struct convolution_pd_t: public primitive_desc_t {
- static constexpr auto base_pkind = primitive_kind::convolution;
-
- convolution_pd_t(engine_t *engine,
- const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : primitive_desc_t(engine, attr, base_pkind)
- , desc_(*adesc)
- , hint_fwd_pd_(hint_fwd_pd)
- {}
-
- const convolution_desc_t *desc() const { return &desc_; }
- virtual const op_desc_t *op_desc() const override
- { return reinterpret_cast<const op_desc_t *>(this->desc()); }
- virtual void init_info() override { impl::init_info(this, this->info_); }
-
- virtual status_t query(query_t what, int idx, void *result) const override {
- switch (what) {
- case pkind_traits<base_pkind>::query_d:
- *(const convolution_desc_t**)result = desc(); break;
- default: return primitive_desc_t::query(what, idx, result);
- }
- return status::success;
- }
-
- /* common conv aux functions */
-
- dim_t MB() const { return _src_md()->dims[0]; }
-
- dim_t IC() const { return _src_md()->dims[1]; }
- dim_t OC() const { return _dst_md()->dims[1]; }
- dim_t G() const { return with_groups() ? _wei_md()->dims[0] : 1; }
-
- dim_t ID() const { return ndims() >= 5 ? _src_md()->dims[ndims() - 3] : 1; }
- dim_t IH() const { return ndims() >= 4 ? _src_md()->dims[ndims() - 2] : 1; }
- dim_t IW() const { return _src_md()->dims[ndims() - 1]; }
-
- dim_t OD() const { return ndims() >= 5 ? _dst_md()->dims[ndims() - 3] : 1; }
- dim_t OH() const { return ndims() >= 4 ? _dst_md()->dims[ndims() - 2] : 1; }
- dim_t OW() const { return _dst_md()->dims[ndims() - 1]; }
-
- dim_t KD() const { return ndims() >= 5 ? _wei_md()->dims[ndims() + with_groups() - 3] : 1; }
- dim_t KH() const { return ndims() >= 4 ? _wei_md()->dims[ndims() + with_groups() - 2] : 1; }
- dim_t KW() const { return _wei_md()->dims[ndims() + with_groups() - 1]; }
-
- dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
- dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
- dim_t KSW() const { return desc_.strides[ndims() - 3]; }
-
- dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; }
- dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; }
- dim_t KDW() const { return desc_.dilates[ndims() - 3]; }
-
- dim_t padFront() const { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
- dim_t padBack() const { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
- dim_t padT() const { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
- dim_t padB() const { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
- dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
- dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
-
- int ndims() const { return _src_md()->ndims; }
-
- bool with_bias() const { return !memory_desc_wrapper(*_bia_md()).is_zero(); }
- bool with_groups() const { return _wei_md()->ndims == ndims() + 1; }
-
- bool is_fwd() const {
- return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
- prop_kind::forward_inference);
- }
-
- bool has_zero_dim_memory() const {
- const auto s_d = memory_desc_wrapper(*_src_md());
- const auto d_d = memory_desc_wrapper(*_dst_md());
- return s_d.has_zero_dim() || d_d.has_zero_dim();
- }
-
-protected:
- convolution_desc_t desc_;
- const convolution_fwd_pd_t *hint_fwd_pd_;
-
- bool set_default_formats_common_template(
- memory_desc_t &src_md, format_tag_t src_tag,
- memory_desc_t &wei_md, format_tag_t wei_tag,
- memory_desc_t &dst_md, format_tag_t dst_tag,
- memory_desc_t &bia_md) {
- using namespace format_tag;
-
-# define IS_OK(f) \
- do { if ((f) != status::success) return false; } while(0)
- if (src_md.format_kind == format_kind::any
- && !utils::one_of(src_tag, any, undef))
- IS_OK(memory_desc_init_by_tag(src_md, src_tag));
- if (dst_md.format_kind == format_kind::any
- && !utils::one_of(dst_tag, any, undef))
- IS_OK(memory_desc_init_by_tag(dst_md, dst_tag));
- if (wei_md.format_kind == format_kind::any
- && !utils::one_of(wei_tag, any, undef))
- IS_OK(memory_desc_init_by_tag(wei_md, wei_tag));
- if (with_bias() && bia_md.format_kind == format_kind::any)
- IS_OK(memory_desc_init_by_tag(bia_md, x));
-# undef IS_OK
-
- return true;
- }
-
- bool set_default_alg_kind(alg_kind_t alg_kind) {
- assert(utils::one_of(alg_kind, alg_kind::convolution_direct,
- alg_kind::convolution_winograd));
- if (desc_.alg_kind == alg_kind::convolution_auto)
- desc_.alg_kind = alg_kind;
- return desc_.alg_kind == alg_kind;
- }
-
- bool expect_data_types(data_type_t src_dt, data_type_t wei_dt,
- data_type_t bia_dt, data_type_t dst_dt, data_type_t acc_dt) const {
- bool ok = true
- && (src_dt == data_type::undef || _src_md()->data_type == src_dt)
- && (wei_dt == data_type::undef || _wei_md()->data_type == wei_dt)
- && (dst_dt == data_type::undef || _dst_md()->data_type == dst_dt)
- && (acc_dt == data_type::undef || desc_.accum_data_type == acc_dt);
- if (with_bias() && bia_dt != data_type::undef)
- ok = ok && _bia_md()->data_type == bia_dt;
- return ok;
- }
-
-private:
- const memory_desc_t *_src_md() const { return conv_prop_invariant_src_d(&desc_); }
- const memory_desc_t *_wei_md() const { return conv_prop_invariant_wei_d(&desc_); }
- const memory_desc_t *_bia_md() const { return conv_prop_invariant_bia_d(&desc_); }
- const memory_desc_t *_dst_md() const { return conv_prop_invariant_dst_d(&desc_); }
-};
-
-struct convolution_fwd_pd_t: public convolution_pd_t {
- typedef convolution_fwd_pd_t base_class;
- typedef convolution_fwd_pd_t hint_class;
-
- convolution_fwd_pd_t(engine_t *engine,
- const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
- , src_md_(desc_.src_desc)
- , weights_md_(desc_.weights_desc)
- , bias_md_(desc_.bias_desc)
- , dst_md_(desc_.dst_desc)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_BIAS && with_bias())
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DST)
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override
- { return index == 0 ? &src_md_ : nullptr; }
- virtual const memory_desc_t *dst_md(int index = 0) const override
- { return index == 0 ? &dst_md_ : nullptr; }
- virtual const memory_desc_t *weights_md(int index = 0) const override {
- if (index == 0) return &weights_md_;
- if (index == 1 && with_bias()) return &bias_md_;
- return nullptr;
- }
-
- virtual int n_inputs() const override { return 2 + with_bias(); }
- virtual int n_outputs() const override { return 1; }
-
-protected:
- memory_desc_t src_md_;
- memory_desc_t weights_md_;
- memory_desc_t bias_md_;
- memory_desc_t dst_md_;
-
- bool set_default_formats_common(format_tag_t src_tag,
- format_tag_t wei_tag, format_tag_t dst_tag) {
- return set_default_formats_common_template(src_md_, src_tag,
- weights_md_, wei_tag, dst_md_, dst_tag, bias_md_);
- }
-};
-
-struct convolution_bwd_data_pd_t: public convolution_pd_t {
- typedef convolution_bwd_data_pd_t base_class;
- typedef convolution_fwd_pd_t hint_class;
-
- convolution_bwd_data_pd_t(engine_t *engine,
- const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
- , diff_src_md_(desc_.diff_src_desc)
- , weights_md_(desc_.weights_desc)
- , bias_md_(desc_.bias_desc)
- , diff_dst_md_(desc_.diff_dst_desc)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DIFF_SRC)
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *diff_src_md(int index = 0) const override
- { return index == 0 ? &diff_src_md_ : nullptr; }
- virtual const memory_desc_t *diff_dst_md(int index = 0) const override
- { return index == 0 ? &diff_dst_md_ : nullptr; }
- virtual const memory_desc_t *weights_md(int index = 0) const override {
- if (index == 0) return &weights_md_;
- if (index == 1 && with_bias()) return &bias_md_;
- return nullptr;
- }
-
- virtual int n_inputs() const override { return 2 + with_bias(); }
- virtual int n_outputs() const override { return 1; }
-
- virtual bool support_bias() const { return false; }
-
-protected:
- memory_desc_t diff_src_md_;
- memory_desc_t weights_md_;
- memory_desc_t bias_md_;
- memory_desc_t diff_dst_md_;
-
- bool set_default_formats_common(format_tag_t diff_src_tag,
- format_tag_t wei_tag, format_tag_t diff_dst_tag) {
- return set_default_formats_common_template(diff_src_md_, diff_src_tag,
- weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, bias_md_);
- }
-};
-
-struct convolution_bwd_weights_pd_t: public convolution_pd_t {
- typedef convolution_bwd_weights_pd_t base_class;
- typedef convolution_fwd_pd_t hint_class;
-
- convolution_bwd_weights_pd_t(engine_t *engine,
- const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
- , src_md_(desc_.src_desc)
- , diff_weights_md_(desc_.diff_weights_desc)
- , diff_bias_md_(desc_.diff_bias_desc)
- , diff_dst_md_(desc_.diff_dst_desc)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
- return arg_usage_t::output;
-
- if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override
- { return index == 0 ? &src_md_ : nullptr; }
- virtual const memory_desc_t *diff_dst_md(int index = 0) const override
- { return index == 0 ? &diff_dst_md_ : nullptr; }
- virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
- if (index == 0) return &diff_weights_md_;
- if (index == 1 && with_bias()) return &diff_bias_md_;
- return nullptr;
- }
-
- virtual int n_inputs() const override { return 2; }
- virtual int n_outputs() const override { return 1 + with_bias(); }
-
-protected:
- memory_desc_t src_md_;
- memory_desc_t diff_weights_md_;
- memory_desc_t diff_bias_md_;
- memory_desc_t diff_dst_md_;
-
- bool set_default_formats_common(format_tag_t src_tag,
- format_tag_t diff_wei_tag, format_tag_t diff_dst_tag) {
- return set_default_formats_common_template(src_md_, src_tag,
- diff_weights_md_, diff_wei_tag, diff_dst_md_, diff_dst_tag,
- diff_bias_md_);
- }
-};
-
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp
deleted file mode 100644
index 98063c1c37..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp
+++ /dev/null
@@ -1,188 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "mkldnn.h"
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::prop_kind;
-using namespace mkldnn::impl::alg_kind;
-using namespace mkldnn::impl::types;
-
-namespace {
-status_t deconv_desc_init(deconvolution_desc_t *deconv_desc,
- prop_kind_t prop_kind, alg_kind_t alg_kind,
- const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
- const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
- const dims_t strides, const dims_t dilates, const dims_t padding_l,
- const dims_t padding_r, padding_kind_t padding_kind) {
- bool args_ok = true
- && !any_null(deconv_desc, src_desc, weights_desc, dst_desc, strides,
- padding_l)
- && one_of(alg_kind, deconvolution_direct, deconvolution_winograd)
- && one_of(padding_kind, padding_kind::padding_zero);
- if (!args_ok)
- return invalid_arguments;
-
- if (padding_r == nullptr)
- padding_r = padding_l;
-
- auto dd = deconvolution_desc_t();
- dd.primitive_kind = primitive_kind::deconvolution;
- dd.prop_kind = prop_kind;
- dd.alg_kind = alg_kind;
-
- dd.diff_src_desc = dd.src_desc = zero_md();
- dd.diff_dst_desc = dd.dst_desc = zero_md();
- dd.diff_weights_desc = dd.weights_desc = zero_md();
- dd.diff_bias_desc = dd.bias_desc = zero_md();
-
- const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
- const bool with_bias
- = bias_desc && bias_desc->format_kind != format_kind::undef;
- const bool with_groups = weights_desc->ndims == src_desc->ndims + 1;
-
- (prop_kind == backward_data ? dd.diff_src_desc : dd.src_desc) = *src_desc;
- (is_fwd ? dd.dst_desc : dd.diff_dst_desc) = *dst_desc;
- (prop_kind == backward_weights ? dd.diff_weights_desc : dd.weights_desc)
- = *weights_desc;
- if (with_bias)
- (prop_kind == backward_weights ? dd.diff_bias_desc : dd.bias_desc)
- = *bias_desc;
-
- int sp_dims = src_desc->ndims - 2;
- utils::array_copy(dd.strides, strides, sp_dims);
- utils::array_copy(dd.padding[0], padding_l, sp_dims);
- utils::array_copy(dd.padding[1], padding_r, sp_dims);
- if (dilates)
- utils::array_copy(dd.dilates, dilates, sp_dims);
- else
- utils::array_set(dd.dilates, 0, sp_dims);
-
- dd.padding_kind = padding_kind;
- dd.accum_data_type = types::default_accum_data_type(src_desc->data_type,
- weights_desc->data_type, dst_desc->data_type, prop_kind);
-
- const int g = with_groups ? weights_desc->dims[0] : 1;
- bool consistency = true
- && src_desc->ndims == dst_desc->ndims
- && utils::one_of(src_desc->ndims, 3, 4, 5)
- && utils::one_of(weights_desc->ndims, src_desc->ndims,
- src_desc->ndims + 1)
- && (with_bias ? bias_desc->ndims == 1 : true)
- && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true)
- && src_desc->dims[0] == dst_desc->dims[0]
- && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1]
- && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0];
- for (int i = 2; i < src_desc->ndims; ++i) {
- int src = src_desc->dims[i];
- int ker = weights_desc->dims[with_groups + i];
- int dil = dd.dilates[i - 2];
- int pad = padding_l[i - 2] + padding_r[i - 2];
- int str = strides[i - 2];
- int dst = dst_desc->dims[i];
- int ker_range = 1 + (ker - 1) * (dil + 1);
-
- consistency
- = consistency && (dst - ker_range + pad) / str + 1 == src;
- }
- if (!consistency)
- return invalid_arguments;
-
- *deconv_desc = dd;
- return success;
-}
-}
-
-status_t mkldnn_deconvolution_forward_desc_init(
- deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind,
- alg_kind_t alg_kind, const memory_desc_t *src_desc,
- const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
- const memory_desc_t *dst_desc, const dims_t strides,
- const dims_t padding_l, const dims_t padding_r,
- padding_kind_t padding_kind) {
- if (!one_of(prop_kind, forward_training, forward_inference))
- return invalid_arguments;
- return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc,
- weights_desc, bias_desc, dst_desc, strides, nullptr, padding_l,
- padding_r, padding_kind);
-}
-
-status_t mkldnn_dilated_deconvolution_forward_desc_init(
- deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind,
- alg_kind_t alg_kind, const memory_desc_t *src_desc,
- const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
- const memory_desc_t *dst_desc, const dims_t strides,
- const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
- padding_kind_t padding_kind) {
- if (!one_of(prop_kind, forward_training, forward_inference))
- return invalid_arguments;
- return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc,
- weights_desc, bias_desc, dst_desc, strides, dilates, padding_l,
- padding_r, padding_kind);
-}
-
-status_t mkldnn_deconvolution_backward_data_desc_init(
- deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
- const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
- const memory_desc_t *diff_dst_desc, const dims_t strides,
- const dims_t padding_l, const dims_t padding_r,
- padding_kind_t padding_kind) {
- return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc,
- weights_desc, nullptr, diff_dst_desc, strides, nullptr, padding_l,
- padding_r, padding_kind);
-}
-
-status_t mkldnn_dilated_deconvolution_backward_data_desc_init(
- deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
- const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
- const memory_desc_t *diff_dst_desc, const dims_t strides,
- const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
- padding_kind_t padding_kind) {
- return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc,
- weights_desc, nullptr, diff_dst_desc, strides,dilates, padding_l,
- padding_r, padding_kind);
-}
-
-status_t mkldnn_deconvolution_backward_weights_desc_init(
- deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
- const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
- const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc,
- const dims_t strides, const dims_t padding_l, const dims_t padding_r,
- padding_kind_t padding_kind) {
- return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc,
- diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, nullptr,
- padding_l, padding_r, padding_kind);
-}
-
-status_t mkldnn_dilated_deconvolution_backward_weights_desc_init(
- deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
- const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
- const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc,
- const dims_t strides, const dims_t dilates, const dims_t padding_l,
- const dims_t padding_r, padding_kind_t padding_kind) {
- return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc,
- diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, dilates,
- padding_l, padding_r, padding_kind);
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp
deleted file mode 100644
index 539e44bd9b..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp
+++ /dev/null
@@ -1,293 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef DECONVOLUTION_PD_HPP
-#define DECONVOLUTION_PD_HPP
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "convolution_pd.hpp"
-#include "primitive_desc.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-struct deconvolution_fwd_pd_t;
-
-struct deconvolution_pd_t: public primitive_desc_t {
- static constexpr auto base_pkind = primitive_kind::deconvolution;
-
- deconvolution_pd_t(engine_t *engine,
- const deconvolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const deconvolution_fwd_pd_t *hint_fwd_pd)
- : primitive_desc_t(engine, attr, base_pkind)
- , desc_(*adesc)
- , hint_fwd_pd_(hint_fwd_pd)
- {}
-
- const deconvolution_desc_t *desc() const { return &desc_; }
- virtual const op_desc_t *op_desc() const override
- { return reinterpret_cast<const op_desc_t *>(this->desc()); }
- virtual void init_info() override { impl::init_info(this, this->info_); }
-
- virtual status_t query(query_t what, int idx, void *result) const override {
- switch (what) {
- case pkind_traits<base_pkind>::query_d:
- *(const deconvolution_desc_t **)result = desc();
- break;
- default: return primitive_desc_t::query(what, idx, result);
- }
- return status::success;
- }
-
- /* common deconv aux functions (note that conv_desc_t == deconv_desc_t) */
-
- dim_t MB() const { return conv_prop_invariant_src_d(&desc_)->dims[0]; }
-
- dim_t IC() const { return conv_prop_invariant_src_d(&desc_)->dims[1]; }
- dim_t OC() const { return conv_prop_invariant_dst_d(&desc_)->dims[1]; }
- dim_t G() const
- { return with_groups() ? conv_prop_invariant_wei_d(&desc_)->dims[0] : 1; }
-
- dim_t ID() const {
- return ndims() >= 5
- ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1;
- }
- dim_t IH() const {
- return ndims() >= 4
- ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1;
- }
- dim_t IW() const {
- return conv_prop_invariant_src_d(&desc_)->dims[ndims() - 1];
- }
-
- dim_t OD() const {
- return ndims() >= 5
- ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1;
- }
- dim_t OH() const {
- return ndims() >= 4
- ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1;
- }
- dim_t OW() const {
- return conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 1];
- }
-
- dim_t KD() const {
- const int w_ndims = ndims() + with_groups();
- return ndims() >= 5
- ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 3] : 1;
- }
- dim_t KH() const {
- const int w_ndims = ndims() + with_groups();
- return ndims() >= 4
- ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 2] : 1;
- }
- dim_t KW() const {
- const int w_ndims = ndims() + with_groups();
- return conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 1];
- }
-
- dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
- dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
- dim_t KSW() const { return desc_.strides[ndims() - 3]; }
-
- dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; }
- dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; }
- dim_t KDW() const { return desc_.dilates[ndims() - 3]; }
-
- dim_t padFront() const
- { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
- dim_t padBack() const
- { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
- dim_t padT() const
- { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
- dim_t padB() const
- { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
- dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
- dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
-
- bool with_bias() const {
- return
- !memory_desc_wrapper(*conv_prop_invariant_bia_d(&desc_)).is_zero();
- }
-
- bool with_groups() const
- { return conv_prop_invariant_wei_d(&desc_)->ndims == ndims() + 1; }
-
- int ndims() const { return conv_prop_invariant_src_d(&desc_)->ndims; }
-
- bool is_fwd() const {
- return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
- prop_kind::forward_inference);
- }
-
- bool has_zero_dim_memory() const {
- const auto s_d = memory_desc_wrapper(*conv_prop_invariant_src_d(&desc_));
- const auto d_d = memory_desc_wrapper(*conv_prop_invariant_dst_d(&desc_));
- return s_d.has_zero_dim() || d_d.has_zero_dim();
- }
-
-protected:
- deconvolution_desc_t desc_;
- const deconvolution_fwd_pd_t *hint_fwd_pd_;
-};
-
-struct deconvolution_fwd_pd_t: public deconvolution_pd_t {
- typedef deconvolution_fwd_pd_t base_class;
- typedef deconvolution_fwd_pd_t hint_class;
-
- deconvolution_fwd_pd_t(engine_t *engine,
- const deconvolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const deconvolution_fwd_pd_t *hint_fwd_pd)
- : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
- , src_md_(desc_.src_desc)
- , weights_md_(desc_.weights_desc)
- , bias_md_(desc_.bias_desc)
- , dst_md_(desc_.dst_desc)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_BIAS && with_bias())
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DST)
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override
- { return index == 0 ? &src_md_ : nullptr; }
- virtual const memory_desc_t *dst_md(int index = 0) const override
- { return index == 0 ? &dst_md_ : nullptr; }
- virtual const memory_desc_t *weights_md(int index = 0) const override {
- if (index == 0) return &weights_md_;
- if (index == 1 && with_bias()) return &bias_md_;
- return nullptr;
- }
-
- virtual int n_inputs() const override { return 2 + with_bias(); }
- virtual int n_outputs() const override { return 1; }
-
-protected:
- memory_desc_t src_md_;
- memory_desc_t weights_md_;
- memory_desc_t bias_md_;
- memory_desc_t dst_md_;
-};
-
-struct deconvolution_bwd_data_pd_t: public deconvolution_pd_t {
- typedef deconvolution_bwd_data_pd_t base_class;
- typedef deconvolution_fwd_pd_t hint_class;
-
- deconvolution_bwd_data_pd_t(engine_t *engine,
- const deconvolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const deconvolution_fwd_pd_t *hint_fwd_pd)
- : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
- , diff_src_md_(desc_.diff_src_desc)
- , weights_md_(desc_.weights_desc)
- , diff_dst_md_(desc_.diff_dst_desc)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DIFF_SRC)
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *diff_src_md(int index = 0) const override
- { return index == 0 ? &diff_src_md_ : nullptr; }
- virtual const memory_desc_t *diff_dst_md(int index = 0) const override
- { return index == 0 ? &diff_dst_md_ : nullptr; }
- virtual const memory_desc_t *weights_md(int index = 0) const override
- { return index == 0 ? &weights_md_ : nullptr; }
-
- virtual int n_inputs() const override { return 2; }
- virtual int n_outputs() const override { return 1; }
-
-protected:
- memory_desc_t diff_src_md_;
- memory_desc_t weights_md_;
- memory_desc_t diff_dst_md_;
-};
-
-struct deconvolution_bwd_weights_pd_t: public deconvolution_pd_t {
- typedef deconvolution_bwd_weights_pd_t base_class;
- typedef deconvolution_fwd_pd_t hint_class;
-
- deconvolution_bwd_weights_pd_t(engine_t *engine,
- const deconvolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const deconvolution_fwd_pd_t *hint_fwd_pd)
- : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
- , src_md_(desc_.src_desc)
- , diff_weights_md_(desc_.diff_weights_desc)
- , diff_bias_md_(desc_.diff_bias_desc)
- , diff_dst_md_(desc_.diff_dst_desc)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
- return arg_usage_t::output;
-
- if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override
- { return index == 0 ? &src_md_ : nullptr; }
- virtual const memory_desc_t *diff_dst_md(int index = 0) const override
- { return index == 0 ? &diff_dst_md_ : nullptr; }
- virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
- if (index == 0) return &diff_weights_md_;
- if (index == 1 && with_bias()) return &diff_bias_md_;
- return nullptr;
- }
-
- virtual int n_inputs() const override { return 2; }
- virtual int n_outputs() const override { return 1 + with_bias(); }
-
-protected:
- memory_desc_t src_md_;
- memory_desc_t diff_weights_md_;
- memory_desc_t diff_bias_md_;
- memory_desc_t diff_dst_md_;
-};
-
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp
deleted file mode 100644
index f1708fca52..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp
+++ /dev/null
@@ -1,84 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::prop_kind;
-using namespace mkldnn::impl::alg_kind;
-using namespace mkldnn::impl::types;
-
-namespace {
-status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind,
- alg_kind_t alg_kind, const memory_desc_t *data_desc,
- const memory_desc_t *diff_data_desc, float alpha, float beta) {
- bool args_ok = true
- && !any_null(eltwise_desc, data_desc)
- && one_of(prop_kind, forward_training, forward_inference,
- backward_data)
- && one_of(alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu,
- eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
- eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)
- && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr);
- if (!args_ok) return invalid_arguments;
-
- auto ed = eltwise_desc_t();
- ed.primitive_kind = primitive_kind::eltwise;
- ed.prop_kind = prop_kind;
- ed.alg_kind = alg_kind;
-
- ed.data_desc = *data_desc;
- ed.diff_data_desc =
- (ed.prop_kind == backward_data) ? *diff_data_desc : zero_md();
-
- ed.alpha = alpha;
- ed.beta = beta;
-
- bool consistency = true
- && IMPLICATION(ed.prop_kind == backward_data,
- array_cmp(ed.diff_data_desc.dims, ed.data_desc.dims,
- ed.diff_data_desc.ndims));
- if (!consistency) return invalid_arguments;
-
- *eltwise_desc = ed;
- return success;
-}
-}
-
-status_t mkldnn_eltwise_forward_desc_init(eltwise_desc_t *eltwise_desc,
- prop_kind_t prop_kind, alg_kind_t alg_kind,
- const memory_desc_t *data_desc, float alpha, float beta) {
- if (!one_of(prop_kind, forward_training, forward_inference))
- return invalid_arguments;
- return eltwise_desc_init(eltwise_desc, prop_kind, alg_kind, data_desc,
- nullptr, alpha, beta);
-}
-
-status_t mkldnn_eltwise_backward_desc_init(eltwise_desc_t *eltwise_desc,
- alg_kind_t alg_kind, const memory_desc_t *diff_data_desc,
- const memory_desc_t *data_desc, float alpha, float beta) {
- return eltwise_desc_init(eltwise_desc, backward_data, alg_kind, data_desc,
- diff_data_desc, alpha, beta);
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp
deleted file mode 100644
index 9fd260fcee..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp
+++ /dev/null
@@ -1,161 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef ELTWISE_PD_HPP
-#define ELTWISE_PD_HPP
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "primitive_desc.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-struct eltwise_fwd_pd_t;
-
-struct eltwise_pd_t: public primitive_desc_t {
- static constexpr auto base_pkind = primitive_kind::eltwise;
-
- eltwise_pd_t(mkldnn::impl::engine_t *engine,
- const eltwise_desc_t *adesc,
- const primitive_attr_t *attr,
- const eltwise_fwd_pd_t *hint_fwd_pd)
- : primitive_desc_t(engine, attr, base_pkind)
- , desc_(*adesc)
- , hint_fwd_pd_(hint_fwd_pd)
- , data_md_(desc_.data_desc)
- {}
-
- const eltwise_desc_t *desc() const { return &desc_; }
- virtual const op_desc_t *op_desc() const override
- { return reinterpret_cast<const op_desc_t *>(this->desc()); }
- virtual void init_info() override { impl::init_info(this, this->info_); }
-
- virtual status_t query(query_t what, int idx, void *result) const override {
- switch (what) {
- case query::eltwise_d:
- *(const eltwise_desc_t**)result = desc(); break;
- default: return primitive_desc_t::query(what, idx, result);
- }
- return status::success;
- }
-
- /* common eltwise aux functions */
-
- dim_t MB() const { return data_desc().dims[0]; }
- dim_t C() const { return data_desc().dims[1]; }
- dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
- dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
- dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
-
- int ndims() const { return data_desc().ndims; }
-
- bool is_fwd() const {
- return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
- prop_kind::forward_inference);
- }
-
- bool has_zero_dim_memory() const
- { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
-
-protected:
- eltwise_desc_t desc_;
- const eltwise_fwd_pd_t *hint_fwd_pd_;
-
- memory_desc_t data_md_;
-
-private:
- const memory_desc_t &data_desc() const { return desc_.data_desc; }
-};
-
-struct eltwise_fwd_pd_t: public eltwise_pd_t {
- typedef eltwise_fwd_pd_t base_class;
- typedef eltwise_fwd_pd_t hint_class;
-
- eltwise_fwd_pd_t(mkldnn::impl::engine_t *engine,
- const eltwise_desc_t *adesc,
- const primitive_attr_t *attr,
- const eltwise_fwd_pd_t *hint_fwd_pd)
- : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (arg == MKLDNN_ARG_SRC)
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DST)
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override
- { return index == 0 ? &data_md_ : nullptr; }
- virtual const memory_desc_t *dst_md(int index = 0) const override
- { return index == 0 ? &data_md_ : nullptr; }
-
- virtual int n_inputs() const override { return 1; }
- virtual int n_outputs() const override { return 1; }
-
- bool is_zero_preserved() const
- { return math::eltwise_fwd_preserves_zero(desc_.alg_kind); }
-};
-
-struct eltwise_bwd_pd_t: public eltwise_pd_t {
- typedef eltwise_bwd_pd_t base_class;
- typedef eltwise_fwd_pd_t hint_class;
-
- eltwise_bwd_pd_t(engine_t *engine,
- const eltwise_desc_t *adesc,
- const primitive_attr_t *attr,
- const eltwise_fwd_pd_t *hint_fwd_pd)
- : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd)
- , diff_data_md_(desc_.diff_data_desc)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DIFF_SRC)
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override
- { return index == 0 ? &data_md_ : nullptr; }
- virtual const memory_desc_t *diff_dst_md(int index = 0) const override
- { return index == 0 ? &diff_data_md_ : nullptr; }
- virtual const memory_desc_t *diff_src_md(int index = 0) const override
- { return index == 0 ? &diff_data_md_ : nullptr; }
-
- virtual int n_inputs() const override { return 2; }
- virtual int n_outputs() const override { return 1; }
-
- bool is_zero_preserved() const { return true; }
-
-protected:
- memory_desc_t diff_data_md_;
-};
-
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/engine.cpp b/thirdparty/oidn/mkl-dnn/src/common/engine.cpp
deleted file mode 100644
index 3b3e25456d..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/engine.cpp
+++ /dev/null
@@ -1,75 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "mkldnn.h"
-#include "engine.hpp"
-#include "nstl.hpp"
-
-#include "c_types_map.hpp"
-#include "../cpu/cpu_engine.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-engine_factory_t *engine_factories[] = {
- &cpu::engine_factory,
- nullptr,
-};
-
-static inline engine_factory_t *get_engine_factory(engine_kind_t kind) {
- for (engine_factory_t **ef = engine_factories; *ef; ef++)
- if ((*ef)->kind() == kind)
- return *ef;
- return nullptr;
-}
-
-}
-}
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::status;
-
-size_t mkldnn_engine_get_count(engine_kind_t kind) {
- engine_factory_t *ef = get_engine_factory(kind);
- return ef != nullptr ? ef->count() : 0;
-}
-
-status_t mkldnn_engine_create(engine_t **engine,
- engine_kind_t kind, size_t index) {
- if (engine == nullptr)
- return invalid_arguments;
-
- engine_factory_t *ef = get_engine_factory(kind);
- if (ef == nullptr || index >= ef->count())
- return invalid_arguments;
-
- return ef->engine_create(engine, index);
-}
-
-status_t mkldnn_engine_get_kind(engine_t *engine, engine_kind_t *kind) {
- if (engine == nullptr)
- return invalid_arguments;
- *kind = engine->kind();
- return success;
-}
-
-status_t mkldnn_engine_destroy(engine_t *engine) {
- /* TODO: engine->dec_ref_count(); */
- delete engine;
- return success;
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/engine.hpp b/thirdparty/oidn/mkl-dnn/src/common/engine.hpp
deleted file mode 100644
index 8ac8a29de5..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/engine.hpp
+++ /dev/null
@@ -1,119 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef ENGINE_HPP
-#define ENGINE_HPP
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "primitive.hpp"
-#include "utils.hpp"
-
-/** \brief An abstraction of an execution unit with shared resources
- *
- * Responsibilities:
- * - Provide engine specific memory allocation
- * - Provide engine specific primitive_desc_t creators
- */
-struct mkldnn_engine: public mkldnn::impl::c_compatible {
- mkldnn_engine(mkldnn::impl::engine_kind_t kind)
- : kind_(kind)
- {}
- virtual ~mkldnn_engine() {}
-
- /** get kind of the current engine */
- virtual mkldnn::impl::engine_kind_t kind() const { return kind_; }
-
- /** allocate memory */
- virtual mkldnn::impl::status_t memory_create(
- mkldnn::impl::memory_t **memory,
- const mkldnn::impl::memory_desc_t *md,
- void *handle) = 0;
-
- /** implementation section (typedefs) */
-
- // TODO: remove engine?
- typedef mkldnn::impl::status_t (*reorder_primitive_desc_create_f)(
- mkldnn::impl::reorder_pd_t **reorder_pd,
- mkldnn::impl::engine_t *engine,
- const mkldnn::impl::primitive_attr_t *attr,
- mkldnn::impl::engine_t *src_engine,
- const mkldnn::impl::memory_desc_t *src_md,
- mkldnn::impl::engine_t *dst_engine,
- const mkldnn::impl::memory_desc_t *dst_md);
-
- typedef mkldnn::impl::status_t (*concat_primitive_desc_create_f)(
- mkldnn::impl::concat_pd_t **concat_pd,
- mkldnn::impl::engine_t *engine,
- const mkldnn::impl::primitive_attr_t *attr,
- const mkldnn::impl::memory_desc_t *dst_md,
- int n, int concat_dim,
- const mkldnn::impl::memory_desc_t *src_mds);
-
- typedef mkldnn::impl::status_t (*sum_primitive_desc_create_f)(
- mkldnn::impl::sum_pd_t **sum_pd,
- mkldnn::impl::engine_t *engine,
- const mkldnn::impl::primitive_attr_t *attr,
- const mkldnn::impl::memory_desc_t *dst_md,
- int n, const float *scales,
- const mkldnn::impl::memory_desc_t *src_mds);
-
- typedef mkldnn::impl::status_t (*primitive_desc_create_f)(
- mkldnn::impl::primitive_desc_t **, const mkldnn::impl::op_desc_t *,
- const mkldnn::impl::primitive_attr_t *attr,
- mkldnn::impl::engine_t *, const mkldnn::impl::primitive_desc_t *);
-
- /* implementation section */
-
- /** return the list of reorder implementations. engine guarantees to return
- * a NULL-terminated list */
- virtual const reorder_primitive_desc_create_f*
- get_reorder_implementation_list() const = 0;
-
- /** return the list of concat implementations. engine guarantees to return
- * a NULL-terminated list */
- virtual const concat_primitive_desc_create_f*
- get_concat_implementation_list() const = 0;
-
- /** return the list of sum implementations. engine guarantees to return
- * a NULL-terminated list */
- virtual const sum_primitive_desc_create_f*
- get_sum_implementation_list() const = 0;
-
- /** return the list of implementations. engine guarantees to return a
- * NULL-terminated list */
- virtual const primitive_desc_create_f* get_implementation_list() const = 0;
-
-protected:
- mkldnn::impl::engine_kind_t kind_;
-};
-
-namespace mkldnn {
-namespace impl {
-
-struct engine_factory_t: public c_compatible {
- virtual size_t count() const = 0;
- virtual engine_kind_t kind() const = 0;
- virtual status_t engine_create(engine_t **engine, size_t index) const = 0;
-};
-
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp
deleted file mode 100644
index 5a9f58cb1e..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp
+++ /dev/null
@@ -1,106 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::prop_kind;
-using namespace mkldnn::impl::types;
-
-namespace {
-status_t ip_desc_init(inner_product_desc_t *ip_desc, prop_kind_t prop_kind,
- const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
- const memory_desc_t *bias_desc, const memory_desc_t *dst_desc) {
- bool args_ok = !any_null(ip_desc, src_desc, weights_desc, dst_desc);
- if (!args_ok) return invalid_arguments;
-
- auto id = inner_product_desc_t();
- id.primitive_kind = primitive_kind::inner_product;
- id.prop_kind = prop_kind;
-
- id.diff_src_desc = id.src_desc = zero_md();
- id.diff_dst_desc = id.dst_desc = zero_md();
- id.diff_weights_desc = id.weights_desc = zero_md();
- id.diff_bias_desc = id.bias_desc = zero_md();
-
- const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
- const bool with_bias =
- bias_desc && bias_desc->format_kind != format_kind::undef;
-
- (prop_kind == backward_data ? id.diff_src_desc : id.src_desc) = *src_desc;
- (is_fwd ? id.dst_desc : id.diff_dst_desc) = *dst_desc;
- (prop_kind == backward_weights ? id.diff_weights_desc : id.weights_desc) =
- *weights_desc;
- if (with_bias)
- (prop_kind == backward_weights ? id.diff_bias_desc : id.bias_desc) =
- *bias_desc;
-
- id.accum_data_type = types::default_accum_data_type(src_desc->data_type,
- weights_desc->data_type, dst_desc->data_type, prop_kind);
-
- bool consistency = true
- && memory_desc_wrapper(weights_desc).nelems()
- && one_of(src_desc->ndims, 2, 3, 4, 5)
- && dst_desc->ndims == 2
- && weights_desc->ndims == src_desc->ndims
- && (with_bias ? bias_desc->ndims == 1 : true)
- && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true)
- && src_desc->dims[0] == dst_desc->dims[0]
- && array_cmp(&src_desc->dims[1], &weights_desc->dims[1],
- src_desc->ndims - 1)
- && dst_desc->dims[1] == weights_desc->dims[0];
- if (!consistency) return invalid_arguments;
-
- *ip_desc = id;
- return success;
-}
-}
-
-status_t mkldnn_inner_product_forward_desc_init(inner_product_desc_t *ip_desc,
- prop_kind_t prop_kind, const memory_desc_t *src_desc,
- const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
- const memory_desc_t *dst_desc) {
- if (!one_of(prop_kind, forward_training, forward_inference))
- return invalid_arguments;
- return ip_desc_init(ip_desc, prop_kind, src_desc, weights_desc, bias_desc,
- dst_desc);
-}
-
-status_t mkldnn_inner_product_backward_data_desc_init(
- inner_product_desc_t *ip_desc, const memory_desc_t *diff_src_desc,
- const memory_desc_t *weights_desc, const memory_desc_t *diff_dst_desc)
-{
- return ip_desc_init(ip_desc, backward_data, diff_src_desc, weights_desc,
- nullptr, diff_dst_desc);
-}
-
-status_t mkldnn_inner_product_backward_weights_desc_init(
- inner_product_desc_t *ip_desc, const memory_desc_t *src_desc,
- const memory_desc_t *diff_weights_desc,
- const memory_desc_t *diff_bias_desc,
- const memory_desc_t *diff_dst_desc) {
- return ip_desc_init(ip_desc, backward_weights, src_desc, diff_weights_desc,
- diff_bias_desc, diff_dst_desc);
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp
deleted file mode 100644
index 091cf0f5d6..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp
+++ /dev/null
@@ -1,56 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "utils.hpp"
-
-#include "inner_product_pd.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-using namespace prop_kind;
-
-memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc) {
- return desc->prop_kind == backward_data
- ? &desc->diff_src_desc : &desc->src_desc;
-}
-
-memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc) {
- return desc->prop_kind == backward_weights
- ? &desc->diff_weights_desc : &desc->weights_desc;
-}
-
-memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc) {
- return desc->prop_kind == backward_weights
- ? &desc->diff_bias_desc : &desc->bias_desc;
-}
-
-memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc) {
- return utils::one_of(desc->prop_kind, forward_inference, forward_training)
- ? &desc->dst_desc : &desc->diff_dst_desc;
-}
-
-const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc)
-{ return ip_prop_invariant_src_d(const_cast<inner_product_desc_t *>(desc)); }
-const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc)
-{ return ip_prop_invariant_wei_d(const_cast<inner_product_desc_t *>(desc)); }
-const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc)
-{ return ip_prop_invariant_bia_d(const_cast<inner_product_desc_t *>(desc)); }
-const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc)
-{ return ip_prop_invariant_dst_d(const_cast<inner_product_desc_t *>(desc)); }
-
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp
deleted file mode 100644
index c426de632c..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp
+++ /dev/null
@@ -1,321 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef INNER_PRODUCT_PD_HPP
-#define INNER_PRODUCT_PD_HPP
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "primitive_desc.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc);
-memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc);
-memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc);
-memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc);
-const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc);
-const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc);
-const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc);
-const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc);
-
-struct inner_product_fwd_pd_t;
-
-struct inner_product_pd_t: public primitive_desc_t {
- static constexpr auto base_pkind = primitive_kind::inner_product;
-
- inner_product_pd_t(engine_t *engine,
- const inner_product_desc_t *adesc,
- const primitive_attr_t *attr,
- const inner_product_fwd_pd_t *hint_fwd_pd)
- : primitive_desc_t(engine, attr, base_pkind)
- , desc_(*adesc)
- , hint_fwd_pd_(hint_fwd_pd)
- {}
-
- const inner_product_desc_t *desc() const { return &desc_; }
- virtual const op_desc_t *op_desc() const override
- { return reinterpret_cast<const op_desc_t *>(this->desc()); }
- virtual void init_info() override { impl::init_info(this, this->info_); }
-
- virtual status_t query(query_t what, int idx, void *result) const override {
- switch (what) {
- case query::inner_product_d:
- *(const inner_product_desc_t**)result = desc(); break;
- default: return primitive_desc_t::query(what, idx, result);
- }
- return status::success;
- }
-
- /* common inner_product aux functions */
-
- dim_t MB() const { return ip_prop_invariant_src_d(&desc_)->dims[0]; }
- dim_t IC() const { return ip_prop_invariant_src_d(&desc_)->dims[1]; }
- dim_t OC() const { return ip_prop_invariant_dst_d(&desc_)->dims[1]; }
-
- dim_t ID() const {
- return ndims() >= 5
- ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1;
- }
- dim_t IH() const {
- return ndims() >= 4
- ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1;
- }
- dim_t IW() const {
- return ndims() >= 3
- ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 1] : 1;
- }
-
- dim_t OD() const {
- return ndims() >= 5
- ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1;
- }
- dim_t OH() const {
- return ndims() >= 4
- ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1;
- }
- dim_t OW() const {
- return ndims() >= 3
- ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 1] : 1;
- }
-
- dim_t KD() const {
- return ndims() >= 5
- ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 3] : 1;
- }
- dim_t KH() const {
- return ndims() >= 4
- ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 2] : 1;
- }
- dim_t KW() const {
- return ndims() >= 3
- ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 1] : 1;
- }
-
- dim_t IC_total() const {
- return utils::array_product(&ip_prop_invariant_src_d(&desc_)->dims[1],
- ndims() - 1);
- }
-
- dim_t IC_total_padded() const {
- auto src_d = desc()->prop_kind == prop_kind::backward_data
- ? memory_desc_wrapper(diff_src_md())
- : memory_desc_wrapper(src_md());
- assert(src_d.is_blocking_desc());
- if (!src_d.is_blocking_desc()) return -1;
- return utils::array_product(src_d.padded_dims() + 1, ndims() - 1);
- }
-
- int ndims() const { return ip_prop_invariant_src_d(&desc_)->ndims; }
-
- bool with_bias() const
- { return !memory_desc_wrapper(*ip_prop_invariant_bia_d(&desc_)).is_zero(); }
-
- bool has_zero_dim_memory() const {
- const auto s_d = memory_desc_wrapper(*ip_prop_invariant_src_d(&desc_));
- const auto d_d = memory_desc_wrapper(*ip_prop_invariant_dst_d(&desc_));
- return s_d.has_zero_dim() || d_d.has_zero_dim();
- }
-
- bool is_fwd() const {
- return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
- prop_kind::forward_inference);
- }
-
-protected:
- inner_product_desc_t desc_;
- const inner_product_fwd_pd_t *hint_fwd_pd_;
-
- status_t template_set_default_params(memory_desc_t &src_md,
- memory_desc_t &weights_md, memory_desc_t &dst_md,
- memory_desc_t *bias_md) {
- using namespace format_tag;
- if (src_md.format_kind == format_kind::any) {
- CHECK(memory_desc_init_by_tag(src_md,
- utils::pick(ndims() - 2, nc, ncw, nchw, ncdhw)));
- }
- if (dst_md.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(dst_md, nc));
- if (weights_md.format_kind == format_kind::any) {
- CHECK(memory_desc_init_by_tag(weights_md,
- utils::pick(ndims() - 2, oi, oiw, oihw, oidhw)));
- }
- if (bias_md && bias_md->format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(*bias_md, x));
- return status::success;
- }
-};
-
-struct inner_product_fwd_pd_t: public inner_product_pd_t {
- typedef inner_product_fwd_pd_t base_class;
- typedef inner_product_fwd_pd_t hint_class;
-
- inner_product_fwd_pd_t(engine_t *engine,
- const inner_product_desc_t *adesc,
- const primitive_attr_t *attr,
- const inner_product_fwd_pd_t *hint_fwd_pd)
- : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
- , src_md_(desc_.src_desc)
- , weights_md_(desc_.weights_desc)
- , bias_md_(desc_.bias_desc)
- , dst_md_(desc_.dst_desc)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_BIAS && with_bias())
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DST)
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override
- { return index == 0 ? &src_md_ : nullptr; }
- virtual const memory_desc_t *dst_md(int index = 0) const override
- { return index == 0 ? &dst_md_ : nullptr; }
- virtual const memory_desc_t *weights_md(int index = 0) const override {
- if (index == 0) return &weights_md_;
- if (index == 1 && with_bias()) return &bias_md_;
- return nullptr;
- }
-
- virtual int n_inputs() const override { return 2 + with_bias(); }
- virtual int n_outputs() const override { return 1; }
-
-protected:
- memory_desc_t src_md_;
- memory_desc_t weights_md_;
- memory_desc_t bias_md_;
- memory_desc_t dst_md_;
-
- status_t set_default_params() {
- return template_set_default_params(src_md_, weights_md_, dst_md_,
- &bias_md_);
- }
-};
-
-struct inner_product_bwd_data_pd_t: public inner_product_pd_t {
- typedef inner_product_bwd_data_pd_t base_class;
- typedef inner_product_fwd_pd_t hint_class;
-
- inner_product_bwd_data_pd_t(engine_t *engine,
- const inner_product_desc_t *adesc,
- const primitive_attr_t *attr,
- const inner_product_fwd_pd_t *hint_fwd_pd)
- : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
- , diff_src_md_(desc_.diff_src_desc)
- , weights_md_(desc_.weights_desc)
- , diff_dst_md_(desc_.diff_dst_desc)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DIFF_SRC)
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *diff_src_md(int index = 0) const override
- { return index == 0 ? &diff_src_md_ : nullptr; }
- virtual const memory_desc_t *diff_dst_md(int index = 0) const override
- { return index == 0 ? &diff_dst_md_ : nullptr; }
- virtual const memory_desc_t *weights_md(int index = 0) const override
- { return index == 0 ? &weights_md_ : nullptr; }
-
- virtual int n_inputs() const override { return 2; }
- virtual int n_outputs() const override { return 1; }
-
-protected:
- memory_desc_t diff_src_md_;
- memory_desc_t weights_md_;
- memory_desc_t diff_dst_md_;
-
- status_t set_default_params() {
- return template_set_default_params(diff_src_md_, weights_md_,
- diff_dst_md_, nullptr);
- }
-};
-
-struct inner_product_bwd_weights_pd_t: public inner_product_pd_t {
- typedef inner_product_bwd_weights_pd_t base_class;
- typedef inner_product_fwd_pd_t hint_class;
-
- inner_product_bwd_weights_pd_t(engine_t *engine,
- const inner_product_desc_t *adesc,
- const primitive_attr_t *attr,
- const inner_product_fwd_pd_t *hint_fwd_pd)
- : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
- , src_md_(desc_.src_desc)
- , diff_weights_md_(desc_.diff_weights_desc)
- , diff_bias_md_(desc_.diff_bias_desc)
- , diff_dst_md_(desc_.diff_dst_desc)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
- return arg_usage_t::output;
-
- if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override
- { return index == 0 ? &src_md_ : nullptr; }
- virtual const memory_desc_t *diff_dst_md(int index = 0) const override
- { return index == 0 ? &diff_dst_md_ : nullptr; }
- virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
- if (index == 0) return &diff_weights_md_;
- if (index == 1 && with_bias()) return &diff_bias_md_;
- return nullptr;
- }
-
- virtual int n_inputs() const override { return 2; }
- virtual int n_outputs() const override { return 1 + with_bias(); }
-
-protected:
- memory_desc_t src_md_;
- memory_desc_t diff_weights_md_;
- memory_desc_t diff_bias_md_;
- memory_desc_t diff_dst_md_;
-
- status_t set_default_params() {
- return template_set_default_params(src_md_, diff_weights_md_,
- diff_dst_md_, &diff_bias_md_);
- }
-};
-
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp b/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp
deleted file mode 100644
index fcf18b556f..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp
+++ /dev/null
@@ -1,91 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::prop_kind;
-using namespace mkldnn::impl::alg_kind;
-using namespace mkldnn::impl::types;
-
-namespace {
-status_t lrn_desc_init(lrn_desc_t *lrn_desc,
- prop_kind_t prop_kind, alg_kind_t alg_kind,
- const memory_desc_t *data_desc, const memory_desc_t *diff_data_desc,
- dim_t local_size, float alpha, float beta, float k) {
- bool args_ok = true
- && !any_null(lrn_desc, data_desc)
- && one_of(alg_kind, lrn_within_channel, lrn_across_channels)
- && one_of(prop_kind, forward_training, forward_inference, backward_data)
- && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr);
- if (!args_ok) return invalid_arguments;
-
- auto ld = lrn_desc_t();
- ld.primitive_kind = primitive_kind::lrn;
- ld.prop_kind = prop_kind;
- ld.alg_kind = alg_kind;
-
- const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
-
- ld.data_desc = *data_desc;
- if (!is_fwd)
- ld.diff_data_desc = *diff_data_desc;
- else
- ld.diff_data_desc = zero_md();
- ld.local_size = local_size;
- ld.lrn_alpha = alpha;
- ld.lrn_beta = beta;
- ld.lrn_k = k;
-
- bool consistency = true
- && ld.data_desc.ndims == 4;
- if (ld.prop_kind == backward_data)
- consistency = consistency
- && ld.diff_data_desc.ndims == 4
- && array_cmp(ld.diff_data_desc.dims, ld.data_desc.dims, 4);
- if (!consistency) return invalid_arguments;
-
- *lrn_desc = ld;
- return success;
-}
-}
-
-status_t mkldnn_lrn_forward_desc_init(lrn_desc_t *lrn_desc,
- prop_kind_t prop_kind, alg_kind_t alg_kind,
- const memory_desc_t *data_desc, dim_t local_size, float alpha,
- float beta, float k) {
- if (!one_of(prop_kind, forward_training, forward_inference))
- return invalid_arguments;
- return lrn_desc_init(lrn_desc, prop_kind, alg_kind, data_desc, nullptr,
- local_size, alpha, beta, k);
-}
-
-status_t mkldnn_lrn_backward_desc_init(lrn_desc_t *lrn_desc,
- alg_kind_t alg_kind, const memory_desc_t *data_desc,
- const memory_desc_t *diff_data_desc, dim_t local_size, float alpha,
- float beta, float k) {
- return lrn_desc_init(lrn_desc, backward_data, alg_kind, data_desc,
- diff_data_desc, local_size, alpha, beta, k);
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp
deleted file mode 100644
index 90886e9656..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp
+++ /dev/null
@@ -1,170 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef LRN_PD_HPP
-#define LRN_PD_HPP
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "primitive_desc.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-struct lrn_fwd_pd_t;
-
-struct lrn_pd_t: public primitive_desc_t {
- static constexpr auto base_pkind = primitive_kind::lrn;
-
- lrn_pd_t(engine_t *engine,
- const lrn_desc_t *adesc,
- const primitive_attr_t *attr,
- const lrn_fwd_pd_t *hint_fwd_pd)
- : primitive_desc_t(engine, attr, base_pkind)
- , desc_(*adesc)
- , hint_fwd_pd_(hint_fwd_pd)
- , data_md_(desc_.data_desc)
- , ws_md_()
- {}
-
- const lrn_desc_t *desc() const { return &desc_; }
- virtual const op_desc_t *op_desc() const override
- { return reinterpret_cast<const op_desc_t *>(this->desc()); }
- virtual void init_info() override { impl::init_info(this, this->info_); }
-
- virtual status_t query(query_t what, int idx, void *result) const override {
- switch (what) {
- case query::lrn_d:
- *(const lrn_desc_t**)result = desc(); break;
- default: return primitive_desc_t::query(what, idx, result);
- }
- return status::success;
- }
-
- /* common lrn aux functions */
-
- dim_t MB() const { return data_desc().dims[0]; }
- dim_t C() const { return data_desc().dims[1]; }
- dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
- dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
- dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
-
- int ndims() const { return data_desc().ndims; }
-
- bool has_zero_dim_memory() const
- { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
-
- bool is_fwd() const {
- return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
- prop_kind::forward_inference);
- }
-
-protected:
- lrn_desc_t desc_;
- const lrn_fwd_pd_t *hint_fwd_pd_;
-
- memory_desc_t data_md_;
- memory_desc_t ws_md_;
-
-private:
- const memory_desc_t &data_desc() const { return desc_.data_desc; }
-};
-
-struct lrn_fwd_pd_t: public lrn_pd_t {
- typedef lrn_fwd_pd_t base_class;
- typedef lrn_fwd_pd_t hint_class;
-
- lrn_fwd_pd_t(engine_t *engine,
- const lrn_desc_t *adesc,
- const primitive_attr_t *attr,
- const lrn_fwd_pd_t *hint_fwd_pd)
- : lrn_pd_t(engine, adesc, attr, hint_fwd_pd)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (arg == MKLDNN_ARG_SRC)
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DST)
- return arg_usage_t::output;
-
- if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override
- { return index == 0 ? &data_md_ : nullptr; }
- virtual const memory_desc_t *dst_md(int index = 0) const override
- { return index == 0 ? &data_md_ : nullptr; }
- virtual const memory_desc_t *workspace_md(int index = 0) const override
- { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
-
- virtual int n_inputs() const override { return 1; }
- virtual int n_outputs() const override
- { return 1 + (workspace_md() != nullptr); }
-};
-
-struct lrn_bwd_pd_t: public lrn_pd_t {
- typedef lrn_bwd_pd_t base_class;
- typedef lrn_fwd_pd_t hint_class;
-
- lrn_bwd_pd_t(engine_t *engine,
- const lrn_desc_t *adesc,
- const primitive_attr_t *attr,
- const lrn_fwd_pd_t *hint_fwd_pd)
- : lrn_pd_t(engine, adesc, attr, hint_fwd_pd)
- , diff_data_md_(desc_.diff_data_desc)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DIFF_SRC)
- return arg_usage_t::output;
-
- if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
- return arg_usage_t::input;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override
- { return index == 0 ? &data_md_ : nullptr; }
- virtual const memory_desc_t *diff_dst_md(int index = 0) const override
- { return index == 0 ? &diff_data_md_ : nullptr; }
- virtual const memory_desc_t *diff_src_md(int index = 0) const override
- { return index == 0 ? &diff_data_md_ : nullptr; }
- virtual const memory_desc_t *workspace_md(int index = 0) const override
- { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
-
- virtual int n_inputs() const override
- { return 2 + (workspace_md() != nullptr); }
- virtual int n_outputs() const override { return 1; }
-
-protected:
- memory_desc_t diff_data_md_;
-};
-
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp b/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp
deleted file mode 100644
index 3fddc0bd45..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp
+++ /dev/null
@@ -1,280 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef MATH_UTILS_HPP
-#define MATH_UTILS_HPP
-
-#include <stdint.h>
-#include <math.h>
-
-#include "utils.hpp"
-#include "nstl.hpp"
-#include "mkldnn_traits.hpp"
-
-#if defined(MKLDNN_X86_64)
-#include "immintrin.h"
-#endif
-
-namespace mkldnn {
-namespace impl {
-namespace math {
-
-/** rounds @p f to an integer according to the mxcsr register */
-inline int mxcsr_round(float f) {
-#if defined(MKLDNN_X86_64)
- return _mm_cvtss_si32(_mm_load_ss(&f));
-#else
- return (int)nearbyintf(f); // optimism
-#endif
-}
-
-template <typename data_t, typename acc_t>
-inline typename utils::enable_if<!nstl::is_integral<data_t>::value,
- typename utils::remove_reference<data_t>::type>::type
-saturate(const acc_t &x) {
- return (typename utils::remove_reference<data_t>::type)x;
-}
-
-template <typename data_t, typename acc_t>
-inline typename utils::enable_if<nstl::is_integral<data_t>::value,
- typename utils::remove_reference<data_t>::type>::type
-saturate(const acc_t &x) {
- acc_t v = x;
- if (v < (acc_t)nstl::numeric_limits<data_t>::lowest())
- v = (acc_t)nstl::numeric_limits<data_t>::lowest();
- if (v > (acc_t)nstl::numeric_limits<data_t>::max())
- v = (acc_t)nstl::numeric_limits<data_t>::max();
- return (typename utils::remove_reference<data_t>::type)v;
-}
-
-template <typename data_t>
-double saturate(const double &x) {
- double v = x;
- if (v < (double)nstl::numeric_limits<data_t>::lowest())
- v = (double)nstl::numeric_limits<data_t>::lowest();
- if (v > (double)nstl::numeric_limits<data_t>::max())
- v = (double)nstl::numeric_limits<data_t>::max();
- return v;
-}
-
-template <> inline int8_t saturate<int8_t, uint8_t>(const uint8_t &x) {
- return x <= 127u ? x : 127;
-}
-
-template <> inline uint8_t saturate<uint8_t, int8_t>(const int8_t &x) {
- return x >= 0 ? x : 0;
-}
-
-template <typename out_t>
-typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
-out_round(float v) { return (out_t)mxcsr_round(v); }
-
-template <typename out_t>
-typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
-out_round(double v) { return (out_t)mxcsr_round((float)v); }
-
-template <typename out_t>
-typename utils::enable_if<!nstl::is_integral<out_t>::value, out_t>::type
-out_round(float v) { return v; }
-
-inline int gcd(int a, int b) {
- a = impl::nstl::abs(a);
- b = impl::nstl::abs(b);
- if (a < b) { int x = a; a = b; b = x; }
-
- if (b == 0) return a;
-
- int r;
- while ((r = a % b) != 0) { a = b; b = r; }
-
- return b;
-}
-
-template <typename T>
-inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; }
-
-/** returns floor(log2(v)), aka the position of the leftmost non-0 bit */
-inline int ilog2q(size_t v) {
- if (v == 0)
- return -1;
-
- int p = 0;
-# define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0)
- CP(32); CP(16); CP(8); CP(4); CP(2); CP(1);
-# undef CP
- return p;
-}
-
-template <typename T, typename U = typename utils::remove_reference<T>::type>
-inline U one_m_square(T x) {
- return (U)(1 - x) * (1 + x);
-}
-
-template <typename T, typename U = typename utils::remove_reference<T>::type>
-inline U x_m_square(T x) {
- return (U)(1 - x) * x;
-}
-
-/* activation */
-template <typename T, typename A,
- typename U = typename utils::remove_reference<T>::type>
-inline U relu_fwd(T s, A alpha) {
- return s > 0 ? s : (U)(s * alpha);
-}
-template <typename T, typename A,
- typename U = typename utils::remove_reference<T>::type>
-inline U relu_bwd(T dd, T s, A alpha) {
- return s > 0 ? dd : (U)(dd * alpha);
-}
-
-template <typename T, typename U = typename utils::remove_reference<T>::type>
-inline U tanh_fwd(T s) {
- const float e = tanhf((float) s);
- return (U)e;
-}
-
-template <typename T, typename U = typename utils::remove_reference<T>::type>
-inline U tanh_bwd(T dd, T s) {
- const float e = tanh_fwd<float>((float) s);
- return (U)(dd * (1 - e) * (1 + e));
-}
-
-template <typename T, typename A,
- typename U = typename utils::remove_reference<T>::type>
-inline U elu_fwd(T s, A alpha) {
- return s > 0 ? s : (U)(alpha * (::expm1f((float)s)));
-}
-template <typename T, typename A,
- typename U = typename utils::remove_reference<T>::type>
- inline U elu_bwd(T dd, T s, A alpha) {
- return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s)));
-}
-
-template <typename T, typename U = typename utils::remove_reference<T>::type>
-inline U square_fwd(T s) {
- return s * s;
-}
-
-template <typename T, typename U = typename utils::remove_reference<T>::type>
-inline U square_bwd(T dd, T s) {
- return dd * 2 * s;
-}
-
-template <typename T, typename U = typename utils::remove_reference<T>::type>
-inline U abs_fwd(T s) {
- return s > 0 ? s : -s;
-}
-
-template <typename T, typename U = typename utils::remove_reference<T>::type>
-inline U abs_bwd(T dd, T s) {
- return s > 0 ? dd : s < 0 ? -dd : 0;
-}
-
-template <typename T, typename U = typename utils::remove_reference<T>::type>
-inline U sqrt_fwd(T s) {
- return s > 0 ? (U)(::sqrtf((float)(s))) : 0;
-}
-
-template <typename T, typename U = typename utils::remove_reference<T>::type>
-inline U sqrt_bwd(T dd, T s) {
- return s > 0
- ? (U)(dd / (2 * ::sqrtf((float)(s))))
- : 0;
-}
-
-template <typename T, typename A,
- typename U = typename utils::remove_reference<T>::type>
-inline U linear_fwd(T s, A alpha, A beta) {
- return (U)(alpha * s + beta);
-}
-
-template <typename T, typename A,
- typename U = typename utils::remove_reference<T>::type>
-inline U linear_bwd(T dd, T s, A alpha, A beta) {
- (void) s;
- (void) beta;
- return (U)(dd * alpha);
-}
-
-template <typename T, typename A,
- typename U = typename utils::remove_reference<T>::type>
-inline U bounded_relu_fwd(T s, A alpha) {
- s = s > 0 ? s : 0;
- return s > alpha ? (U)(alpha) : s;
-}
-
-template <typename T, typename A,
- typename U = typename utils::remove_reference<T>::type>
-inline U bounded_relu_bwd(T dd, T s, A alpha) {
- return dd * (0 < s && s < alpha ? 1 : 0);
-}
-
-template <typename T, typename U = typename utils::remove_reference<T>::type>
-inline U soft_relu_fwd(T s) {
- float max_logf = 8.872284e+01; //::logf(FLT_MAX)
- return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s;
-}
-
-template <typename T, typename U = typename utils::remove_reference<T>::type>
-inline U soft_relu_bwd(T dd, T s) {
- return (U)(dd / (1 + ::expf((float)(-s))));
-}
-
-template <typename T, typename U = typename utils::remove_reference<T>::type>
-inline U logistic_fwd(T s) {
- U v = (U)(::expf((float) -s));
- return 1 / (1 + v);
-}
-
-template <typename T, typename U = typename utils::remove_reference<T>::type>
-inline U logistic_bwd(T dd, T s) {
- U v = logistic_fwd<T, U>(s);
- return dd * v * (1 - v);
-}
-
-inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) {
- using namespace alg_kind;
- using namespace utils;
- const bool preserves_zero = true
- && !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic)
- && IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh));
- return preserves_zero;
-}
-
-inline float get_bias(const char *bias, size_t offset, data_type_t data_type)
-{
- if (!bias)
- return 0.0f;
-
-#define CASE(dt) \
- case dt: return (float)((const prec_traits<dt>::type *)bias)[offset]
-
- switch (data_type) {
- CASE(data_type::s8);
- CASE(data_type::u8);
- CASE(data_type::s32);
- CASE(data_type::f32);
- default: assert(!"unimplemented");
- }
- return 0; // never happens (should probably be a NaN)
-#undef CASE
-}
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory.cpp b/thirdparty/oidn/mkl-dnn/src/common/memory.cpp
deleted file mode 100644
index cea849c96e..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/memory.cpp
+++ /dev/null
@@ -1,238 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include <stddef.h>
-#include <stdint.h>
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "engine.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::data_type;
-
-namespace {
-bool memory_desc_sanity_check(int ndims,const dims_t dims,
- data_type_t data_type, format_kind_t format_kind) {
- if (ndims == 0) return true;
-
- bool ok = true
- && dims != nullptr
- && 0 < ndims && ndims <= MKLDNN_MAX_NDIMS
- && one_of(data_type, f32, s32, s8, u8)
- && format_kind != format_kind::undef;
- if (!ok) return false;
- for (int d = 0; d < ndims; ++d)
- if (dims[d] < 0) return false;
-
- return true;
-}
-
-bool memory_desc_sanity_check(const memory_desc_t *md) {
- if (md == nullptr) return false;
- return memory_desc_sanity_check(md->ndims, md->dims, md->data_type,
- format_kind::any);
-}
-}
-
-status_t mkldnn_memory_desc_init_by_tag(memory_desc_t *memory_desc, int ndims,
- const dims_t dims, data_type_t data_type, format_tag_t tag) {
- if (any_null(memory_desc)) return invalid_arguments;
- if (ndims == 0 || tag == format_tag::undef) {
- *memory_desc = types::zero_md();
- return success;
- }
-
- format_kind_t format_kind = types::format_tag_to_kind(tag);
-
- /* memory_desc != 0 */
- bool args_ok = !any_null(memory_desc)
- && memory_desc_sanity_check(ndims, dims, data_type, format_kind);
- if (!args_ok) return invalid_arguments;
-
- auto md = memory_desc_t();
- md.ndims = ndims;
- array_copy(md.dims, dims, ndims);
- md.data_type = data_type;
- array_copy(md.padded_dims, dims, ndims);
- md.format_kind = format_kind;
-
- status_t status = success;
- if (tag == format_tag::undef) {
- status = invalid_arguments;
- } else if (tag == format_tag::any) {
- // nop
- } else if (format_kind == format_kind::blocked) {
- status = memory_desc_wrapper::compute_blocking(md, tag);
- } else {
- assert(!"unreachable");
- status = invalid_arguments;
- }
-
- if (status == success)
- *memory_desc = md;
-
- return status;
-}
-
-status_t mkldnn_memory_desc_init_by_strides(memory_desc_t *memory_desc,
- int ndims, const dims_t dims, data_type_t data_type,
- const dims_t strides) {
- if (any_null(memory_desc)) return invalid_arguments;
- if (ndims == 0) {
- *memory_desc = types::zero_md();
- return success;
- }
-
- /* memory_desc != 0 */
- bool args_ok = !any_null(memory_desc)
- && memory_desc_sanity_check(ndims, dims, data_type, format_kind::any);
- if (!args_ok) return invalid_arguments;
-
- auto md = memory_desc_t();
- md.ndims = ndims;
- array_copy(md.dims, dims, ndims);
- md.data_type = data_type;
- array_copy(md.padded_dims, dims, ndims);
- md.format_kind = format_kind::blocked;
-
- dims_t default_strides = {0};
- if (strides == nullptr) {
- default_strides[md.ndims - 1] = 1;
- for (int d = md.ndims - 2; d >= 0; --d)
- default_strides[d] = default_strides[d + 1] * md.padded_dims[d + 1];
- strides = default_strides;
- } else {
- /* TODO: add sanity check for the provided strides */
- }
-
- array_copy(md.format_desc.blocking.strides, strides, md.ndims);
-
- *memory_desc = md;
-
- return status::success;
-}
-
-status_t mkldnn_memory_desc_init_submemory(memory_desc_t *md,
- const memory_desc_t *parent_md, const dims_t dims,
- const dims_t offsets) {
- if (any_null(md, parent_md) || !memory_desc_sanity_check(parent_md))
- return invalid_arguments;
-
- const memory_desc_wrapper src_d(parent_md);
-
- for (int d = 0; d < src_d.ndims(); ++d) {
- if (dims[d] < 0 || offsets[d] < 0
- || (offsets[d] + dims[d] > src_d.dims()[d]))
- return invalid_arguments;
- }
-
- if (src_d.format_kind() != format_kind::blocked)
- return unimplemented;
-
- dims_t blocks;
- src_d.compute_blocks(blocks);
-
- memory_desc_t dst_d = *parent_md;
- auto &dst_d_blk = dst_d.format_desc.blocking;
-
- /* TODO: put this into memory_desc_wrapper */
- for (int d = 0; d < src_d.ndims(); ++d) {
- /* very limited functionality for now */
- const bool ok = true
- && offsets[d] % blocks[d] == 0 /* [r1] */
- && src_d.padded_offsets()[d] == 0
- && (false
- || dims[d] % blocks[d] == 0
- || dims[d] < blocks[d]);
- if (!ok)
- return unimplemented;
-
- const bool is_right_border = offsets[d] + dims[d] == src_d.dims()[d];
-
- dst_d.dims[d] = dims[d];
- dst_d.padded_dims[d] = is_right_border
- ? src_d.padded_dims()[d] - offsets[d] : dst_d.dims[d];
- dst_d.padded_offsets[d] = src_d.padded_offsets()[d];
- dst_d.offset0 += /* [r1] */
- offsets[d] / blocks[d] * dst_d_blk.strides[d];
- }
-
- *md = dst_d;
-
- return success;
-}
-
-int mkldnn_memory_desc_equal(const memory_desc_t *lhs,
- const memory_desc_t *rhs) {
- if (lhs == rhs) return 1;
- if (any_null(lhs, rhs)) return 0;
- return memory_desc_wrapper(*lhs) == memory_desc_wrapper(*rhs);
-}
-
-size_t mkldnn_memory_desc_get_size(const memory_desc_t *md) {
- if (md == nullptr) return 0;
- return memory_desc_wrapper(*md).size();
-}
-
-status_t mkldnn_memory_create(memory_t **memory, const memory_desc_t *md,
- engine_t *engine, void *handle) {
- if (any_null(memory, engine)) return invalid_arguments;
- memory_desc_t z_md = types::zero_md();
- return engine->memory_create(memory, md ? md : &z_md, handle);
-}
-
-status_t mkldnn_memory_get_memory_desc(const memory_t *memory,
- const memory_desc_t **md) {
- if (any_null(memory, md)) return invalid_arguments;
- *md = memory->md();
- return success;
-}
-
-status_t mkldnn_memory_get_engine(const memory_t *memory, engine_t **engine) {
- if (any_null(memory, engine)) return invalid_arguments;
- *engine = memory->engine();
- return success;
-}
-
-status_t mkldnn_memory_get_data_handle(const memory_t *memory,
- void **handle) {
- if (any_null(handle))
- return invalid_arguments;
- if (memory == nullptr) {
- *handle = nullptr;
- return success;
- }
- return memory->get_data_handle(handle);
-}
-
-status_t mkldnn_memory_set_data_handle(memory_t *memory, void *handle) {
- if (any_null(memory)) return invalid_arguments;
- return memory->set_data_handle(handle);
-}
-
-status_t mkldnn_memory_destroy(memory_t *memory) {
- delete memory;
- return success;
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory.hpp
deleted file mode 100644
index 03dfee01ff..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/memory.hpp
+++ /dev/null
@@ -1,63 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef MEMORY_HPP
-#define MEMORY_HPP
-
-#include <assert.h>
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "nstl.hpp"
-
-struct mkldnn_memory: public mkldnn::impl::c_compatible {
- mkldnn_memory(mkldnn::impl::engine_t *engine,
- const mkldnn::impl::memory_desc_t *md)
- : engine_(engine), md_(*md) {}
- virtual ~mkldnn_memory() {}
-
- /** allocates/initializes memory */
- virtual mkldnn::impl::status_t init() = 0;
-
- /** returns memory's engine */
- mkldnn::impl::engine_t *engine() const { return engine_; }
- /** returns memory's description */
- const mkldnn::impl::memory_desc_t *md() const { return &md_; }
-
- /** returns data handle */
- virtual mkldnn::impl::status_t get_data_handle(void **handle) const = 0;
-
- /** sets data handle */
- virtual mkldnn::impl::status_t set_data_handle(void *handle) = 0;
-
- /** zeros padding */
- virtual mkldnn::impl::status_t zero_pad() const
- { return mkldnn::impl::status::success; }
-
-protected:
- mkldnn::impl::engine_t *engine_;
- const mkldnn::impl::memory_desc_t md_;
-
-private:
- mkldnn_memory() = delete;
- mkldnn_memory(const mkldnn_memory &) = delete;
- mkldnn_memory(mkldnn_memory &&) = delete;
- mkldnn_memory &operator=(const mkldnn_memory &) = delete;
- mkldnn_memory &operator=(mkldnn_memory &&) = delete;
-};
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp
deleted file mode 100644
index 8a99be33f3..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp
+++ /dev/null
@@ -1,212 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-
-#include <initializer_list>
-
-#include "c_types_map.hpp"
-#include "memory_desc_wrapper.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-status_t fill_blocked(memory_desc_t &md,
- std::initializer_list<int> perm,
- std::initializer_list<int> inner_blks,
- std::initializer_list<int> inner_idxs) {
- const bool ok = true
- && perm.size() == (size_t)md.ndims
- && inner_blks.size() == inner_idxs.size();
- if (!ok) return status::invalid_arguments;
-
- md.offset0 = 0;
-
- blocking_desc_t &blk = md.format_desc.blocking;
-
- dim_t block_size = 1;
- dims_t blocks = {0};
- utils::array_set(blocks, 1, md.ndims);
-
- blk.inner_nblks = (int)inner_blks.size();
-
- int iblk = 0;
- for (const auto &b: inner_idxs)
- blk.inner_idxs[iblk++] = b;
-
- iblk = 0;
- for (const auto &b: inner_blks) {
- int dim = blk.inner_idxs[iblk];
- block_size *= b;
- blocks[dim] *= b;
- blk.inner_blks[iblk++] = b;
- }
-
- utils::array_set(md.padded_offsets, 0, md.ndims);
- for (int d = 0; d < md.ndims; ++d)
- md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]);
-
- dim_t stride = block_size;
- // if only we use C++14, the initializer_list would have rbegin()/rend()...
- for (int d = 0; d < md.ndims; ++d)
- stride *= md.padded_dims[d] == 0 ? 1 : md.padded_dims[d] / blocks[d];
-
- for (const auto &d: perm) {
- if (md.padded_dims[d] == 0) {
- blk.strides[d] = 1;
- continue;
- }
- stride /= md.padded_dims[d] / blocks[d];
- blk.strides[d] = stride;
- }
-
- assert(stride == block_size);
-
- return status::success;
-}
-
-status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc,
- format_tag_t tag)
-{
- using namespace format_tag;
-
- if (memory_desc.ndims == 0) return status::invalid_arguments;
-
-# define C(tag, ... /* perm, inner_blks, inner_idxs */) \
- case tag: return fill_blocked(memory_desc, __VA_ARGS__)
-
- switch (tag) {
- C(a, {0}, {}, {});
- C(ab, {0, 1}, {}, {});
- C(abc, {0, 1, 2}, {}, {});
- C(abcd, {0, 1, 2, 3}, {}, {});
- C(abcde, {0, 1, 2, 3, 4}, {}, {});
- C(abcdef, {0, 1, 2, 3, 4, 5}, {}, {});
- C(abdec, {0, 1, 3, 4, 2}, {}, {});
- C(acb, {0, 2, 1}, {}, {});
- C(acbde, {0, 2, 1, 3, 4}, {}, {});
- C(acdb, {0, 2, 3, 1}, {}, {});
- C(acdeb, {0, 2, 3, 4, 1}, {}, {});
- C(ba, {1, 0}, {}, {});
- C(bac, {1, 0, 2}, {}, {});
- C(bacd, {1, 0, 2, 3}, {}, {});
- C(bcda, {1, 2, 3, 0}, {}, {});
- C(cba, {2, 1, 0}, {}, {});
- C(cdba, {2, 3, 1, 0}, {}, {});
- C(cdeba, {2, 3, 4, 1, 0}, {}, {});
- C(decab, {3, 4, 2, 0, 1}, {}, {});
-
- C(Abc4a, {0, 1, 2}, {4}, {0});
- C(aBc4b, {0, 1, 2}, {4}, {1});
- C(ABc4b16a4b, {0, 1, 2}, {4, 16, 4}, {1, 0, 1});
- C(ABc4b4a, {0, 1, 2}, {4, 4}, {1, 0});
- C(Abcd4a, {0, 1, 2, 3}, {4}, {0});
- C(aBcd4b, {0, 1, 2, 3}, {4}, {1});
- C(ABcd4b4a, {0, 1, 2, 3}, {4, 4}, {1, 0});
- C(aBCd4c16b4c, {0, 1, 2, 3}, {4, 16, 4}, {2, 1, 2});
- C(aBCd4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1});
- C(Abcde4a, {0, 1, 2, 3, 4}, {4}, {0});
- C(aBcde4b, {0, 1, 2, 3, 4}, {4}, {1});
- C(ABcde4b4a, {0, 1, 2, 3, 4}, {4, 4}, {1, 0});
- C(aBCde4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1});
- C(aBcdef4b, {0, 1, 2, 3, 4, 5}, {4}, {1});
- C(aBCdef4c4b, {0, 1, 2, 3, 4, 5}, {4, 4}, {2, 1});
- C(aBdc4b, {0, 1, 3, 2}, {4}, {1});
- C(aBdec4b, {0, 1, 3, 4, 2}, {4}, {1});
- C(aBdefc4b, {0, 1, 3, 4, 5, 2}, {4}, {1});
- C(Acb4a, {0, 2, 1}, {4}, {0});
- C(Acdb4a, {0, 2, 3, 1}, {4}, {0});
- C(Acdeb4a, {0, 2, 3, 4, 1}, {4}, {0});
-
- C(Abc16a, {0, 1, 2}, {16}, {0});
- C(ABc16a16b, {0, 1, 2}, {16, 16}, {0, 1});
- C(aBc16b, {0, 1, 2}, {16}, {1});
- C(ABc16b16a, {0, 1, 2}, {16, 16}, {1, 0});
- C(ABc8a16b2a, {0, 1, 2}, {8, 16, 2}, {0, 1, 0});
- C(ABc8a8b, {0, 1, 2}, {8, 8}, {0, 1});
- C(aBc8b, {0, 1, 2}, {8}, {1});
- C(ABc8b16a2b, {0, 1, 2}, {8, 16, 2}, {1, 0, 1});
- C(ABc8b8a, {0, 1, 2}, {8, 8}, {1, 0});
- C(Abcd16a, {0, 1, 2, 3}, {16}, {0});
- C(ABcd16a16b, {0, 1, 2, 3}, {16, 16}, {0, 1});
- C(aBcd16b, {0, 1, 2, 3}, {16}, {1});
- C(ABcd16b16a, {0, 1, 2, 3}, {16, 16}, {1, 0});
- C(aBCd16b16c, {0, 1, 2, 3}, {16, 16}, {1, 2});
- C(aBCd16c16b, {0, 1, 2, 3}, {16, 16}, {2, 1});
- C(ABcd4b16a4b, {0, 1, 2, 3}, {4, 16, 4}, {1, 0, 1});
- C(ABcd8a16b2a, {0, 1, 2, 3}, {8, 16, 2}, {0, 1, 0});
- C(ABcd8a8b, {0, 1, 2, 3}, {8, 8}, {0, 1});
- C(aBcd8b, {0, 1, 2, 3}, {8}, {1});
- C(ABcd8b16a2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 0, 1});
- C(aBCd8b16c2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 2, 1});
- C(ABcd8b8a, {0, 1, 2, 3}, {8, 8}, {1, 0});
- C(aBCd8b8c, {0, 1, 2, 3}, {8, 8}, {1, 2});
- C(aBCd8c16b2c, {0, 1, 2, 3}, {8, 16, 2}, {2, 1, 2});
- C(aBCd8c8b, {0, 1, 2, 3}, {8, 8}, {2, 1});
- C(Abcde16a, {0, 1, 2, 3, 4}, {16}, {0});
- C(ABcde16a16b, {0, 1, 2, 3, 4}, {16, 16}, {0, 1});
- C(aBcde16b, {0, 1, 2, 3, 4}, {16}, {1});
- C(ABcde16b16a, {0, 1, 2, 3, 4}, {16, 16}, {1, 0});
- C(aBCde16b16c, {0, 1, 2, 3, 4}, {16, 16}, {1, 2});
- C(aBCde16c16b, {0, 1, 2, 3, 4}, {16, 16}, {2, 1});
- C(aBCde2c8b4c, {0, 1, 2, 3, 4}, {2, 8, 4}, {2, 1, 2});
- C(aBCde4b4c, {0, 1, 2, 3, 4}, {4, 4}, {1, 2});
- C(aBCde4c16b4c, {0, 1, 2, 3, 4}, {4, 16, 4}, {2, 1, 2});
- C(Abcde8a, {0, 1, 2, 3, 4}, {8}, {0});
- C(ABcde8a8b, {0, 1, 2, 3, 4}, {8, 8}, {0, 1});
- C(aBcde8b, {0, 1, 2, 3, 4}, {8}, {1});
- C(ABcde8b16a2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 0, 1});
- C(aBCde8b16c2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 2, 1});
- C(ABcde8b8a, {0, 1, 2, 3, 4}, {8, 8}, {1, 0});
- C(aBCde8b8c, {0, 1, 2, 3, 4}, {8, 8}, {1, 2});
- C(aBCde8c16b2c, {0, 1, 2, 3, 4}, {8, 16, 2}, {2, 1, 2});
- C(aBCde8c8b, {0, 1, 2, 3, 4}, {8, 8}, {2, 1});
- C(aBcdef16b, {0, 1, 2, 3, 4, 5}, {16}, {1});
- C(aBCdef16b16c, {0, 1, 2, 3, 4, 5}, {16, 16}, {1, 2});
- C(aBCdef16c16b, {0, 1, 2, 3, 4, 5}, {16, 16}, {2, 1});
- C(aBCdef8b8c, {0, 1, 2, 3, 4, 5}, {8, 8}, {1, 2});
- C(aBCdef8c16b2c, {0, 1, 2, 3, 4, 5}, {8, 16, 2}, {2, 1, 2});
- C(aBCdef8c8b, {0, 1, 2, 3, 4, 5}, {8, 8}, {2, 1});
- C(aBdc16b, {0, 1, 3, 2}, {16}, {1});
- C(aBdc8b, {0, 1, 3, 2}, {8}, {1});
- C(aBdec16b, {0, 1, 3, 4, 2}, {16}, {1});
- C(aBdec8b, {0, 1, 3, 4, 2}, {8}, {1});
- C(aBdefc16b, {0, 1, 3, 4, 5, 2}, {16}, {1});
- C(aBdefc8b, {0, 1, 3, 4, 5, 2}, {8}, {1});
- C(Acb16a, {0, 2, 1}, {16}, {0});
- C(Acb8a, {0, 2, 1}, {8}, {0});
- C(aCBd16b16c, {0, 2, 1, 3}, {16, 16}, {1, 2});
- C(aCBde16b16c, {0, 2, 1, 3, 4}, {16, 16}, {1, 2});
- C(Acdb16a, {0, 2, 3, 1}, {16}, {0});
- C(Acdb8a, {0, 2, 3, 1}, {8}, {0});
- C(Acdeb16a, {0, 2, 3, 4, 1}, {16}, {0});
- C(Acdeb8a, {0, 2, 3, 4, 1}, {8}, {0});
- C(BAc16a16b, {1, 0, 2}, {16, 16}, {0, 1});
- C(BAcd16a16b, {1, 0, 2, 3}, {16, 16}, {0, 1});
- default: break;
- }
-
-#undef C
-
- return status::invalid_arguments;
-}
-
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp
deleted file mode 100644
index 1758f9078a..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp
+++ /dev/null
@@ -1,400 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef MEMORY_DESC_WRAPPER_HPP
-#define MEMORY_DESC_WRAPPER_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "nstl.hpp"
-#include "utils.hpp"
-
-#include "type_helpers.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-/** thin wrapper class over \struct memory_desc_t which allows easy
- * manipulations with underlying C structure, which is taken by reference */
-struct memory_desc_wrapper: public c_compatible {
- const memory_desc_t *md_;
-
- /** constructor which takes a reference to a constant underlying C memory
- * descriptor \param md */
- memory_desc_wrapper(const memory_desc_t *md): md_(md) {}
- memory_desc_wrapper(const memory_desc_t &md): memory_desc_wrapper(&md) {}
-
- /* implementing attributes */
- int ndims() const { return md_->ndims; }
- const dims_t &dims() const { return md_->dims; }
- data_type_t data_type() const { return md_->data_type; }
-
- const dims_t &padded_dims() const { return md_->padded_dims; }
- const dims_t &padded_offsets() const { return md_->padded_offsets; }
- dim_t offset0() const { return md_->offset0; }
-
- format_kind_t format_kind() const { return md_->format_kind; }
-
- bool is_blocking_desc() const
- { return format_kind() == format_kind::blocked; }
- bool is_wino_desc() const
- { return format_kind() == format_kind::wino; }
- bool is_rnn_packed_desc() const
- { return format_kind() == format_kind::rnn_packed; }
-
- const blocking_desc_t &blocking_desc() const {
- assert(is_blocking_desc());
- return md_->format_desc.blocking;
- }
- const wino_desc_t &wino_desc() const {
- assert(is_wino_desc());
- return md_->format_desc.wino_desc;
- }
- const rnn_packed_desc_t &rnn_packed_desc() const {
- assert(is_rnn_packed_desc());
- return md_->format_desc.rnn_packed_desc;
- }
-
- const memory_extra_desc_t &extra() const { return md_->extra; }
-
- /* some useful function */
-
- /** returns the number of elements including padding if \param with_padding
- * is true, and the number of data elements otherwise */
- dim_t nelems(bool with_padding = false) const {
- if (is_zero()) return 0;
- return utils::array_product(
- with_padding ? padded_dims() : dims(), ndims());
- }
-
- /** returns true if memory descriptor is zero */
- bool is_zero() const { return ndims() == 0; }
-
- /** returns true if memory descriptor contains zero as one of its dim */
- bool has_zero_dim() const { return nelems() == 0; }
-
- /** return the size of data type (a shortcut) */
- size_t data_type_size() const
- { return types::data_type_size(data_type()); }
-
- /** return the size of data type of additional buffer */
- size_t additional_buffer_data_size() const {
- if (extra().flags & memory_extra_flags::compensation_conv_s8s8)
- return sizeof(int32_t);
- return 0;
- }
-
- /** return true if memory format has additional buffer */
- bool is_additional_buffer() const {
- return (extra().flags & memory_extra_flags::compensation_conv_s8s8);
- }
-
- /** returns the size of additional buffer */
- size_t additional_buffer_size() const {
- if (extra().flags & memory_extra_flags::compensation_conv_s8s8) {
- int cmask = extra().compensation_mask;
- assert(cmask == 1 || cmask == 3);
- dim_t prod = 1;
- for (int d = 0; d < ndims(); ++d)
- if (cmask & (1<<d)) prod *= padded_dims()[d];
- return prod * additional_buffer_data_size();
- }
-
- return 0;
- }
-
- /** returns the size required to store described memory
- * note: if offset0 != 0 returns 0 (need to specify the behavior) */
- size_t size() const {
- if (is_zero() || has_zero_dim() || format_kind() == format_kind::any)
- return 0;
-
- if (format_kind() == format_kind::wino) {
- return wino_desc().size;
- } else if (format_kind() == format_kind::rnn_packed) {
- return rnn_packed_desc().size;
- } else {
- if (offset0() != 0) return 0;
-
- dims_t blocks = {0};
- compute_blocks(blocks);
-
- const auto &bd = blocking_desc();
-
- size_t max_size = 0;
- for (int d = 0; d < ndims(); ++d)
- max_size = nstl::max<size_t>(max_size,
- padded_dims()[d] / blocks[d] * bd.strides[d]);
-
- if (max_size == 1 && bd.inner_nblks != 0) {
- max_size = utils::array_product(bd.inner_blks, bd.inner_nblks);
- }
-
- return max_size * data_type_size() + additional_buffer_size();
- }
- }
-
- /** returns true if data is dense in memory */
- bool is_dense(bool with_padding = false) const {
- if (utils::one_of(format_kind(), format_kind::undef, format_kind::any))
- return false;
- return nelems(with_padding) * data_type_size() == size();
- }
-
- /** returns true if memory desc is fully defined */
- bool is_defined() const { return format_kind() != format_kind::any; }
-
- /** returns true if the only (potentially) padded dim is \param dim */
- bool only_padded_dim(int dim) const {
- for (int d = 0; d < ndims(); ++d)
- if (d != dim && dims()[d] != padded_dims()[d])
- return false;
- return true;
- }
-
- /** returns true if memory desc has blocked layout and block dims are 1s */
- bool is_plain() const {
- if (!is_blocking_desc()) return false;
- return blocking_desc().inner_nblks == 0;
- }
-
- /** returns overall block sizes */
- void compute_blocks(dims_t blocks) const {
- if (!is_blocking_desc()) {
- utils::array_set(blocks, 0, ndims());
- return;
- }
-
- utils::array_set(blocks, 1, ndims());
-
- const auto &bd = blocking_desc();
- for (int iblk = 0; iblk < bd.inner_nblks; ++iblk)
- blocks[bd.inner_idxs[iblk]] *= bd.inner_blks[iblk];
- }
-
- /* comparison section */
-
- bool operator==(const memory_desc_wrapper &rhs) const
- { return *this->md_ == *rhs.md_; }
- bool operator!=(const memory_desc_wrapper &rhs) const
- { return !operator==(rhs); }
- bool operator==(const memory_desc_t &rhs) const
- { return operator==(memory_desc_wrapper(rhs)); }
- bool operator!=(const memory_desc_t &rhs) const
- { return !operator==(rhs); }
-
- /** returns true if data (w/o padding if with_padding == false and w/
- * padding otherwise) have the same physical structure, i.e. dimensions,
- * strides, and blocked structure. Depending on with_data_type flag
- * data_type is taken or not taken into account. dim_start allows to check
- * similarity for the logical part of data [dim_start .. ndims()].
- * CAUTION: format kind any and undef are not similar to whatever, hence the
- * following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
- /* TODO: revise */
- bool similar_to(const memory_desc_wrapper &rhs,
- bool with_padding = true, bool with_data_type = true,
- int dim_start = 0) const;
-
- /** returns true if one memory can be reordered to another */
- bool consistent_with(const memory_desc_wrapper &rhs) const;
-
- /** returns true if the memory desc corresponds to the given format tag and
- * strides.
- * @sa memory_desc_matches_tag */
- bool matches_tag(format_tag_t tag, const dims_t strides = nullptr) const {
- return memory_desc_matches_tag(*md_, tag, strides);
- }
-
- /** returns matching tag (or undef if match is not found)
- * XXX: This is a workaround that eventually should go away! */
- template <typename... Tags>
- format_tag_t matches_one_of_tag(Tags ...tags) const {
- for (const auto tag: {tags...}) {
- if (memory_desc_matches_tag(*md_, tag))
- return tag;
- }
- return format_tag::undef;
- }
-
- /* offset section */
-
- /** returns physical offset by logical one. logical offset is represented by
- * an array \param pos. if \param is_pos_padded is true \param pos
- * represents the position in already padded area */
- dim_t off_v(const dims_t pos, bool is_pos_padded = false) const {
- assert(is_blocking_desc());
- const blocking_desc_t &blk = blocking_desc();
-
- dims_t pos_copy = {0};
- for (int d = 0; d < ndims(); ++d)
- pos_copy[d] = pos[d] + (is_pos_padded ? 0 : padded_offsets()[d]);
-
- dim_t phys_offset = offset0();
-
- if (blk.inner_nblks > 0) {
- dim_t blk_stride = 1;
- for (int iblk = blk.inner_nblks - 1; iblk >= 0; --iblk) {
- const int d = blk.inner_idxs[iblk];
- const dim_t p = pos_copy[d] % blk.inner_blks[iblk];
-
- phys_offset += p * blk_stride;
-
- pos_copy[d] /= blk.inner_blks[iblk];
-
- blk_stride *= blk.inner_blks[iblk];
- }
- }
-
- for (int d = 0; d < ndims(); ++d) {
- const dim_t p = pos_copy[d];
- phys_offset += p * blk.strides[d];
- }
-
- return phys_offset;
- }
-
- /** returns physical offset by logical one. logical offset is represented by
- * a scalar \param l_offset. if \param is_pos_padded is true, \param
- * l_offset represents logical offset in already padded area */
- dim_t off_l(dim_t l_offset, bool is_pos_padded = false) const {
- assert(is_blocking_desc());
- dims_t pos;
- for (int rd = 0; rd < ndims(); ++rd) {
- const int d = ndims() - 1 - rd;
- const dim_t cur_dim = is_pos_padded ? padded_dims()[d] : dims()[d];
- pos[d] = l_offset % cur_dim;
- l_offset /= cur_dim;
- }
- return off_v(pos, is_pos_padded);
- }
-
- /** returns physical offset by logical one. logical offset is represented by
- * a tuple of indices (\param xn, ..., \param x1, \param x0) */
- template<typename... Args>
- dim_t off(Args... args) const {
- assert(sizeof...(args) == ndims());
- dims_t pos = { args... };
- return off_v(pos, false);
- }
-
- /** returns physical offset by logical one. logical offset is represented by
- * a tuple of indices (\param xn, ..., \param x1, \param x0) in already
- * padded area */
- template<typename... Args>
- dim_t off_padding(Args... args) const {
- assert(sizeof...(args) == ndims());
- dims_t pos = { args... };
- return off_v(pos, true);
- }
-
- /** returns physical offset by logical one. Logical offset is represented by
- * a tuple of block indices (\param bn, ..., \param b1, \param b0). It is a
- * user responsibility to adjust the result to get offset within blocks */
- template<typename ...Args>
- dim_t blk_off(Args... args) const {
- return _blk_off<sizeof...(args), Args...>(args...);
- }
-
- template<bool skip_first, typename T, typename ...Args>
- dim_t blk_off(T xn, Args... args) const {
- return skip_first
- ? blk_off<Args...>(args...)
- : blk_off<T, Args...>(xn, args...);
- }
-
- /* static functions section */
- /* TODO: replace with non-static, once md_ becomes non-const ref */
-
- static status_t compute_blocking(memory_desc_t &memory_desc,
- format_tag_t tag);
-
-private:
- /* TODO: put logical_offset in utils */
- template<typename T>
- dim_t logical_offset(T x0) const { return x0; }
-
- template<typename T, typename... Args>
- dim_t logical_offset(T xn, Args... args) const {
- const size_t n_args = sizeof...(args);
- return xn * utils::array_product<n_args>(
- &dims()[ndims() - n_args]) + logical_offset(args...);
- }
-
- template<int ORIG_LEN, typename ...Void>
- dim_t _blk_off() const { return offset0(); }
-
- template<int ORIG_LEN, typename T, typename ...Args>
- dim_t _blk_off(T xc, Args ...args) const {
- assert(is_blocking_desc());
- constexpr int dc = ORIG_LEN - sizeof...(args) - 1;
- return xc * blocking_desc().strides[dc]
- + _blk_off<ORIG_LEN, Args...>(args...);
- }
-};
-
-inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
- bool with_padding, bool with_data_type, int dim_start) const {
- using namespace utils;
-
- if (one_of(format_kind(), format_kind::undef, format_kind::any))
- return false;
- if (is_wino_desc() || is_rnn_packed_desc())
- return false;
-
- const int ds = dim_start;
- const auto &blk = blocking_desc();
- const auto &r_blk = rhs.blocking_desc();
-
- return ndims() == rhs.ndims()
- && dim_start <= ndims() /* guard */
- && format_kind() == rhs.format_kind()
- && IMPLICATION(with_data_type, data_type() == rhs.data_type())
- && array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds)
- && array_cmp(blk.strides + ds, r_blk.strides + ds, ndims() - ds)
- && blk.inner_nblks == r_blk.inner_nblks
- && array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks)
- && array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks)
- && IMPLICATION(with_padding, true
- && array_cmp(padded_dims() + ds, rhs.padded_dims() + ds,
- ndims() - ds)
- && array_cmp(padded_offsets() + ds, rhs.padded_offsets() + ds,
- ndims() - ds));
-}
-
-inline bool memory_desc_wrapper::consistent_with(
- const memory_desc_wrapper &rhs) const {
- if (ndims() == rhs.ndims()) {
- for (int d = 0; d < ndims(); ++d) {
- if (dims()[d] != rhs.dims()[d]) return false;
- }
- return true;
- } else {
- /* TODO: revise.
- * is the following possible?
- * [1, a, b] <--reorder--> [a, b]
- * [a, 1, b] <--reorder--> [a, b]
- * not, at least for now */
- return false;
- }
-}
-
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp
deleted file mode 100644
index ec077b308c..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp
+++ /dev/null
@@ -1,295 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef MEMORY_TRACKING_HPP
-#define MEMORY_TRACKING_HPP
-
-#include <assert.h>
-#include <unordered_map>
-
-#include "nstl.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace memory_tracking {
-
-/* Memory tracking capabilities
- *
- * The main purpose of this header file is to provide uniform way to register
- * required memory for a scratchpad at a primitive descriptor creation time
- * and then easily access it having only the base address of the scratchpad.
- *
- * Primitives might contain multiple disjoint parts that require temporary
- * buffers (known as scratchpad) during their execution. A primitive descriptor
- * should summarize all the needs into one single number -- the buffer size
- * that would be requested from a user. At execution time, the corresponding
- * primitive will receive a base pointer to a scratchpad. It then needs to
- * provide each part of algorithm the corresponding piece of memory. Three main
- * challenges here are:
- * 1. Track correct offset (from the base scratchpad address) for each piece
- * 2. Algorithm might require that different memory pieces to be aligned, so
- * the scratchpad size is no more just a sum of size of the corresponding
- * subparts.
- * 3. While a primitive is responsible for its scratchpad, the implementation
- * might use some other basic blocks (e.g. cpu_reducer) that also require
- * scratchpad memory. So there should be a simple way of passing the
- * information back and force between the main algorithm (a primitive) and
- * auxiliary stuff that lives completely separately from it (e.g. reducer).
- *
- * To address these challenges this header file provides 3 structures:
- * 1. registry_t -- the class the stores the information about requested
- * memory. The information includes required size and desired
- * alignment for each piece. This class is also responsible
- * for computing the right offset to a given piece using the
- * base pointer.
- * This class is basically a ledger with all entries.
- * Lives in primitive descriptors.
- *
- * 2. registrar_t -- the interface to a registry_t to book memory. Used at
- * primitive descriptor creation time only. Contains a
- * reference to the corresponding *mutable* registry.
- * Always modifiable.
- * Allows chaining (using prefixes).
- *
- * 3. grantor_t -- the interface to a registry_t to access memory. Used at
- * primitive execution time only. Contains a reference to
- * the corresponding *constant* registry and base pointer.
- * Always constant.
- * Allows chaining (using prefixes).
- *
- * Both registrar_t and grantor_t allow chaining with extra prefix provided.
- * The feature is useful when a primitive offload a part of computations to
- * some other primitives which require their own scratchpad space
- * (e.g. reducer). Prefixes are used to avoid key collision in cases when
- * multiple sub-primitive (e.g. multiple reducers) are used.
- *
- * A short example below demonstrates how to use aforementioned classes. In it
- * the main primitive is convolution that uses scratchpad for keeping padded
- * bias. It also needs a reducer, that needs its own space as well.
- *
- * ``` c++
- * struct reducer_t {
- * static void init(registrar_t &scratchpad) {
- * // preserve space for the reduction (one page aligned)
- * scratchpad.book(key_space, sizeof(float) * 980 * 1024, 4096);
- * }
- *
- * void exec(const grantor_t &scratchpad) {
- * // get the pointer to preserved space. scratchpad came from
- * // upper primitive (convolution in this example)
- * auto space = scratchpad.get<float>(key_reducer_space);
- *
- * space[:] += ...;
- * }
- * };
- *
- * struct conv_t {
- * struct pd_t {
- * void init() {
- * registrar_t scratchpad(scratchpad_registry_);
- *
- * // preserve a space for padded bias (using default alignment)
- * scratchpad.book(key_conv_padded_bias, 128);
- *
- * // create a proxy registrar for the reducer All entries made
- * // by reducer would live in convolution's registry, but would
- * // have their own `prefix`, so no interference with conv's
- * // buffers.
- * registrar_t reducer_scratchpad(scratchpad, prefix_reducer);
- *
- * reducer_t::init(reducer_scratchpad);
- * }
- *
- * registry_t scratchpad_registry_;
- * }
- *
- * void exec() {
- * // get the base pointer to a scratchpad memory from a user
- * void *scratchpad_ptr = this->input(MKLDNN_MEM_SCRATCHPAD);
- *
- * // create a grantor to the scratchpad (and provide the base
- * // pointer).
- * grantor_t scratchpad(pd()->scratchpad_registry_, scratchpad_ptr);
- *
- * // access the padded_bias (need only key name and the grantor)
- * auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
- *
- * // to give the `right` grantor to reducer we need to add the
- * // corresponding prefix, so that reducer would be able to access
- * // its keys. The call is very similar to the one in pd_t::init
- * // with only difference in types: grantor_t vs registrar_t.
- * grantor_t reducer_scratchpad(scratchpad, prefix_reducer);
- * reducer->exec(reducer_scratchpad);
- * }
- * };
- * ```
- */
-
-
-/* namespace with common keys and prefixes */
-namespace names {
-enum {
- key_none = 0,
- key_bnorm_tmp_mean,
- key_bnorm_tmp_var,
- key_bnorm_tmp_diff_ss,
- key_bnorm_tmp_stats,
- key_bnorm_reduction,
- key_concat_iptrs,
- key_concat_istrides,
- key_concat_nelems,
- key_concat_optrs,
- key_conv_adjusted_scales,
- key_conv_bia_reduction,
- key_conv_gemm_col,
- key_conv_gemm_imtr,
- key_conv_int_dat_in_acc_dt,
- key_conv_padded_bias,
- key_conv_rtus_space,
- key_conv_tr_diff_dst,
- key_conv_tr_diff_dst_bctx,
- key_conv_tr_src,
- key_conv_tr_src_bctx,
- key_conv_wei_reduction,
- key_conv_wei_bia_reduction,
- key_conv_wei_bia_reduction_bctx,
- key_iprod_int_dat_in_acc_dt,
- key_reducer_space,
- key_reducer_space_bctx,
- key_reorder_wino_plain,
- key_reorder_wino_transform_space,
- key_reorder_rnn_weights_quantization,
- key_reorder_rnn_weights_reduction,
- key_rnn_space,
- key_rnn_ptrs_bia,
- key_rnn_ptrs_wei_layer,
- key_rnn_ptrs_wei_iter,
- key_softmax_reduction,
- key_wino_U,
- key_wino_V,
- key_wino_M,
- key_barrier,
-};
-
-enum {
- prefix_none = 0,
- prefix_reducer_bia,
- prefix_reducer_wei,
-};
-}
-
-// level 0: 00 00 00 xxx
-// level 1: 00 00 aa xxx
-// level 2: 00 aa bb xxx
-// level 3: aa bb cc xxx
-// max # of levels: 3 + 1 (base_level)
-// here:
-// xxx : [1 .. MAX_KEY) : key
-// aa, bb, cc : [1 .. MAX_PREFIX) : prefixes for levels 1, 2, and 3
-
-using key_t = uint32_t;
-enum { MAX_KEY = (1u << 10), MAX_PREFIX = (1u << 7), };
-
-/// generates global key based on a prefix and a local key
-inline key_t make_key(key_t prefix, key_t key) { return prefix + key; }
-
-/// generates global prefix based on the global parent and the local ones
-inline key_t make_prefix(key_t parent_prefix, key_t prefix)
-{ return MAX_PREFIX * parent_prefix + MAX_KEY * prefix; }
-
-struct registrar_t;
-struct grantor_t;
-
-struct registry_t {
- void book(const key_t &key, size_t size, size_t alignment) {
- if (size == 0) return;
- assert(offset_map_.count(key) == 0);
-
- size = utils::rnd_up(size, minimal_alignment);
- alignment = nstl::max<size_t>(alignment, minimal_alignment);
- offset_map_[key] = entry_t{size_, size, alignment};
-
- size_ += size + alignment - minimal_alignment;
- }
-
- void *get(const key_t &key, void *base_ptr) const {
- if (base_ptr == nullptr) { assert(size() == 0); return nullptr; }
- if (offset_map_.count(key) != 1) return nullptr;
-
- const auto &e = offset_map_.at(key);
- base_ptr = utils::align_ptr<void>(base_ptr, minimal_alignment);
- char *ptr = (char *)base_ptr + e.offset;
- return utils::align_ptr<void>(ptr, e.alignment);
- }
-
- size_t size() const
- { return size_ > 0 ? size_ + minimal_alignment - 1 : 0; }
-
- registrar_t registrar();
- grantor_t grantor(void *base_ptr) const;
-
-protected:
- enum { minimal_alignment = 64 };
- struct entry_t { size_t offset, size, alignment; };
-
- std::unordered_map<key_t, entry_t> offset_map_;
- size_t size_ = 0;
-};
-
-struct registrar_t {
- enum { default_alignment = 64 };
-
- registrar_t(registry_t &registry): registry_(registry), prefix_(0) {}
- registrar_t(registrar_t &parent, const key_t &prefix)
- : registry_(parent.registry_)
- , prefix_(make_prefix(parent.prefix_, prefix)) {}
-
- void book(const key_t &key, size_t size,
- size_t alignment = default_alignment)
- { registry_.book(make_key(prefix_, key), size, alignment); }
-
-protected:
- registry_t &registry_;
- const key_t prefix_;
-};
-
-struct grantor_t {
- grantor_t(const registry_t &registry, void *base_ptr)
- : registry_(registry), prefix_(0), base_ptr_(base_ptr) {}
- grantor_t(const grantor_t &parent, const key_t &prefix)
- : registry_(parent.registry_)
- , prefix_(make_prefix(parent.prefix_, prefix))
- , base_ptr_(parent.base_ptr_) {}
-
- template <typename T = void> T *get(const key_t &key) const
- { return (T *)registry_.get(make_key(prefix_, key), base_ptr_); }
-
-protected:
- const registry_t &registry_;
- const key_t prefix_;
- void *base_ptr_;
-};
-
-inline registrar_t registry_t::registrar() { return registrar_t(*this); }
-inline grantor_t registry_t::grantor(void *base_ptr) const
-{ return grantor_t(*this, base_ptr); }
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp
deleted file mode 100644
index 2ef4a8fddc..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp
+++ /dev/null
@@ -1,131 +0,0 @@
-/*******************************************************************************
-* Copyright 2019 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include <stdio.h>
-#include <cinttypes>
-
-#include "mkldnn_debug.h"
-#include "mkldnn_types.h"
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#define DPRINT(...) do { \
- int l = snprintf(str + written_len, str_len, __VA_ARGS__); \
- if (l < 0) return l; \
- if ((size_t)l >= str_len) return -1; \
- written_len += l; str_len -= l; \
-} while(0)
-
-int mkldnn_md2fmt_str(char *str, size_t str_len,
- const mkldnn_memory_desc_t *mdesc) {
- using namespace mkldnn::impl;
-
- if (str == nullptr || str_len <= 1u)
- return -1;
-
- int written_len = 0;
-
- if (mdesc == nullptr) {
- DPRINT("%s::%s::",
- mkldnn_dt2str(data_type::undef),
- mkldnn_fmt_kind2str(format_kind::undef));
- return written_len;
- }
-
- memory_desc_wrapper md(mdesc);
-
- DPRINT("%s:", mkldnn_dt2str(md.data_type()));
-
- bool padded_dims = false, padded_offsets = false;
- for (int d = 0; d < md.ndims(); ++d) {
- if (md.dims()[d] != md.padded_dims()[d]) padded_dims = true;
- if (md.padded_offsets()[d] != 0) padded_offsets = true;
- }
- bool offset0 = md.offset0();
- DPRINT("%s%s%s:",
- padded_dims ? "p" : "",
- padded_offsets ? "o" : "",
- offset0 ? "0" : "");
-
- DPRINT("%s:", mkldnn_fmt_kind2str(md.format_kind()));
-
- if (!md.is_blocking_desc()) {
- /* TODO: extend */
- DPRINT("%s:", "");
- } else {
- const auto &blk = md.blocking_desc();
-
- dims_t blocks;
- md.compute_blocks(blocks);
-
- char dim_chars[MKLDNN_MAX_NDIMS + 1];
-
- bool plain = true;
- for (int d = 0; d < md.ndims(); ++d) {
- dim_chars[d] = (blocks[d] == 1 ? 'a' : 'A') + (char)d;
- if (blocks[d] != 1) plain = false;
- }
-
- dims_t strides;
- utils::array_copy(strides, blk.strides, md.ndims());
- utils::simultaneous_sort(strides, dim_chars, md.ndims(),
- [](dim_t a, dim_t b) { return b - a; });
-
- dim_chars[md.ndims()] = '\0';
- DPRINT("%s", dim_chars);
-
- if (!plain) {
- for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) {
- DPRINT("%d%c", (int)blk.inner_blks[iblk],
- 'a' + (char)blk.inner_idxs[iblk]);
- }
- }
-
- DPRINT("%s", ":");
- }
-
- DPRINT("f%lx", (long)md.extra().flags);
-
- return written_len;
-}
-
-int mkldnn_md2dim_str(char *str, size_t str_len,
- const mkldnn_memory_desc_t *mdesc) {
- using namespace mkldnn::impl;
-
- if (str == nullptr || str_len <= 1)
- return -1;
-
- int written_len = 0;
-
- if (mdesc == nullptr || mdesc->ndims == 0) {
- DPRINT("%s", "");
- return written_len;
- }
-
- memory_desc_wrapper md(mdesc);
-
- for (int d = 0; d < md.ndims() - 1; ++d)
- DPRINT("%" PRId64 "x", md.dims()[d]);
- DPRINT("%" PRId64, md.dims()[md.ndims() - 1]);
-
- return written_len;
-}
-
-#undef DPRINT
diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp
deleted file mode 100644
index 16a8f7ea5e..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp
+++ /dev/null
@@ -1,365 +0,0 @@
-/*******************************************************************************
-* Copyright 2018-2019 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-/* DO NOT EDIT, AUTO-GENERATED */
-
-#include <assert.h>
-
-#include "mkldnn_debug.h"
-#include "mkldnn_types.h"
-
-const char *mkldnn_status2str(mkldnn_status_t v) {
- if (v == mkldnn_success) return "success";
- if (v == mkldnn_out_of_memory) return "out_of_memory";
- if (v == mkldnn_try_again) return "try_again";
- if (v == mkldnn_invalid_arguments) return "invalid_arguments";
- if (v == mkldnn_not_ready) return "not_ready";
- if (v == mkldnn_unimplemented) return "unimplemented";
- if (v == mkldnn_iterator_ends) return "iterator_ends";
- if (v == mkldnn_runtime_error) return "runtime_error";
- if (v == mkldnn_not_required) return "not_required";
- assert(!"unknown status");
- return "unknown status";
-}
-
-const char *mkldnn_dt2str(mkldnn_data_type_t v) {
- if (v == mkldnn_data_type_undef) return "undef";
- if (v == mkldnn_f32) return "f32";
- if (v == mkldnn_s32) return "s32";
- if (v == mkldnn_s8) return "s8";
- if (v == mkldnn_u8) return "u8";
- assert(!"unknown dt");
- return "unknown dt";
-}
-
-const char *mkldnn_fmt_kind2str(mkldnn_format_kind_t v) {
- if (v == mkldnn_format_kind_undef) return "undef";
- if (v == mkldnn_format_kind_any) return "any";
- if (v == mkldnn_blocked) return "blocked";
- if (v == mkldnn_format_kind_wino) return "wino";
- if (v == mkldnn_format_kind_rnn_packed) return "rnn_packed";
- assert(!"unknown fmt_kind");
- return "unknown fmt_kind";
-}
-
-const char *mkldnn_fmt_tag2str(mkldnn_format_tag_t v) {
- if (v == mkldnn_format_tag_undef) return "undef";
- if (v == mkldnn_format_tag_any) return "format_tag_any";
- if (v == mkldnn_a) return "a";
- if (v == mkldnn_ab) return "ab";
- if (v == mkldnn_abc) return "abc";
- if (v == mkldnn_abcd) return "abcd";
- if (v == mkldnn_abcde) return "abcde";
- if (v == mkldnn_abcdef) return "abcdef";
- if (v == mkldnn_abdec) return "abdec";
- if (v == mkldnn_acb) return "acb";
- if (v == mkldnn_acbde) return "acbde";
- if (v == mkldnn_acdb) return "acdb";
- if (v == mkldnn_acdeb) return "acdeb";
- if (v == mkldnn_ba) return "ba";
- if (v == mkldnn_bac) return "bac";
- if (v == mkldnn_bacd) return "bacd";
- if (v == mkldnn_bcda) return "bcda";
- if (v == mkldnn_cba) return "cba";
- if (v == mkldnn_cdba) return "cdba";
- if (v == mkldnn_cdeba) return "cdeba";
- if (v == mkldnn_decab) return "decab";
- if (v == mkldnn_Abc16a) return "Abc16a";
- if (v == mkldnn_ABc16a16b) return "ABc16a16b";
- if (v == mkldnn_aBc16b) return "aBc16b";
- if (v == mkldnn_ABc16b16a) return "ABc16b16a";
- if (v == mkldnn_Abc4a) return "Abc4a";
- if (v == mkldnn_aBc4b) return "aBc4b";
- if (v == mkldnn_ABc4b16a4b) return "ABc4b16a4b";
- if (v == mkldnn_ABc4b4a) return "ABc4b4a";
- if (v == mkldnn_ABc8a16b2a) return "ABc8a16b2a";
- if (v == mkldnn_ABc8a8b) return "ABc8a8b";
- if (v == mkldnn_aBc8b) return "aBc8b";
- if (v == mkldnn_ABc8b16a2b) return "ABc8b16a2b";
- if (v == mkldnn_ABc8b8a) return "ABc8b8a";
- if (v == mkldnn_Abcd16a) return "Abcd16a";
- if (v == mkldnn_ABcd16a16b) return "ABcd16a16b";
- if (v == mkldnn_aBcd16b) return "aBcd16b";
- if (v == mkldnn_ABcd16b16a) return "ABcd16b16a";
- if (v == mkldnn_aBCd16b16c) return "aBCd16b16c";
- if (v == mkldnn_aBCd16c16b) return "aBCd16c16b";
- if (v == mkldnn_Abcd4a) return "Abcd4a";
- if (v == mkldnn_aBcd4b) return "aBcd4b";
- if (v == mkldnn_ABcd4b16a4b) return "ABcd4b16a4b";
- if (v == mkldnn_ABcd4b4a) return "ABcd4b4a";
- if (v == mkldnn_aBCd4c16b4c) return "aBCd4c16b4c";
- if (v == mkldnn_aBCd4c4b) return "aBCd4c4b";
- if (v == mkldnn_ABcd8a16b2a) return "ABcd8a16b2a";
- if (v == mkldnn_ABcd8a8b) return "ABcd8a8b";
- if (v == mkldnn_aBcd8b) return "aBcd8b";
- if (v == mkldnn_ABcd8b16a2b) return "ABcd8b16a2b";
- if (v == mkldnn_aBCd8b16c2b) return "aBCd8b16c2b";
- if (v == mkldnn_ABcd8b8a) return "ABcd8b8a";
- if (v == mkldnn_aBCd8b8c) return "aBCd8b8c";
- if (v == mkldnn_aBCd8c16b2c) return "aBCd8c16b2c";
- if (v == mkldnn_aBCd8c8b) return "aBCd8c8b";
- if (v == mkldnn_Abcde16a) return "Abcde16a";
- if (v == mkldnn_ABcde16a16b) return "ABcde16a16b";
- if (v == mkldnn_aBcde16b) return "aBcde16b";
- if (v == mkldnn_ABcde16b16a) return "ABcde16b16a";
- if (v == mkldnn_aBCde16b16c) return "aBCde16b16c";
- if (v == mkldnn_aBCde16c16b) return "aBCde16c16b";
- if (v == mkldnn_aBCde2c8b4c) return "aBCde2c8b4c";
- if (v == mkldnn_Abcde4a) return "Abcde4a";
- if (v == mkldnn_aBcde4b) return "aBcde4b";
- if (v == mkldnn_ABcde4b4a) return "ABcde4b4a";
- if (v == mkldnn_aBCde4b4c) return "aBCde4b4c";
- if (v == mkldnn_aBCde4c16b4c) return "aBCde4c16b4c";
- if (v == mkldnn_aBCde4c4b) return "aBCde4c4b";
- if (v == mkldnn_Abcde8a) return "Abcde8a";
- if (v == mkldnn_ABcde8a8b) return "ABcde8a8b";
- if (v == mkldnn_ABcde8b16a2b) return "ABcde8b16a2b";
- if (v == mkldnn_aBCde8b16c2b) return "aBCde8b16c2b";
- if (v == mkldnn_ABcde8b8a) return "ABcde8b8a";
- if (v == mkldnn_aBCde8b8c) return "aBCde8b8c";
- if (v == mkldnn_aBCde8c16b2c) return "aBCde8c16b2c";
- if (v == mkldnn_aBCde8c8b) return "aBCde8c8b";
- if (v == mkldnn_aBcdef16b) return "aBcdef16b";
- if (v == mkldnn_aBCdef16b16c) return "aBCdef16b16c";
- if (v == mkldnn_aBCdef16c16b) return "aBCdef16c16b";
- if (v == mkldnn_aBcdef4b) return "aBcdef4b";
- if (v == mkldnn_aBCdef4c4b) return "aBCdef4c4b";
- if (v == mkldnn_aBCdef8b8c) return "aBCdef8b8c";
- if (v == mkldnn_aBCdef8c16b2c) return "aBCdef8c16b2c";
- if (v == mkldnn_aBCdef8c8b) return "aBCdef8c8b";
- if (v == mkldnn_aBdc16b) return "aBdc16b";
- if (v == mkldnn_aBdc4b) return "aBdc4b";
- if (v == mkldnn_aBdc8b) return "aBdc8b";
- if (v == mkldnn_aBdec16b) return "aBdec16b";
- if (v == mkldnn_aBdec4b) return "aBdec4b";
- if (v == mkldnn_aBdec8b) return "aBdec8b";
- if (v == mkldnn_aBdefc16b) return "aBdefc16b";
- if (v == mkldnn_aBdefc4b) return "aBdefc4b";
- if (v == mkldnn_aBdefc8b) return "aBdefc8b";
- if (v == mkldnn_Acb16a) return "Acb16a";
- if (v == mkldnn_Acb4a) return "Acb4a";
- if (v == mkldnn_Acb8a) return "Acb8a";
- if (v == mkldnn_aCBd16b16c) return "aCBd16b16c";
- if (v == mkldnn_aCBde16b16c) return "aCBde16b16c";
- if (v == mkldnn_Acdb16a) return "Acdb16a";
- if (v == mkldnn_Acdb4a) return "Acdb4a";
- if (v == mkldnn_Acdb8a) return "Acdb8a";
- if (v == mkldnn_Acdeb16a) return "Acdeb16a";
- if (v == mkldnn_Acdeb4a) return "Acdeb4a";
- if (v == mkldnn_Acdeb8a) return "Acdeb8a";
- if (v == mkldnn_BAc16a16b) return "BAc16a16b";
- if (v == mkldnn_BAcd16a16b) return "BAcd16a16b";
- if (v == mkldnn_format_tag_last) return "format_tag_last";
- if (v == mkldnn_x) return "x";
- if (v == mkldnn_nc) return "nc";
- if (v == mkldnn_cn) return "cn";
- if (v == mkldnn_ncw) return "ncw";
- if (v == mkldnn_nwc) return "nwc";
- if (v == mkldnn_nchw) return "nchw";
- if (v == mkldnn_nhwc) return "nhwc";
- if (v == mkldnn_chwn) return "chwn";
- if (v == mkldnn_ncdhw) return "ncdhw";
- if (v == mkldnn_ndhwc) return "ndhwc";
- if (v == mkldnn_oi) return "oi";
- if (v == mkldnn_io) return "io";
- if (v == mkldnn_oiw) return "oiw";
- if (v == mkldnn_wio) return "wio";
- if (v == mkldnn_oihw) return "oihw";
- if (v == mkldnn_hwio) return "hwio";
- if (v == mkldnn_ihwo) return "ihwo";
- if (v == mkldnn_iohw) return "iohw";
- if (v == mkldnn_oidhw) return "oidhw";
- if (v == mkldnn_dhwio) return "dhwio";
- if (v == mkldnn_goiw) return "goiw";
- if (v == mkldnn_goihw) return "goihw";
- if (v == mkldnn_hwigo) return "hwigo";
- if (v == mkldnn_giohw) return "giohw";
- if (v == mkldnn_goidhw) return "goidhw";
- if (v == mkldnn_tnc) return "tnc";
- if (v == mkldnn_ntc) return "ntc";
- if (v == mkldnn_ldsnc) return "ldsnc";
- if (v == mkldnn_ldigo) return "ldigo";
- if (v == mkldnn_ldgoi) return "ldgoi";
- if (v == mkldnn_ldgo) return "ldgo";
- if (v == mkldnn_nCdhw16c) return "nCdhw16c";
- if (v == mkldnn_nCdhw4c) return "nCdhw4c";
- if (v == mkldnn_nCdhw8c) return "nCdhw8c";
- if (v == mkldnn_nChw16c) return "nChw16c";
- if (v == mkldnn_nChw4c) return "nChw4c";
- if (v == mkldnn_nChw8c) return "nChw8c";
- if (v == mkldnn_nCw16c) return "nCw16c";
- if (v == mkldnn_nCw4c) return "nCw4c";
- if (v == mkldnn_nCw8c) return "nCw8c";
- if (v == mkldnn_IOw16o16i) return "IOw16o16i";
- if (v == mkldnn_OIw16i16o) return "OIw16i16o";
- if (v == mkldnn_OIw16o16i) return "OIw16o16i";
- if (v == mkldnn_Oiw16o) return "Oiw16o";
- if (v == mkldnn_OIw4i16o4i) return "OIw4i16o4i";
- if (v == mkldnn_OIw4i4o) return "OIw4i4o";
- if (v == mkldnn_Oiw4o) return "Oiw4o";
- if (v == mkldnn_OIw8i16o2i) return "OIw8i16o2i";
- if (v == mkldnn_OIw8i8o) return "OIw8i8o";
- if (v == mkldnn_OIw8o16i2o) return "OIw8o16i2o";
- if (v == mkldnn_OIw8o8i) return "OIw8o8i";
- if (v == mkldnn_Owi16o) return "Owi16o";
- if (v == mkldnn_Owi4o) return "Owi4o";
- if (v == mkldnn_Owi8o) return "Owi8o";
- if (v == mkldnn_IOhw16o16i) return "IOhw16o16i";
- if (v == mkldnn_Ohwi16o) return "Ohwi16o";
- if (v == mkldnn_Ohwi4o) return "Ohwi4o";
- if (v == mkldnn_Ohwi8o) return "Ohwi8o";
- if (v == mkldnn_OIhw16i16o) return "OIhw16i16o";
- if (v == mkldnn_OIhw16o16i) return "OIhw16o16i";
- if (v == mkldnn_Oihw16o) return "Oihw16o";
- if (v == mkldnn_OIhw4i16o4i) return "OIhw4i16o4i";
- if (v == mkldnn_OIhw4i4o) return "OIhw4i4o";
- if (v == mkldnn_Oihw4o) return "Oihw4o";
- if (v == mkldnn_OIhw8i16o2i) return "OIhw8i16o2i";
- if (v == mkldnn_OIhw8i8o) return "OIhw8i8o";
- if (v == mkldnn_OIhw8o16i2o) return "OIhw8o16i2o";
- if (v == mkldnn_OIhw8o8i) return "OIhw8o8i";
- if (v == mkldnn_Odhwi16o) return "Odhwi16o";
- if (v == mkldnn_Odhwi4o) return "Odhwi4o";
- if (v == mkldnn_Odhwi8o) return "Odhwi8o";
- if (v == mkldnn_OIdhw16i16o) return "OIdhw16i16o";
- if (v == mkldnn_OIdhw16o16i) return "OIdhw16o16i";
- if (v == mkldnn_Oidhw16o) return "Oidhw16o";
- if (v == mkldnn_OIdhw4i4o) return "OIdhw4i4o";
- if (v == mkldnn_Oidhw4o) return "Oidhw4o";
- if (v == mkldnn_OIdhw8i16o2i) return "OIdhw8i16o2i";
- if (v == mkldnn_OIdhw8i8o) return "OIdhw8i8o";
- if (v == mkldnn_OIdhw8o8i) return "OIdhw8o8i";
- if (v == mkldnn_Goiw16g) return "Goiw16g";
- if (v == mkldnn_gIOw16o16i) return "gIOw16o16i";
- if (v == mkldnn_gOIw16i16o) return "gOIw16i16o";
- if (v == mkldnn_gOIw16o16i) return "gOIw16o16i";
- if (v == mkldnn_gOiw16o) return "gOiw16o";
- if (v == mkldnn_gOIw4i16o4i) return "gOIw4i16o4i";
- if (v == mkldnn_gOIw4i4o) return "gOIw4i4o";
- if (v == mkldnn_gOiw4o) return "gOiw4o";
- if (v == mkldnn_gOIw8i16o2i) return "gOIw8i16o2i";
- if (v == mkldnn_gOIw8i8o) return "gOIw8i8o";
- if (v == mkldnn_gOIw8o16i2o) return "gOIw8o16i2o";
- if (v == mkldnn_gOIw8o8i) return "gOIw8o8i";
- if (v == mkldnn_gOwi16o) return "gOwi16o";
- if (v == mkldnn_gOwi4o) return "gOwi4o";
- if (v == mkldnn_gOwi8o) return "gOwi8o";
- if (v == mkldnn_gIOhw16o16i) return "gIOhw16o16i";
- if (v == mkldnn_gOhwi16o) return "gOhwi16o";
- if (v == mkldnn_gOhwi4o) return "gOhwi4o";
- if (v == mkldnn_gOhwi8o) return "gOhwi8o";
- if (v == mkldnn_Goihw16g) return "Goihw16g";
- if (v == mkldnn_gOIhw16i16o) return "gOIhw16i16o";
- if (v == mkldnn_gOIhw16o16i) return "gOIhw16o16i";
- if (v == mkldnn_gOihw16o) return "gOihw16o";
- if (v == mkldnn_gOIhw2i8o4i) return "gOIhw2i8o4i";
- if (v == mkldnn_gOIhw4i16o4i) return "gOIhw4i16o4i";
- if (v == mkldnn_gOIhw4i4o) return "gOIhw4i4o";
- if (v == mkldnn_gOIhw4o4i) return "gOIhw4o4i";
- if (v == mkldnn_gOihw4o) return "gOihw4o";
- if (v == mkldnn_Goihw8g) return "Goihw8g";
- if (v == mkldnn_gOIhw8i16o2i) return "gOIhw8i16o2i";
- if (v == mkldnn_gOIhw8i8o) return "gOIhw8i8o";
- if (v == mkldnn_gOIhw8o16i2o) return "gOIhw8o16i2o";
- if (v == mkldnn_gOIhw8o8i) return "gOIhw8o8i";
- if (v == mkldnn_gOdhwi16o) return "gOdhwi16o";
- if (v == mkldnn_gOdhwi4o) return "gOdhwi4o";
- if (v == mkldnn_gOdhwi8o) return "gOdhwi8o";
- if (v == mkldnn_gOIdhw16i16o) return "gOIdhw16i16o";
- if (v == mkldnn_gOIdhw16o16i) return "gOIdhw16o16i";
- if (v == mkldnn_gOidhw16o) return "gOidhw16o";
- if (v == mkldnn_gOIdhw4i4o) return "gOIdhw4i4o";
- if (v == mkldnn_gOidhw4o) return "gOidhw4o";
- if (v == mkldnn_gOIdhw8i16o2i) return "gOIdhw8i16o2i";
- if (v == mkldnn_gOIdhw8i8o) return "gOIdhw8i8o";
- if (v == mkldnn_gOIdhw8o8i) return "gOIdhw8o8i";
- assert(!"unknown fmt_tag");
- return "unknown fmt_tag";
-}
-
-const char *mkldnn_prop_kind2str(mkldnn_prop_kind_t v) {
- if (v == mkldnn_prop_kind_undef) return "undef";
- if (v == mkldnn_forward_training) return "forward_training";
- if (v == mkldnn_forward_inference) return "forward_inference";
- if (v == mkldnn_forward_scoring) return "forward_scoring";
- if (v == mkldnn_forward) return "forward";
- if (v == mkldnn_backward) return "backward";
- if (v == mkldnn_backward_data) return "backward_data";
- if (v == mkldnn_backward_weights) return "backward_weights";
- if (v == mkldnn_backward_bias) return "backward_bias";
- assert(!"unknown prop_kind");
- return "unknown prop_kind";
-}
-
-const char *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v) {
- if (v == mkldnn_undefined_primitive) return "undef";
- if (v == mkldnn_reorder) return "reorder";
- if (v == mkldnn_shuffle) return "shuffle";
- if (v == mkldnn_concat) return "concat";
- if (v == mkldnn_sum) return "sum";
- if (v == mkldnn_convolution) return "convolution";
- if (v == mkldnn_deconvolution) return "deconvolution";
- if (v == mkldnn_eltwise) return "eltwise";
- if (v == mkldnn_softmax) return "softmax";
- if (v == mkldnn_pooling) return "pooling";
- if (v == mkldnn_lrn) return "lrn";
- if (v == mkldnn_batch_normalization) return "batch_normalization";
- if (v == mkldnn_inner_product) return "inner_product";
- if (v == mkldnn_rnn) return "rnn";
- assert(!"unknown prim_kind");
- return "unknown prim_kind";
-}
-
-const char *mkldnn_alg_kind2str(mkldnn_alg_kind_t v) {
- if (v == mkldnn_alg_kind_undef) return "undef";
- if (v == mkldnn_convolution_direct) return "convolution_direct";
- if (v == mkldnn_convolution_winograd) return "convolution_winograd";
- if (v == mkldnn_convolution_auto) return "convolution_auto";
- if (v == mkldnn_deconvolution_direct) return "deconvolution_direct";
- if (v == mkldnn_deconvolution_winograd) return "deconvolution_winograd";
- if (v == mkldnn_eltwise_relu) return "eltwise_relu";
- if (v == mkldnn_eltwise_tanh) return "eltwise_tanh";
- if (v == mkldnn_eltwise_elu) return "eltwise_elu";
- if (v == mkldnn_eltwise_square) return "eltwise_square";
- if (v == mkldnn_eltwise_abs) return "eltwise_abs";
- if (v == mkldnn_eltwise_sqrt) return "eltwise_sqrt";
- if (v == mkldnn_eltwise_linear) return "eltwise_linear";
- if (v == mkldnn_eltwise_bounded_relu) return "eltwise_bounded_relu";
- if (v == mkldnn_eltwise_soft_relu) return "eltwise_soft_relu";
- if (v == mkldnn_eltwise_logistic) return "eltwise_logistic";
- if (v == mkldnn_pooling_max) return "pooling_max";
- if (v == mkldnn_pooling_avg_include_padding) return "pooling_avg_include_padding";
- if (v == mkldnn_pooling_avg_exclude_padding) return "pooling_avg_exclude_padding";
- if (v == mkldnn_pooling_avg) return "pooling_avg";
- if (v == mkldnn_lrn_across_channels) return "lrn_across_channels";
- if (v == mkldnn_lrn_within_channel) return "lrn_within_channel";
- if (v == mkldnn_vanilla_rnn) return "vanilla_rnn";
- if (v == mkldnn_vanilla_lstm) return "vanilla_lstm";
- if (v == mkldnn_vanilla_gru) return "vanilla_gru";
- if (v == mkldnn_gru_linear_before_reset) return "gru_linear_before_reset";
- assert(!"unknown alg_kind");
- return "unknown alg_kind";
-}
-
-const char *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v) {
- if (v == mkldnn_unidirectional_left2right) return "unidirectional_left2right";
- if (v == mkldnn_unidirectional_right2left) return "unidirectional_right2left";
- if (v == mkldnn_bidirectional_concat) return "bidirectional_concat";
- if (v == mkldnn_bidirectional_sum) return "bidirectional_sum";
- if (v == mkldnn_unidirectional) return "unidirectional";
- assert(!"unknown rnn_direction");
- return "unknown rnn_direction";
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp
deleted file mode 100644
index 7e5789e2c3..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp
+++ /dev/null
@@ -1,115 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef MKLDNN_THREAD_HPP
-#define MKLDNN_THREAD_HPP
-
-#include "utils.hpp"
-#include "z_magic.hpp"
-
-#define MKLDNN_THR_SEQ 0
-#define MKLDNN_THR_OMP 1
-#define MKLDNN_THR_TBB 2
-
-/* Ideally this condition below should never happen (if the library is built
- * using regular cmake). For the 3rd-party projects that build the library
- * from the sources on their own try to guess the right threading... */
-#if !defined(MKLDNN_THR)
-# define MKLDNN_THR MKLDNN_THR_TBB
-#endif
-
-#if MKLDNN_THR == MKLDNN_THR_SEQ
-#define MKLDNN_THR_SYNC 1
-inline int mkldnn_get_max_threads() { return 1; }
-inline int mkldnn_get_num_threads() { return 1; }
-inline int mkldnn_get_thread_num() { return 0; }
-inline int mkldnn_in_parallel() { return 0; }
-inline void mkldnn_thr_barrier() {}
-
-#define PRAGMA_OMP(...)
-
-#elif MKLDNN_THR == MKLDNN_THR_OMP
-#include <omp.h>
-#define MKLDNN_THR_SYNC 1
-
-inline int mkldnn_get_max_threads() { return omp_get_max_threads(); }
-inline int mkldnn_get_num_threads() { return omp_get_num_threads(); }
-inline int mkldnn_get_thread_num() { return omp_get_thread_num(); }
-inline int mkldnn_in_parallel() { return omp_in_parallel(); }
-inline void mkldnn_thr_barrier() {
-# pragma omp barrier
-}
-
-#define PRAGMA_OMP(...) PRAGMA_MACRO(CHAIN2(omp, __VA_ARGS__))
-
-#elif MKLDNN_THR == MKLDNN_THR_TBB
-#include "tbb/task_arena.h"
-#include "tbb/parallel_for.h"
-#define MKLDNN_THR_SYNC 0
-
-inline int mkldnn_get_max_threads()
-{ return tbb::this_task_arena::max_concurrency(); }
-inline int mkldnn_get_num_threads() { return mkldnn_get_max_threads(); }
-inline int mkldnn_get_thread_num()
-{ return tbb::this_task_arena::current_thread_index(); }
-inline int mkldnn_in_parallel() { return 0; }
-inline void mkldnn_thr_barrier() { assert(!"no barrier in TBB"); }
-
-#define PRAGMA_OMP(...)
-
-#endif
-
-/* MSVC still supports omp 2.0 only */
-#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER)
-# define collapse(x)
-# define PRAGMA_OMP_SIMD(...)
-#else
-# define PRAGMA_OMP_SIMD(...) PRAGMA_MACRO(CHAIN2(omp, simd __VA_ARGS__))
-#endif // defined(_MSC_VER) && !defined(__INTEL_COMPILER)
-
-namespace mkldnn {
-namespace impl {
-
-inline bool mkldnn_thr_syncable() { return MKLDNN_THR_SYNC == 1; }
-
-template <typename T, typename U>
-inline void balance211(T n, U team, U tid, T &n_start, T &n_end) {
- T n_min = 1;
- T &n_my = n_end;
- if (team <= 1 || n == 0) {
- n_start = 0;
- n_my = n;
- } else if (n_min == 1) {
- // team = T1 + T2
- // n = T1*n1 + T2*n2 (n1 - n2 = 1)
- T n1 = utils::div_up(n, (T)team);
- T n2 = n1 - 1;
- T T1 = n - n2 * (T)team;
- n_my = (T)tid < T1 ? n1 : n2;
- n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2;
- }
-
- n_end += n_start;
-}
-
-} // namespace impl
-} // namespace mkldnn
-
-#include "mkldnn_thread_parallel_nd.hpp"
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp
deleted file mode 100644
index 50f9b29622..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp
+++ /dev/null
@@ -1,277 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef MKLDNN_THREAD_PARALLEL_ND_HPP
-#define MKLDNN_THREAD_PARALLEL_ND_HPP
-
-/* This header must be included by mkldnn_thread.hpp only */
-
-/* Functions:
- * - parallel(nthr, f) - executes f in parallel using at most
- * nthr threads. If nthr equals 0
- * mkldnn_get_max_threads() threads is
- * used
- * - for_nd(ithr, nthr, dims..., f) - multidimensional for loop for already
- * created threads
- * - parallel_nd(dims..., f) - creates a parallel section and then
- * calls for_nd
- * - parallel_nd_in_omp(dims..., f) - queries current nthr and ithr and then
- * calls for_nd (mostly for convenience)
- */
-
-namespace mkldnn {
-namespace impl {
-
-/* general parallelization */
-template <typename F>
-void parallel(int nthr, F f) {
- if (nthr == 0) nthr = mkldnn_get_max_threads();
-#if MKLDNN_THR == MKLDNN_THR_SEQ
- assert(nthr == 1);
- f(0, 1);
-#elif MKLDNN_THR == MKLDNN_THR_OMP
- if (nthr == 1) { f(0, 1); return; }
-# pragma omp parallel num_threads(nthr)
- f(mkldnn_get_thread_num(), mkldnn_get_num_threads());
-#elif MKLDNN_THR == MKLDNN_THR_TBB
- if (nthr == 1) { f(0, 1); return; }
- tbb::parallel_for(0, nthr, [&](int ithr) { f(ithr, nthr); }, tbb::static_partitioner());
-#endif
-}
-
-/* for_nd section */
-
-template <typename T0, typename F>
-void for_nd(const int ithr, const int nthr, const T0 &D0, F f) {
- T0 start{0}, end{0};
- balance211(D0, nthr, ithr, start, end);
- for (T0 d0 = start; d0 < end; ++d0) f(d0);
-}
-
-template <typename T0, typename T1, typename F>
-void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, F f) {
- const size_t work_amount = (size_t)D0 * D1;
- if (work_amount == 0) return;
- size_t start{0}, end{0};
- balance211(work_amount, nthr, ithr, start, end);
-
- T0 d0{0}; T1 d1{0};
- utils::nd_iterator_init(start, d0, D0, d1, D1);
- for (size_t iwork = start; iwork < end; ++iwork) {
- f(d0, d1);
- utils::nd_iterator_step(d0, D0, d1, D1);
- }
-}
-
-template <typename T0, typename T1, typename T2, typename F>
-void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
- const T2 &D2, F f) {
- const size_t work_amount = (size_t)D0 * D1 * D2;
- if (work_amount == 0) return;
- size_t start{0}, end{0};
- balance211(work_amount, nthr, ithr, start, end);
-
- T0 d0{0}; T1 d1{0}; T2 d2{0};
- utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2);
- for (size_t iwork = start; iwork < end; ++iwork) {
- f(d0, d1, d2);
- utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
- }
-}
-
-template <typename T0, typename T1, typename T2, typename T3, typename F>
-void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
- const T2 &D2, const T3 &D3, F f) {
- const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
- if (work_amount == 0) return;
- size_t start{0}, end{0};
- balance211(work_amount, nthr, ithr, start, end);
-
- T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0};
- utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3);
- for (size_t iwork = start; iwork < end; ++iwork) {
- f(d0, d1, d2, d3);
- utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
- }
-}
-
-template <typename T0, typename T1, typename T2, typename T3, typename T4,
- typename F>
-void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
- const T2 &D2, const T3 &D3, const T4 &D4, F f) {
- const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4;
- if (work_amount == 0) return;
- size_t start{0}, end{0};
- balance211(work_amount, nthr, ithr, start, end);
-
- T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0};
- utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
- for (size_t iwork = start; iwork < end; ++iwork) {
- f(d0, d1, d2, d3, d4);
- utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
- }
-}
-
-template <typename T0, typename T1, typename T2, typename T3, typename T4,
- typename T5, typename F>
-void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
- const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) {
- const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
- if (work_amount == 0) return;
- size_t start{0}, end{0};
- balance211(work_amount, nthr, ithr, start, end);
-
- T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0};
- utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4,
- d5, D5);
- for (size_t iwork = start; iwork < end; ++iwork) {
- f(d0, d1, d2, d3, d4, d5);
- utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
- }
-}
-
-// Skip a lambda function in the parameter pack.
-template <typename T>
-constexpr size_t get_work_amount(const T &v) { return 1; }
-template <typename T, typename ...Args>
-constexpr size_t get_work_amount(const T &v, Args &&...args)
-{ return (size_t)v * get_work_amount(utils::forward<Args>(args)...); }
-
-/* parallel_nd and parallel_nd_in_omp section */
-
-#if MKLDNN_THR != MKLDNN_THR_TBB
-template <typename ...Args>
-void parallel_nd(Args &&...args) {
-#if MKLDNN_THR == MKLDNN_THR_SEQ
- for_nd(0, 1, utils::forward<Args>(args)...);
-#elif MKLDNN_THR == MKLDNN_THR_OMP
- const bool do_parallel = get_work_amount(utils::forward<Args>(args)...) > 1;
-# pragma omp parallel if (do_parallel)
- {
- const int nthr = !do_parallel ? 1 : mkldnn_get_num_threads();
- const int ithr = !do_parallel ? 0 : mkldnn_get_thread_num();
- for_nd(ithr, nthr, utils::forward<Args>(args)...);
- }
-#endif
-}
-#else // MKLDNN_THR != MKLDNN_THR_TBB
-
-// gcc 4.8 has a bug with passing parameter pack to lambdas.
-// So have to explicitly instantiate all the cases.
-
-template <typename T0, typename F>
-void parallel_nd(const T0 &D0, F f) {
- const size_t work_amount = (size_t)D0;
- if (work_amount == 0) return;
- tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
- for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
- f(T0(iwork));
- }
- }, tbb::static_partitioner());
-}
-
-template <typename T0, typename T1, typename F>
-void parallel_nd(const T0 &D0, const T1 &D1, F f) {
- const size_t work_amount = (size_t)D0 * D1;
- if (work_amount == 0) return;
- tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
- T0 d0{0}; T1 d1{0};
- utils::nd_iterator_init(r.begin(), d0, D0, d1, D1);
- for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
- f(d0, d1);
- utils::nd_iterator_step(d0, D0, d1, D1);
- }
- }, tbb::static_partitioner());
-}
-
-template <typename T0, typename T1, typename T2, typename F>
-void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, F f) {
- const size_t work_amount = (size_t)D0 * D1 * D2;
- if (work_amount == 0) return;
- tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
- T0 d0{0}; T1 d1{0}; T2 d2{0};
- utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2);
- for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
- f(d0, d1, d2);
- utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
- }
- }, tbb::static_partitioner());
-}
-
-template <typename T0, typename T1, typename T2, typename T3, typename F>
-void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) {
- const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
- if (work_amount == 0) return;
- tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
- T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0};
- utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3);
- for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
- f(d0, d1, d2, d3);
- utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
- }
- }, tbb::static_partitioner());
-}
-
-template <typename T0, typename T1, typename T2, typename T3, typename T4,
- typename F>
-void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
- const T4 &D4, F f) {
- const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4;
- if (work_amount == 0) return;
- tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
- T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0};
- utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
- for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
- f(d0, d1, d2, d3, d4);
- utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
- }
- }, tbb::static_partitioner());
-}
-
-template <typename T0, typename T1, typename T2, typename T3, typename T4,
- typename T5, typename F>
-void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
- const T4 &D4, const T5 &D5, F f) {
- const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
- if (work_amount == 0) return;
- tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
- T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0};
- utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4,
- d5, D5);
- for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
- f(d0, d1, d2, d3, d4, d5);
- utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
- }
- }, tbb::static_partitioner());
-}
-#endif
-
-template <typename ...Args>
-void parallel_nd_in_omp(Args &&...args) {
-#if MKLDNN_THR == MKLDNN_THR_SEQ
- for_nd(0, 1, utils::forward<Args>(args)...);
-#elif MKLDNN_THR == MKLDNN_THR_OMP
- for_nd(mkldnn_get_thread_num(), mkldnn_get_num_threads(),
- utils::forward<Args>(args)...);
-#elif MKLDNN_THR == MKLDNN_THR_TBB
- assert(!"unsupported parallel_nd_in_omp()");
-#endif
-}
-
-} // namespace impl
-} // namespace mkldnn
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp
deleted file mode 100644
index aa671a0b6e..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp
+++ /dev/null
@@ -1,77 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef MKLDNN_TRAITS_HPP
-#define MKLDNN_TRAITS_HPP
-
-#include <assert.h>
-#include <stdint.h>
-
-#include "mkldnn.h"
-#include "c_types_map.hpp"
-#include "nstl.hpp"
-#include "utils.hpp"
-#include "z_magic.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-template <data_type_t> struct prec_traits {}; /* ::type -> float */
-template <typename> struct data_traits {}; /* ::data_type -> f32 */
-template <int> struct typesize_traits {}; /* ::data_type_size -> f32 */
-template <primitive_kind_t> struct pkind_traits {}; /* ::desc_type, ::query_d */
-
-template <> struct prec_traits<data_type::f32> { typedef float type; };
-template <> struct prec_traits<data_type::s32> { typedef int32_t type; };
-template <> struct prec_traits<data_type::s8> { typedef int8_t type; };
-template <> struct prec_traits<data_type::u8> { typedef uint8_t type; };
-
-template <> struct data_traits<float>
-{ static constexpr data_type_t data_type = data_type::f32; };
-template <> struct data_traits<int32_t>
-{ static constexpr data_type_t data_type = data_type::s32; };
-template <> struct data_traits<int8_t>
-{ static constexpr data_type_t data_type = data_type::s8; };
-template <> struct data_traits<uint8_t>
-{ static constexpr data_type_t data_type = data_type::u8; };
-
-template <> struct typesize_traits<4> { typedef float type; };
-template <> struct typesize_traits<2> { typedef int16_t type; };
-template <> struct typesize_traits<1> { typedef uint8_t type; };
-
-#define PKIND_TRAITS_INST(op) \
-template <> struct pkind_traits<primitive_kind::op> { \
- typedef CONCAT2(op, _desc_t) desc_type; \
- static constexpr query_t query_d = query::CONCAT2(op, _d); \
-}
-PKIND_TRAITS_INST(convolution);
-PKIND_TRAITS_INST(deconvolution);
-PKIND_TRAITS_INST(shuffle);
-PKIND_TRAITS_INST(eltwise);
-PKIND_TRAITS_INST(softmax);
-PKIND_TRAITS_INST(pooling);
-PKIND_TRAITS_INST(lrn);
-PKIND_TRAITS_INST(batch_normalization);
-PKIND_TRAITS_INST(inner_product);
-PKIND_TRAITS_INST(rnn);
-#undef PKIND_TRAITS_INST
-
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp b/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp
deleted file mode 100644
index f89ea999e2..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp
+++ /dev/null
@@ -1,193 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef NSTL_HPP
-#define NSTL_HPP
-
-#include <stdint.h>
-#include <limits.h>
-#include <float.h>
-
-#include <vector>
-#include <map>
-
-#include "z_magic.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-void *malloc(size_t size, int alignment);
-void free(void *p);
-
-struct c_compatible {
- enum { default_alignment = 64 };
- static void *operator new(size_t sz) {
- return malloc(sz, default_alignment);
- }
- static void *operator new(size_t sz, void *p) { UNUSED(sz); return p; }
- static void *operator new[](size_t sz) {
- return malloc(sz, default_alignment);
- }
- static void operator delete(void *p) { free(p); }
- static void operator delete[](void *p) { free(p); }
-};
-
-namespace nstl {
-
-template<typename T>
-inline const T abs(const T& a) {
- return a >= 0 ? a : -a;
-}
-
-template<typename T>
-inline const T& max(const T& a, const T& b) {
- return a > b ? a : b;
-}
-
-template<typename T>
-inline const T& min(const T& a, const T& b) {
- return a < b ? a : b;
-}
-
-template<typename T> void swap(T& t1, T& t2) {
- T tmp(t1);
- t1 = t2;
- t2 = tmp;
-}
-
-// Rationale: MKL-DNN needs numeric limits implementation that does not
-// generate dependencies on C++ run-time libraries.
-
-template<typename T> struct numeric_limits;
-
-template<> struct numeric_limits<float> {
- static constexpr float lowest() { return -FLT_MAX; }
- static constexpr float max() { return FLT_MAX; }
-};
-
-template<> struct numeric_limits<int32_t> {
- static constexpr int lowest() { return INT32_MIN; }
- static constexpr int max() { return INT32_MAX; }
-};
-
-template<> struct numeric_limits<int16_t> {
- static constexpr int16_t lowest() { return INT16_MIN; }
- static constexpr int16_t max() { return INT16_MAX; }
-};
-
-template<> struct numeric_limits<int8_t> {
- static constexpr int8_t lowest() { return INT8_MIN; }
- static constexpr int8_t max() { return INT8_MAX; }
-};
-
-template<> struct numeric_limits<uint8_t> {
- static constexpr uint8_t lowest() { return 0; }
- static constexpr uint8_t max() { return UINT8_MAX; }
-};
-
-template<typename T> struct is_integral
-{ static constexpr bool value = false; };
-template<> struct is_integral<int32_t> { static constexpr bool value = true; };
-template<> struct is_integral<int16_t> { static constexpr bool value = true; };
-template<> struct is_integral<int8_t> { static constexpr bool value = true; };
-template<> struct is_integral<uint8_t> { static constexpr bool value = true; };
-
-template <typename T, typename U> struct is_same
-{ static constexpr bool value = false; };
-template <typename T> struct is_same<T, T>
-{ static constexpr bool value = true; };
-
-// Rationale: MKL-DNN needs container implementations that do not generate
-// dependencies on C++ run-time libraries.
-//
-// Implementation philosophy: caller is responsible to check if the operation
-// is valid. The only functions that have to return status are those that
-// depend on memory allocation or similar operations.
-//
-// This means that e.g. an operator [] does not have to check for boundaries.
-// The caller should have checked the boundaries. If it did not we crash and
-// burn: this is a bug in MKL-DNN and throwing an exception would not have been
-// recoverable.
-//
-// On the other hand, insert() or resize() or a similar operation needs to
-// return a status because the outcome depends on factors external to the
-// caller. The situation is probably also not recoverable also, but MKL-DNN
-// needs to be nice and report "out of memory" to the users.
-
-enum nstl_status_t {
- success = 0,
- out_of_memory
-};
-
-template <typename T> class vector: public c_compatible {
-private:
- std::vector<T> _impl;
-public:
- typedef typename std::vector<T>::iterator iterator;
- typedef typename std::vector<T>::const_iterator const_iterator;
- typedef typename std::vector<T>::size_type size_type;
- vector() {}
- vector(size_type n): _impl(n) {}
- vector(size_type n, const T &value): _impl(n, value) {}
- template <typename input_iterator>
- vector(input_iterator first, input_iterator last): _impl(first, last) {}
- ~vector() {}
- size_type size() const { return _impl.size(); }
- T& operator[] (size_type i) { return _impl[i]; }
- const T& operator[] (size_type i) const { return _impl[i]; }
- iterator begin() { return _impl.begin(); }
- const_iterator begin() const { return _impl.begin(); }
- iterator end() { return _impl.end(); }
- const_iterator end() const { return _impl.end(); }
- template <typename input_iterator>
- nstl_status_t insert(iterator pos, input_iterator begin, input_iterator end)
- {
- _impl.insert(pos, begin, end);
- return success;
- }
- void clear() { _impl.clear(); }
- void push_back(const T& t) { _impl.push_back(t); }
- void resize(size_type count) { _impl.resize(count); }
- void reserve(size_type count) { _impl.reserve(count); }
-};
-
-template <typename Key, typename T> class map: public c_compatible {
-private:
- std::map<Key, T> _impl;
-public:
- typedef typename std::map<Key, T>::iterator iterator;
- typedef typename std::map<Key, T>::const_iterator const_iterator;
- typedef typename std::map<Key, T>::size_type size_type;
- map() {}
- ~map() {}
- size_type size() const { return _impl.size(); }
- T& operator[](const Key &k) { return _impl[k]; }
- const T& operator[](const Key &k) const { return _impl[k]; }
- iterator begin() { return _impl.begin(); }
- const_iterator begin() const { return _impl.begin(); }
- iterator end() { return _impl.end(); }
- const_iterator end() const { return _impl.end(); }
- template <typename input_iterator>
- void clear() { _impl.clear(); }
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp b/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp
deleted file mode 100644
index be96e654ff..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp
+++ /dev/null
@@ -1,114 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::prop_kind;
-using namespace mkldnn::impl::alg_kind;
-using namespace mkldnn::impl::types;
-
-namespace {
-status_t pooling_desc_init(pooling_desc_t *pool_desc,
- prop_kind_t prop_kind, alg_kind_t alg_kind,
- const memory_desc_t *src_desc, const memory_desc_t *dst_desc,
- const dims_t strides, const dims_t kernel, const dims_t padding_l,
- const dims_t padding_r, padding_kind_t padding_kind) {
- bool args_ok = true
- && !any_null(pool_desc, src_desc, dst_desc, strides, kernel, padding_l)
- && one_of(alg_kind, pooling_max,
- pooling_avg_include_padding,
- pooling_avg_exclude_padding)
- && one_of(padding_kind, padding_kind::padding_zero);
- if (!args_ok) return invalid_arguments;
-
- if (padding_r == nullptr) padding_r = padding_l;
-
- auto pd = pooling_desc_t();
- pd.primitive_kind = primitive_kind::pooling;
- pd.prop_kind = prop_kind;
- pd.alg_kind = alg_kind;
- pd.src_desc.ndims = src_desc->ndims;
-
- const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
-
- pd.diff_src_desc = pd.src_desc = zero_md();
- pd.diff_dst_desc = pd.dst_desc = zero_md();
-
- (is_fwd ? pd.src_desc : pd.diff_src_desc) = *src_desc;
- (is_fwd ? pd.dst_desc : pd.diff_dst_desc) = *dst_desc;
-
- int sp_dims = src_desc->ndims - 2;
- utils::array_copy(pd.strides, strides, sp_dims);
- utils::array_copy(pd.kernel, kernel, sp_dims);
- utils::array_copy(pd.padding[0], padding_l, sp_dims);
- utils::array_copy(pd.padding[1], padding_r, sp_dims);
-
- pd.padding_kind = padding_kind;
- if (one_of(alg_kind, pooling_max, pooling_avg_include_padding,
- pooling_avg_exclude_padding)) {
- pd.accum_data_type = types::default_accum_data_type(
- src_desc->data_type, dst_desc->data_type);
- } else {
- pd.accum_data_type = dst_desc->data_type;
- }
-
- bool consistency = true
- && utils::one_of(src_desc->ndims, 4, 5)
- && utils::one_of(dst_desc->ndims, 4, 5)
- && src_desc->dims[0] == dst_desc->dims[0]
- && src_desc->dims[1] == dst_desc->dims[1];
- for (int i = 2; i < src_desc->ndims; ++i)
- consistency = consistency && (
- (src_desc->dims[i] - kernel[i - 2] + padding_l[i - 2]
- + padding_r[i - 2]) / strides[i - 2] + 1
- == dst_desc->dims[i]);
- if (!consistency) return invalid_arguments;
-
- *pool_desc = pd;
- return success;
-}
-}
-
-status_t mkldnn_pooling_forward_desc_init(pooling_desc_t *pool_desc,
- prop_kind_t prop_kind, alg_kind_t alg_kind,
- const memory_desc_t *src_desc, const memory_desc_t *dst_desc,
- const dims_t strides, const dims_t kernel, const dims_t padding_l,
- const dims_t padding_r, padding_kind_t padding_kind) {
- if (!one_of(prop_kind, forward_training, forward_inference))
- return invalid_arguments;
- return pooling_desc_init(pool_desc, prop_kind, alg_kind, src_desc,
- dst_desc, strides, kernel, padding_l, padding_r, padding_kind);
-}
-
-status_t mkldnn_pooling_backward_desc_init(pooling_desc_t *pool_desc,
- alg_kind_t alg_kind, const memory_desc_t *diff_src_desc,
- const memory_desc_t *diff_dst_desc, const dims_t strides,
- const dims_t kernel, const dims_t padding_l, const dims_t padding_r,
- padding_kind_t padding_kind) {
- return pooling_desc_init(pool_desc, prop_kind::backward_data, alg_kind,
- diff_src_desc, diff_dst_desc, strides, kernel, padding_l,
- padding_r, padding_kind);
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp
deleted file mode 100644
index 4c9c009412..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp
+++ /dev/null
@@ -1,238 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef POOLING_PD_HPP
-#define POOLING_PD_HPP
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "primitive_desc.hpp"
-#include "type_helpers.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-struct pooling_fwd_pd_t;
-
-struct pooling_pd_t: public primitive_desc_t {
- static constexpr auto base_pkind = primitive_kind::pooling;
-
- pooling_pd_t(engine_t *engine,
- const pooling_desc_t *adesc,
- const primitive_attr_t *attr,
- const pooling_fwd_pd_t *hint_fwd_pd)
- : primitive_desc_t(engine, attr, base_pkind)
- , desc_(*adesc)
- , hint_fwd_pd_(hint_fwd_pd)
- , ws_md_()
- {}
-
- const pooling_desc_t *desc() const { return &desc_; }
- virtual const op_desc_t *op_desc() const override
- { return reinterpret_cast<const op_desc_t *>(this->desc()); }
- virtual void init_info() override { impl::init_info(this, this->info_); }
-
- virtual status_t query(query_t what, int idx, void *result) const override {
- switch (what) {
- case query::pooling_d:
- *(const pooling_desc_t**)result = desc(); break;
- default: return primitive_desc_t::query(what, idx, result);
- }
- return status::success;
- }
-
- /* common pooling aux functions */
-
- dim_t MB() const { return src_desc().dims[0]; }
- dim_t C() const { return src_desc().dims[1]; }
-
- dim_t ID() const { return ndims() >= 5 ? src_desc().dims[ndims() - 3] : 1; }
- dim_t IH() const { return ndims() >= 4 ? src_desc().dims[ndims() - 2] : 1; }
- dim_t IW() const { return src_desc().dims[ndims() - 1]; }
-
- dim_t OD() const { return ndims() >= 5 ? dst_desc().dims[ndims() - 3] : 1; }
- dim_t OH() const { return ndims() >= 4 ? dst_desc().dims[ndims() - 2] : 1; }
- dim_t OW() const { return dst_desc().dims[ndims() - 1]; }
-
- dim_t KD() const { return ndims() >= 5 ? desc_.kernel[ndims() - 5] : 1; }
- dim_t KH() const { return ndims() >= 4 ? desc_.kernel[ndims() - 4] : 1; }
- dim_t KW() const { return desc_.kernel[ndims() - 3]; }
-
- dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
- dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
- dim_t KSW() const { return desc_.strides[ndims() - 3]; }
-
- dim_t padFront() const
- { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
- dim_t padBack() const
- { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
- dim_t padT() const
- { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
- dim_t padB() const
- { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
- dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
- dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
-
- int ndims() const { return src_desc().ndims; }
- bool is_3d() const { return ndims() == 5; }
-
- bool has_zero_dim_memory() const
- { return memory_desc_wrapper(src_desc()).has_zero_dim(); }
-
- bool is_fwd() const {
- return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
- prop_kind::forward_inference);
- }
-
-protected:
- pooling_desc_t desc_;
- const pooling_fwd_pd_t *hint_fwd_pd_;
-
- memory_desc_t ws_md_;
-
- void init_default_ws() {
- ws_md_ = is_fwd() ? *dst_md() : *diff_dst_md();
- ws_md_.data_type = indices_data_type();
- }
-
- data_type_t indices_data_type() const {
- /* the simplest way to express 256... */
- const int u8_max = nstl::numeric_limits<
- typename prec_traits<data_type::u8>::type>::max();
- return utils::array_product(desc()->kernel, ndims()) <= u8_max
- ? data_type::u8 : data_type::s32;
- }
-
-private:
- const memory_desc_t &src_desc() const
- { return is_fwd() ? desc_.src_desc : desc_.diff_src_desc; }
- const memory_desc_t &dst_desc() const
- { return is_fwd() ? desc_.dst_desc : desc_.diff_dst_desc; }
-};
-
-struct pooling_fwd_pd_t: public pooling_pd_t {
- typedef pooling_fwd_pd_t base_class;
- typedef pooling_fwd_pd_t hint_class;
-
- pooling_fwd_pd_t(engine_t *engine,
- const pooling_desc_t *adesc,
- const primitive_attr_t *attr,
- const pooling_fwd_pd_t *hint_fwd_pd)
- : pooling_pd_t(engine, adesc, attr, hint_fwd_pd)
- , src_md_(desc_.src_desc)
- , dst_md_(desc_.dst_desc)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (arg == MKLDNN_ARG_SRC)
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DST)
- return arg_usage_t::output;
-
- if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override
- { return index == 0 ? &src_md_ : nullptr; }
- virtual const memory_desc_t *dst_md(int index = 0) const override
- { return index == 0 ? &dst_md_ : nullptr; }
- virtual const memory_desc_t *workspace_md(int index = 0) const override
- { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
-
- virtual int n_inputs() const override { return 1; }
- virtual int n_outputs() const override
- { return 1 + (workspace_md() != nullptr); }
-
-protected:
- memory_desc_t src_md_;
- memory_desc_t dst_md_;
-
- virtual status_t set_default_params() {
- if (dst_md()->format_kind != format_kind::any)
- return status::success;
-
- if (src_md()->format_kind != format_kind::blocked)
- return status::unimplemented;
-
- return memory_desc_init_by_blocking_desc(dst_md_,
- src_md_.format_desc.blocking);
- }
-};
-
-struct pooling_bwd_pd_t: public pooling_pd_t {
- typedef pooling_bwd_pd_t base_class;
- typedef pooling_fwd_pd_t hint_class;
-
- pooling_bwd_pd_t(engine_t *engine,
- const pooling_desc_t *adesc,
- const primitive_attr_t *attr,
- const pooling_fwd_pd_t *hint_fwd_pd)
- : pooling_pd_t(engine, adesc, attr, hint_fwd_pd)
- , diff_src_md_(desc_.diff_src_desc)
- , diff_dst_md_(desc_.diff_dst_desc)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (arg == MKLDNN_ARG_DIFF_DST)
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DIFF_SRC)
- return arg_usage_t::output;
-
- if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
- return arg_usage_t::input;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *diff_src_md(int index = 0) const override
- { return index == 0 ? &diff_src_md_ : nullptr; }
- virtual const memory_desc_t *diff_dst_md(int index = 0) const override
- { return index == 0 ? &diff_dst_md_ : nullptr; }
- virtual const memory_desc_t *workspace_md(int index = 0) const override
- { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
-
- virtual int n_inputs() const override
- { return 1 + (workspace_md() != nullptr); }
- virtual int n_outputs() const override { return 1; }
-
-protected:
- memory_desc_t diff_src_md_;
- memory_desc_t diff_dst_md_;
-
- virtual status_t set_default_params() {
- if (diff_src_md()->format_kind != format_kind::any)
- return status::success;
-
- if (diff_dst_md()->format_kind != format_kind::blocked)
- return status::unimplemented;
-
- return memory_desc_init_by_blocking_desc(diff_src_md_,
- diff_dst_md_.format_desc.blocking);
- }
-};
-
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp
deleted file mode 100644
index fdf6522f62..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp
+++ /dev/null
@@ -1,103 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "engine.hpp"
-#include "primitive_desc.hpp"
-#include "primitive.hpp"
-#include "type_helpers.hpp"
-#include "stream.hpp"
-#include "utils.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::primitive_kind;
-
-namespace {
-// XXX: this is a huge hammer. This disables all and any msan checks on
-// primitives outputs.
-//
-// A proper approach would be an implementation-specific unpoisoning.
-void unpoison_outputs(const exec_args_t &args) {
- for(const auto &arg: args) {
- if (arg.second.is_const) continue;
- auto *mem = arg.second.mem;
- void *p;
- mem->get_data_handle(&p);
- size_t s = memory_desc_wrapper(*mem->md()).size();
- msan_unpoison(p, s);
- }
-}
-}
-
-status_t mkldnn_primitive_desc_destroy(primitive_desc_t *primitive_desc) {
- if (primitive_desc) delete primitive_desc;
- return success;
-}
-
-status_t mkldnn_primitive_create(primitive_t **primitive,
- const primitive_desc_t *primitive_desc) {
- if (utils::any_null(primitive, primitive_desc))
- return invalid_arguments;
- return primitive_desc->create_primitive(primitive);
-}
-
-status_t mkldnn_primitive_execute(const primitive_t *primitive,
- stream_t *stream, int nargs, const mkldnn_exec_arg_t *c_args) {
- bool ok = true
- && !utils::any_null(primitive, stream)
- && primitive->engine() == stream->engine()
- && IMPLICATION(nargs > 0, c_args != nullptr);
- if (!ok) return invalid_arguments;
-
- exec_args_t args;
- status_t status = cvt_primtive_args(primitive->pd(), nargs, c_args, args);
- if (status != status::success) return status;
-
- exec_ctx_t ctx(stream, std::move(args));
-
- if (mkldnn_verbose()->level) {
- double ms = get_msec();
- status = primitive->execute(ctx);
- ms = get_msec() - ms;
- printf("mkldnn_verbose,exec,%s,%g\n", primitive->pd()->info(), ms);
- fflush(0);
- } else {
- status = primitive->execute(ctx);
- }
-
- if (msan_enabled) unpoison_outputs(ctx.args());
-
- return status;
-}
-
-status_t mkldnn_primitive_get_primitive_desc(const primitive_t *primitive,
- const primitive_desc_t **primitive_desc) {
- if (utils::any_null(primitive, primitive_desc))
- return invalid_arguments;
- return safe_ptr_assign<const primitive_desc_t>(*primitive_desc,
- primitive->pd());
-}
-
-status_t mkldnn_primitive_destroy(primitive_t *primitive) {
- if (primitive != nullptr)
- delete primitive;
- return success;
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp
deleted file mode 100644
index 3b506d6d1f..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp
+++ /dev/null
@@ -1,76 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef PRIMITIVE_HPP
-#define PRIMITIVE_HPP
-
-#include <assert.h>
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "nstl.hpp"
-#include "primitive_desc.hpp"
-#include "primitive_exec_types.hpp"
-
-/** \brief A pure virtual primitive class
- *
- * Primitive contains links to its inputs & outputs, though it does not track
- * their readiness on execution step.
- *
- * @remark @b Rational.
- * Dependencies are essential through-out the whole MKL-DNN library, so it
- * makes sense to include them on the very low level. On the other hand,
- * tracking them should be a task for corresponding essence, like scheduler,
- * stream or whatever. Primitive itself should know nothing about the
- * environment it is running in.
- *
- * @note
- * To make user experience better we should provide API which allows
- * achieving the best (or good enough) performance when creating primitives
- * in natural order: i.e. from bottom to top for forward pass and from top to
- * bottom for backward pass. Please consider restriction [1] in Level 0.
- */
-struct mkldnn_primitive: public mkldnn::impl::c_compatible {
- mkldnn_primitive(const mkldnn::impl::primitive_desc_t *pd)
- : pd_(pd->clone()) {}
- virtual ~mkldnn_primitive() { delete pd_; }
-
- /** returns primitive's engine */
- mkldnn::impl::engine_t *engine() const { return pd_->engine(); }
- /** returns primitive's inputs */
- const mkldnn::impl::primitive_desc_t *pd() const { return pd_; }
- /** returns primitive's kind */
- mkldnn::impl::primitive_kind_t kind() const { return pd_->kind(); }
-
- /** executes primitive with execution context @p ctx */
- virtual mkldnn::impl::status_t execute(const mkldnn::impl::exec_ctx_t &ctx)
- const = 0;
-
-protected:
- const mkldnn::impl::primitive_desc_t *pd_;
-
-private:
- mkldnn_primitive() = delete;
- mkldnn_primitive(const mkldnn_primitive &) = delete;
- mkldnn_primitive(mkldnn_primitive &&) = delete;
- mkldnn_primitive &operator=(const mkldnn_primitive &) = delete;
- mkldnn_primitive &operator=(mkldnn_primitive &&) = delete;
-};
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp
deleted file mode 100644
index 9fd638842c..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp
+++ /dev/null
@@ -1,290 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "primitive_attr.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::utils;
-
-namespace mkldnn {
-namespace impl {
-
-status_t scales_t::set(dim_t count, int mask, const float *scales) {
- cleanup();
-
- count_ = count;
- mask_ = mask;
-
- if (count_ == 1) {
- scales_ = scales_buf_;
- utils::array_set(scales_, scales[0], scales_buf_size);
- } else {
- scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64);
- if (scales_ == nullptr)
- return status::out_of_memory;
-
- for (dim_t c = 0; c < count_; ++c)
- scales_[c] = scales[c];
- }
-
- return status::success;
-}
-
-}
-}
-
-status_t post_ops_t::append_sum(float scale) {
- if (len_ == capacity)
- return out_of_memory;
-
- entry_[len_].kind = primitive_kind::sum;
- entry_[len_].sum.scale = scale;
-
- len_++;
-
- return success;
-}
-
-status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha,
- float beta) {
- using namespace mkldnn::impl::alg_kind;
- bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu,
- eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
- eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic);
- if (!known_alg)
- return invalid_arguments;
-
- if (len_ == capacity)
- return out_of_memory;
-
- entry_[len_].kind = primitive_kind::eltwise;
- entry_[len_].eltwise.scale = scale;
- entry_[len_].eltwise.alg = alg;
- entry_[len_].eltwise.alpha = alpha;
- entry_[len_].eltwise.beta = beta;
-
- len_++;
-
- return success;
-}
-
-status_t primitive_attr_t::set_scratchpad_mode(
- scratchpad_mode_t scratchpad_mode) {
- using namespace mkldnn::impl::scratchpad_mode;
-
- const bool ok = one_of(scratchpad_mode, library, user);
- if (!ok)
- return invalid_arguments;
-
- scratchpad_mode_ = scratchpad_mode;
- return success;
-}
-
-status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) {
- this->post_ops_ = post_ops;
- return success;
-}
-
-/* Public C API */
-
-status_t mkldnn_primitive_attr_create(primitive_attr_t **attr) {
- if (attr == nullptr)
- return invalid_arguments;
-
- return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
- new mkldnn_primitive_attr);
-}
-
-status_t mkldnn_primitive_attr_clone(primitive_attr_t **attr,
- const primitive_attr_t *existing_attr) {
- if (any_null(attr, existing_attr))
- return invalid_arguments;
-
- return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
- existing_attr->clone());
-}
-
-status_t mkldnn_primitive_attr_destroy(primitive_attr_t *attr) {
- if (attr)
- delete attr;
-
- return success;
-}
-
-status_t mkldnn_primitive_attr_get_scratchpad_mode(
- const primitive_attr_t *attr, scratchpad_mode_t *scratchpad_mode) {
- if (any_null(attr, scratchpad_mode))
- return invalid_arguments;
-
- *scratchpad_mode = attr->scratchpad_mode_;
-
- return success;
-}
-
-status_t mkldnn_primitive_attr_set_scratchpad_mode(
- primitive_attr_t *attr, scratchpad_mode_t scratchpad_mode) {
- if (any_null(attr))
- return invalid_arguments;
-
- return attr->set_scratchpad_mode(scratchpad_mode);
-}
-
-status_t mkldnn_primitive_attr_get_output_scales(const primitive_attr_t *attr,
- dim_t *count, int *mask, const float **scales) {
- if (any_null(attr, count, mask, scales))
- return invalid_arguments;
-
- *count = attr->output_scales_.count_;
- *mask = attr->output_scales_.mask_;
- *scales = attr->output_scales_.scales_;
-
- return success;
-}
-
-status_t mkldnn_primitive_attr_set_output_scales(primitive_attr_t *attr,
- dim_t count, int mask, const float *scales) {
- bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
- if (!ok)
- return invalid_arguments;
-
- return attr->output_scales_.set(count, mask, scales);
-}
-
-status_t mkldnn_primitive_attr_get_post_ops(const primitive_attr_t *attr,
- const post_ops_t **post_ops) {
- if (any_null(attr, post_ops))
- return invalid_arguments;
-
- *post_ops = &attr->post_ops_;
- return success;
-}
-
-status_t mkldnn_primitive_attr_set_post_ops(primitive_attr_t *attr,
- const post_ops_t *post_ops) {
- if (any_null(attr, post_ops))
- return invalid_arguments;
-
- return attr->set_post_ops(*post_ops);
-}
-
-status_t mkldnn_post_ops_create(post_ops_t **post_ops) {
- if (post_ops == nullptr)
- return invalid_arguments;
-
- return safe_ptr_assign<mkldnn_post_ops>(*post_ops, new mkldnn_post_ops);
-}
-
-status_t mkldnn_post_ops_destroy(post_ops_t *post_ops) {
- if (post_ops)
- delete post_ops;
-
- return success;
-}
-
-int mkldnn_post_ops_len(const post_ops_t *post_ops) {
- if (post_ops)
- return post_ops->len_;
-
- return 0;
-}
-
-primitive_kind_t mkldnn_post_ops_get_kind(const post_ops_t *post_ops,
- int index) {
- bool ok = post_ops && 0 <= index && index < post_ops->len_;
- if (!ok)
- return primitive_kind::undefined;
-
- return post_ops->entry_[index].kind;
-}
-
-status_t mkldnn_post_ops_append_sum(post_ops_t *post_ops, float scale) {
- if (post_ops == nullptr)
- return invalid_arguments;
-
- return post_ops->append_sum(scale);
-}
-
-namespace {
-bool simple_get_params_check(const post_ops_t *post_ops, int index,
- primitive_kind_t kind) {
- bool ok = true
- && post_ops != nullptr
- && 0 <= index
- && index < post_ops->len_
- && post_ops->entry_[index].kind == kind;
- return ok;
-}
-}
-
-status_t mkldnn_post_ops_get_params_sum(const post_ops_t *post_ops, int index,
- float *scale) {
- bool ok = true
- && simple_get_params_check(post_ops, index, primitive_kind::sum)
- && !any_null(scale);
- if (!ok)
- return invalid_arguments;
-
- *scale = post_ops->entry_[index].sum.scale;
- return success;
-}
-
-status_t mkldnn_post_ops_append_eltwise(post_ops_t *post_ops, float scale,
- alg_kind_t kind, float alpha, float beta) {
- if (post_ops == nullptr)
- return invalid_arguments;
-
- return post_ops->append_eltwise(scale, kind, alpha, beta);
-}
-
-status_t mkldnn_post_ops_get_params_eltwise(const post_ops_t *post_ops,
- int index, float *scale, alg_kind_t *alg, float *alpha, float *beta) {
- bool ok = true
- && simple_get_params_check(post_ops, index, primitive_kind::eltwise)
- && !any_null(scale, alpha, beta);
- if (!ok)
- return invalid_arguments;
-
- const auto &e = post_ops->entry_[index].eltwise;
- *scale = e.scale;
- *alg = e.alg;
- *alpha = e.alpha;
- *beta = e.beta;
-
- return success;
-}
-
-status_t mkldnn_primitive_attr_set_rnn_data_qparams(
- primitive_attr_t *attr, const float scale, const float shift) {
- if (attr == nullptr)
- return invalid_arguments;
-
- return attr->rnn_data_qparams_.set(scale, shift);
-}
-
-status_t mkldnn_primitive_attr_set_rnn_weights_qparams(
- primitive_attr_t *attr, dim_t count, int mask, const float *scales) {
- bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
- if (!ok)
- return invalid_arguments;
-
- return attr->rnn_weights_qparams_.set(count, mask, scales);
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp
deleted file mode 100644
index e2130c7ab1..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp
+++ /dev/null
@@ -1,183 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef PRIMITIVE_ATTR_HPP
-#define PRIMITIVE_ATTR_HPP
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "nstl.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-struct rnn_data_qparams_t : public c_compatible {
- rnn_data_qparams_t() : scale_(1.), shift_(0.) {}
- bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); }
-
- status_t set(float scale, float shift) {
- scale_ = scale;
- shift_ = shift;
- return status::success;
- }
-
- float scale_;
- float shift_;
-};
-
-struct scales_t: public c_compatible {
- scales_t(): count_(1), mask_(0), scales_(scales_buf_)
- { set(1.); }
-
- scales_t(const scales_t &rhs): scales_t()
- { set(rhs.count_, rhs.mask_, rhs.scales_); }
-
- ~scales_t() { cleanup(); }
-
- scales_t &operator=(const scales_t &rhs) {
- if (&rhs == this)
- return *this;
- status_t status = set(rhs.count_, rhs.mask_, rhs.scales_);
- assert(status == status::success);
- (void)status;
- return *this;
- }
-
- bool has_default_values() const {
- for (dim_t c = 0; c < count_; ++c) {
- if(scales_[c] != 1.) return false;
- }
- return true;
- }
-
- status_t set(dim_t count, int mask, const float *scales);
- status_t set(float single_scale) { return this->set(1, 0, &single_scale); }
-
- dim_t count_;
- int mask_;
- float *scales_;
-
-private:
- enum { scales_buf_size = 16 };
- float scales_buf_[scales_buf_size];
-
- void cleanup() {
- if (scales_ != scales_buf_ && scales_ != nullptr)
- impl::free(scales_);
-
- count_ = 1;
- mask_ = 0;
- scales_ = scales_buf_;
- }
-};
-
-}
-}
-
-struct mkldnn_post_ops: public mkldnn::impl::c_compatible {
- struct entry_t {
- struct eltwise_t {
- mkldnn::impl::alg_kind_t alg;
- float scale, alpha, beta;
- };
-
- mkldnn::impl::primitive_kind_t kind;
- union {
- struct { float scale; } sum;
- eltwise_t eltwise;
- };
-
- bool is_eltwise(bool require_scale_one = true) const {
- using namespace mkldnn::impl;
- return kind == primitive_kind::eltwise
- && IMPLICATION(require_scale_one, eltwise.scale == 1.f);
- }
-
- bool is_relu(bool require_scale_one = true,
- bool require_nslope_zero = true) const {
- using namespace mkldnn::impl;
- return is_eltwise(require_scale_one)
- && eltwise.alg == alg_kind::eltwise_relu
- && IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f);
- }
-
- bool is_sum(bool require_scale_one = true) const {
- using namespace mkldnn::impl;
- return kind == primitive_kind::sum
- && IMPLICATION(require_scale_one, sum.scale == 1.f);
- }
- };
-
- mkldnn_post_ops(): len_(0) {}
-
- mkldnn::impl::status_t append_sum(float scale);
- mkldnn::impl::status_t append_eltwise(float scale,
- mkldnn::impl::alg_kind_t alg, float alpha, float beta);
-
- int find(mkldnn::impl::primitive_kind_t kind, int start = 0,
- int stop = -1) const {
- if (stop == -1) stop = len_;
- stop = mkldnn::impl::nstl::min(stop, len_);
- for (int idx = start; idx < stop; ++idx)
- if (entry_[idx].kind == kind) return idx;
- return -1;
- }
-
- bool has_default_values() const { return len_ == 0; }
-
- bool contain(mkldnn::impl::primitive_kind_t kind, int index) const
- { return find(kind, index, index + 1) == index; }
-
- enum { capacity = 4 };
-
- int len_;
- entry_t entry_[capacity];
-};
-
-struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible {
- mkldnn_primitive_attr()
- : scratchpad_mode_(mkldnn::impl::scratchpad_mode::library)
- {}
-
- mkldnn_primitive_attr *clone() const
- { return new mkldnn_primitive_attr(*this); }
-
- /** Returns true if the attributes have default values.
- *
- * @note The scratchpad_mode_ is not take into account */
- bool has_default_values() const {
- return true
- && output_scales_.has_default_values()
- && post_ops_.has_default_values()
- && rnn_data_qparams_.has_default_values()
- && rnn_weights_qparams_.has_default_values();
- }
-
- mkldnn::impl::status_t set_scratchpad_mode(
- mkldnn::impl::scratchpad_mode_t scratchpad_mode);
- mkldnn::impl::status_t set_post_ops(
- const mkldnn::impl::post_ops_t &post_ops);
-
- mkldnn::impl::scratchpad_mode_t scratchpad_mode_;
- mkldnn::impl::scales_t output_scales_;
- mkldnn::impl::post_ops_t post_ops_;
- mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_;
- mkldnn::impl::scales_t rnn_weights_qparams_;
-};
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp
deleted file mode 100644
index 723c41e05a..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp
+++ /dev/null
@@ -1,78 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "nstl.hpp"
-#include "primitive_desc.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::status;
-
-status_t primitive_desc_t::query(query_t what, int idx, void *result) const {
- auto safe_ret_md = [&](const memory_desc_t *_) {
- if (_ == nullptr) return not_required;
- *(const memory_desc_t **)result = _;
- return success;
- };
-
- switch (what) {
- case query::engine: *(engine_t**)result = engine(); break;
- case query::primitive_kind: *(primitive_kind_t*)result = kind(); break;
-
- case query::scratchpad_engine:
- *(engine_t**)result = scratchpad_engine(); break;
-
- case query::memory_consumption_s64:
- *(dim_t *)result = scratchpad_size(scratchpad_mode::library); break;
-
- case query::op_d:
- if (idx != 0 || op_desc() == nullptr) return invalid_arguments;
- *(const_c_op_desc_t *)result
- = static_cast<const_c_op_desc_t>(op_desc()); break;
-
- case query::src_md: return safe_ret_md(src_md(idx));
- case query::diff_src_md: return safe_ret_md(diff_src_md(idx));
- case query::dst_md: return safe_ret_md(dst_md(idx));
- case query::diff_dst_md: return safe_ret_md(diff_dst_md(idx));
- case query::weights_md: return safe_ret_md(weights_md(idx));
- case query::diff_weights_md: return safe_ret_md(diff_weights_md(idx));
- case query::workspace_md:
- if (idx != 0) return status::invalid_arguments;
- return safe_ret_md(workspace_md(idx));
- case query::scratchpad_md:
- if (idx != 0) return status::invalid_arguments;
- return safe_ret_md(scratchpad_md(idx));
-
- case query::num_of_inputs_s32: *(int*)result = n_inputs(); break;
- case query::num_of_outputs_s32: *(int*)result = n_outputs(); break;
-
- case query::impl_info_str: *(const char **)result = name(); break;
-
- default: return unimplemented;
- }
- return success;
-}
-
-status_t mkldnn_primitive_desc_get_attr(const primitive_desc_t *primitive_desc,
- const primitive_attr_t **attr) {
- if (utils::any_null(primitive_desc, attr))
- return invalid_arguments;
-
- *attr = primitive_desc->attr();
- return success;
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp
deleted file mode 100644
index 536dcfa1d0..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp
+++ /dev/null
@@ -1,174 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef PRIMITIVE_DESC_HPP
-#define PRIMITIVE_DESC_HPP
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-#include "primitive_attr.hpp"
-#include "verbose.hpp"
-
-struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible {
- using md_t = mkldnn::impl::memory_desc_t;
-
- mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
- const mkldnn::impl::primitive_attr_t *attr,
- mkldnn::impl::primitive_kind_t kind)
- : engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; }
-
- mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
- mkldnn::impl::primitive_kind_t kind)
- : engine_(engine), kind_(kind) { info_[0] = '\0'; }
-
- virtual mkldnn_primitive_desc *clone() const = 0;
- virtual ~mkldnn_primitive_desc() {}
-
- const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; }
- mkldnn::impl::engine_t *engine() const { return engine_; }
- mkldnn::impl::primitive_kind_t kind() const { return kind_; }
-
- virtual void init_info() {}
- const char *info() const { return info_; }
-
- mkldnn::impl::memory_tracking::registry_t &scratchpad_registry()
- { return scratchpad_registry_; }
- const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const
- { return scratchpad_registry_; }
- virtual mkldnn::impl::engine_t *scratchpad_engine() const
- { return engine_; }
-
- virtual const mkldnn::impl::op_desc_t *op_desc() const { return nullptr; }
-
- enum class arg_usage_t { unused, input, output };
- virtual arg_usage_t arg_usage(
- mkldnn::impl::primitive_arg_index_t arg) const {
- using mkldnn::impl::types::is_zero_md;
- if (arg == MKLDNN_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md()))
- return arg_usage_t::output;
- return arg_usage_t::unused;
- }
-
-# define DECLARE_MD_STUB(stub) \
- virtual const mkldnn::impl::memory_desc_t *stub(int idx = 0) const \
- { return nullptr; }
-
- DECLARE_MD_STUB(input_md); DECLARE_MD_STUB(output_md);
- DECLARE_MD_STUB(src_md); DECLARE_MD_STUB(diff_src_md);
- DECLARE_MD_STUB(dst_md); DECLARE_MD_STUB(diff_dst_md);
- DECLARE_MD_STUB(weights_md); DECLARE_MD_STUB(diff_weights_md);
- DECLARE_MD_STUB(workspace_md);
-# undef DECLARE_MD_STUB
-
- const mkldnn::impl::memory_desc_t *scratchpad_md(int idx = 0) const {
- return idx == 0 ? &scratchpad_md_ : nullptr;
- }
-
- virtual void init_scratchpad_md() {
- auto size = scratchpad_size(mkldnn::impl::scratchpad_mode::user);
- mkldnn::impl::dims_t dims = { size };
- mkldnn_memory_desc_init_by_tag(&scratchpad_md_, size ? 1 : 0, dims,
- mkldnn::impl::data_type::u8, mkldnn_x);
- }
-
- /** returns the scratchpad size for the given scratchpad mode. */
- mkldnn::impl::dim_t scratchpad_size(
- mkldnn::impl::scratchpad_mode_t mode) const {
- if (mode != attr_.scratchpad_mode_) return 0;
- return scratchpad_registry().size();
- }
-
- virtual int n_inputs() const { return 0; }
- virtual int n_outputs() const { return 0; }
-
- virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx,
- void *result) const;
-
- virtual mkldnn::impl::status_t create_primitive(
- mkldnn::impl::primitive_t **primitive) const = 0;
-
- virtual const char *name() const { return "mkldnn_primitive_desc"; }
-
- /* static magic */
-
- template<typename pd_t>
- static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd,
- const mkldnn::impl::op_desc_t *adesc,
- const mkldnn::impl::primitive_attr_t *attr,
- mkldnn::impl::engine_t *engine,
- const mkldnn::impl::primitive_desc_t *hint_fwd) {
- using namespace mkldnn::impl;
- using namespace mkldnn::impl::status;
- using pd_op_desc_t = typename pkind_traits<pd_t::base_pkind>::desc_type;
- if (adesc->kind != pd_t::base_pkind) return invalid_arguments;
- assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true);
- auto hint =
- reinterpret_cast<const typename pd_t::hint_class *>(hint_fwd);
- auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint);
- if (_pd == nullptr) return out_of_memory;
- if (_pd->init() != success) { delete _pd; return unimplemented; }
- _pd->init_info();
- _pd->init_scratchpad_md();
- *pd = _pd;
- return success;
- }
-
-protected:
- mkldnn::impl::engine_t *engine_;
- mkldnn::impl::primitive_attr_t attr_;
- mkldnn::impl::primitive_kind_t kind_;
-
- mkldnn::impl::memory_desc_t scratchpad_md_;
-
- char info_[MKLDNN_VERBOSE_BUF_LEN];
-
- mkldnn::impl::memory_tracking::registry_t scratchpad_registry_;
-
-protected:
- /** compares ws between fwd_pd and this (make sense to use for bwd_pd)
- * Expectation: this already set workspace, and this workspace should
- * exactly match the one from fwd_pd */
- bool compare_ws(const mkldnn_primitive_desc *fwd_pd) const {
- using namespace mkldnn::impl;
- if (!workspace_md()) return true; // the impl lives fine w/o workspace
- return fwd_pd && fwd_pd->workspace_md()
- && *fwd_pd->workspace_md() == *workspace_md();
- }
-};
-
-#define DECLARE_COMMON_PD_t(impl_name, ...) \
- virtual pd_t *clone() const override { return new pd_t(*this); } \
- virtual status_t create_primitive(primitive_t **p) const override { \
- double ms = get_msec(); \
- auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
- ms = get_msec() - ms; \
- if (mkldnn_verbose()->level >= 2) { \
- printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
- fflush(0); \
- } \
- return ret; \
- } \
- virtual const char *name() const override { return impl_name; }
-#define DECLARE_COMMON_PD_T(impl_name, ...) \
- DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__)
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp
deleted file mode 100644
index 43e5a31ef3..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp
+++ /dev/null
@@ -1,90 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "memory.hpp"
-#include "primitive.hpp"
-#include "primitive_exec_types.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs,
- const mkldnn_exec_arg_t *c_args, exec_args_t &args) {
- using namespace status;
-
- if (!IMPLICATION(nargs > 0, c_args != nullptr)) return invalid_arguments;
-
- int n_inputs = 0;
- int n_outputs = 0;
-
- for (int i = 0; i < nargs; ++i) {
- primitive_arg_index_t arg = c_args[i].arg;
- auto *mem = c_args[i].memory;
-
- switch (pd->arg_usage(arg)) {
- case primitive_desc_t::arg_usage_t::input:
- if (args.count(arg) != 0) return invalid_arguments;
- args[arg] = {mem, true};
- n_inputs++;
- break;
- case primitive_desc_t::arg_usage_t::output:
- if (args.count(arg) != 0) return invalid_arguments;
- args[arg] = {mem, false};
- n_outputs++;
- break;
- case primitive_desc_t::arg_usage_t::unused:
- break;
- }
- }
-
- bool scratchpad_required = !types::is_zero_md(pd->scratchpad_md());
-
- if (n_inputs != pd->n_inputs()) return invalid_arguments;
- if (n_outputs != pd->n_outputs() + (scratchpad_required ? 1 : 0))
- return invalid_arguments;
-
- return success;
-}
-
-const void *exec_ctx_t::input(primitive_arg_index_t arg) const {
- if (args_.count(arg) != 1) return nullptr;
- const auto ma = args_.at(arg);
- assert(ma.is_const);
- void *ptr;
- status_t status = ma.mem->get_data_handle(&ptr);
- assert(status == status::success); MAYBE_UNUSED(status);
- return ptr;
-}
-
-void *exec_ctx_t::output(primitive_arg_index_t arg) const {
- if (args_.count(arg) != 1) return nullptr;
- const auto ma = args_.at(arg);
- assert(!ma.is_const);
- void *ptr;
- status_t status = ma.mem->get_data_handle(&ptr);
- assert(status == status::success); MAYBE_UNUSED(status);
- return ptr;
-}
-
-const memory_t *exec_ctx_t::memory(primitive_arg_index_t arg) const {
- assert(args_.count(arg) == 1);
- const auto ma = args_.at(arg);
- assert(!ma.is_const);
- return ma.mem;
-}
-
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp
deleted file mode 100644
index 0645891da7..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp
+++ /dev/null
@@ -1,68 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef PRIMITIVE_EXEC_TYPES_HPP
-#define PRIMITIVE_EXEC_TYPES_HPP
-
-#include <unordered_map>
-
-#include "mkldnn_types.h"
-
-#include "c_types_map.hpp"
-#include "memory.hpp"
-#include "primitive_desc.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-struct memory_arg_t {
- memory_t *mem;
- bool is_const;
-};
-
-using exec_args_t = std::unordered_map<primitive_arg_index_t, memory_arg_t>;
-
-status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs,
- const mkldnn_exec_arg_t *c_args, exec_args_t &args);
-
-/** Primitive execution context (helps passing stream, memories, and events. */
-struct exec_ctx_t {
- exec_ctx_t(const exec_ctx_t &) = default;
- exec_ctx_t(exec_ctx_t &&) = default;
-
- exec_ctx_t(stream_t *stream): stream_(stream) {}
- exec_ctx_t(stream_t *stream, exec_args_t &&args)
- : stream_(stream)
- , args_(std::move(args)) {}
-
- stream_t *stream() const { return stream_; }
- const exec_args_t &args() const { return args_; }
-
- /* tentative solution... TODO: replace with functions return memory_t */
- const void *input(primitive_arg_index_t arg) const;
- void *output(primitive_arg_index_t arg) const;
-
- const memory_t *memory(primitive_arg_index_t arg) const;
-
-private:
- stream_t *stream_;
- exec_args_t args_;
-};
-
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp
deleted file mode 100644
index 5a1cd7d379..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp
+++ /dev/null
@@ -1,89 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "engine.hpp"
-#include "primitive_desc.hpp"
-#include "type_helpers.hpp"
-#include "primitive_iterator.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::status;
-
-status_t mkldnn_primitive_desc_iterator_create(
- primitive_desc_iterator_t **iterator, const_c_op_desc_t c_op_desc,
- const primitive_attr_t *attr, engine_t *engine,
- const primitive_desc_t *hint_fwd_pd) {
- const op_desc_t *op_desc = (const op_desc_t *)c_op_desc;
-
- auto it = new primitive_desc_iterator_t(engine, op_desc, attr, hint_fwd_pd);
- if (it == nullptr) return out_of_memory;
-
- ++(*it);
- if (*it == it->end()) {
- delete it;
- return unimplemented;
- }
-
- *iterator = it;
- return success;
-}
-
-status_t mkldnn_primitive_desc_iterator_next(
- primitive_desc_iterator_t *iterator) {
- if (iterator == nullptr) return invalid_arguments;
- ++(*iterator);
- return *iterator == iterator->end() ? iterator_ends : success;
-}
-
-primitive_desc_t *mkldnn_primitive_desc_iterator_fetch(
- const primitive_desc_iterator_t *iterator) {
- if (iterator == nullptr) return nullptr;
- return *(*iterator);
-}
-
-status_t mkldnn_primitive_desc_clone(primitive_desc_t **primitive_desc,
- const primitive_desc_t *existing_primitive_desc) {
- if (utils::any_null(primitive_desc, existing_primitive_desc))
- return invalid_arguments;
- return safe_ptr_assign<primitive_desc_t>(*primitive_desc,
- existing_primitive_desc->clone());
-}
-
-status_t mkldnn_primitive_desc_iterator_destroy(
- primitive_desc_iterator_t *iterator) {
- if (iterator != nullptr)
- delete iterator;
- return success;
-}
-
-status_t mkldnn_primitive_desc_create(primitive_desc_t **primitive_desc,
- const_c_op_desc_t c_op_desc, const primitive_attr_t *attr,
- engine_t *engine, const primitive_desc_t *hint_fwd_pd) {
- const op_desc_t *op_desc = (const op_desc_t *)c_op_desc;
-
- mkldnn_primitive_desc_iterator it(engine, op_desc, attr, hint_fwd_pd);
- ++it;
- if (it == it.end()) return unimplemented;
-
- return safe_ptr_assign<primitive_desc_t>(*primitive_desc, *it);
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp
deleted file mode 100644
index 4e88ab3aa5..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp
+++ /dev/null
@@ -1,79 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-#ifndef PRIMITIVE_ITERATOR_HPP
-#define PRIMITIVE_ITERATOR_HPP
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "engine.hpp"
-#include "primitive_desc.hpp"
-#include "type_helpers.hpp"
-
-struct mkldnn_primitive_desc_iterator: public mkldnn::impl::c_compatible {
- using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f;
-
- mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, const mkldnn::impl::op_desc_t *op_desc,
- const mkldnn::impl::primitive_attr_t *attr, const mkldnn::impl::primitive_desc_t *hint_fwd_pd)
- : idx_(-1), engine_(engine), pd_(nullptr), op_desc_(op_desc)
- , attr_(attr ? *attr : mkldnn::impl::primitive_attr_t()), hint_fwd_pd_(hint_fwd_pd)
- , impl_list_(engine_->get_implementation_list()), last_idx_(0)
- {
- while (impl_list_[last_idx_] != nullptr) ++last_idx_;
- }
- ~mkldnn_primitive_desc_iterator() { if (pd_) delete pd_; }
-
- bool operator==(const mkldnn::impl::primitive_desc_iterator_t& rhs) const
- { return idx_ == rhs.idx_ && engine_ == rhs.engine_; }
- bool operator!=(const mkldnn::impl::primitive_desc_iterator_t& rhs) const
- { return !operator==(rhs); }
-
- mkldnn::impl::primitive_desc_iterator_t end() const
- { return mkldnn_primitive_desc_iterator(engine_, last_idx_); }
-
- mkldnn::impl::primitive_desc_iterator_t &operator++() {
- if (pd_) { delete pd_; pd_ = nullptr; }
- while (++idx_ != last_idx_) {
- auto s = impl_list_[idx_](&pd_, op_desc_, &attr_, engine_,
- hint_fwd_pd_);
- if (s == mkldnn::impl::status::success) break;
- }
- return *this;
- }
-
- mkldnn::impl::primitive_desc_t *operator*() const {
- if (*this == end() || pd_ == nullptr) return nullptr;
- return pd_->clone();
- }
-
-protected:
- int idx_;
- mkldnn::impl::engine_t *engine_;
- mkldnn::impl::primitive_desc_t *pd_;
- const mkldnn::impl::op_desc_t *op_desc_;
- const mkldnn::impl::primitive_attr_t attr_;
- const mkldnn::impl::primitive_desc_t *hint_fwd_pd_;
- const pd_create_f *impl_list_;
- int last_idx_;
-
-private:
- mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, int last_idx)
- : idx_(last_idx), engine_(engine), pd_(nullptr)
- , op_desc_(nullptr), hint_fwd_pd_(nullptr)
- , impl_list_(nullptr), last_idx_(last_idx) {}
-};
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/query.cpp b/thirdparty/oidn/mkl-dnn/src/common/query.cpp
deleted file mode 100644
index 835cd73581..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/query.cpp
+++ /dev/null
@@ -1,59 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "engine.hpp"
-#include "primitive_desc.hpp"
-#include "utils.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::status;
-
-status_t mkldnn_primitive_desc_query(const primitive_desc_t *primitive_desc,
- query_t what, int index, void *result) {
- if (any_null(primitive_desc, result))
- return invalid_arguments;
-
- return primitive_desc->query(what, index, result);
-}
-
-const memory_desc_t *mkldnn_primitive_desc_query_md(
- const primitive_desc_t *primitive_desc, query_t what, int index) {
- const memory_desc_t *res_md = nullptr;
- bool args_ok = true
- && primitive_desc != nullptr
- && (what & query::some_md) == query::some_md
- && what != query::some_md
- && mkldnn_primitive_desc_query(primitive_desc,
- what, index, &res_md) == success;
- return args_ok ? res_md : nullptr;
-}
-
-int mkldnn_primitive_desc_query_s32(const primitive_desc_t *primitive_desc,
- query_t what, int index) {
- int res_s32;
- bool args_ok = primitive_desc != nullptr
- && one_of(what, query::num_of_inputs_s32, query::num_of_outputs_s32)
- && mkldnn_primitive_desc_query(primitive_desc, what, index, &res_s32)
- == success;
- return args_ok ? res_s32 : 0;
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp b/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp
deleted file mode 100644
index d11f1a0361..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp
+++ /dev/null
@@ -1,68 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "engine.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "reorder_pd.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::status;
-
-status_t mkldnn_reorder_primitive_desc_create(
- primitive_desc_t **reorder_pd,
- engine_t *src_engine, const memory_desc_t *src_md,
- engine_t *dst_engine, const memory_desc_t *dst_md,
- const primitive_attr_t *attr) {
- if (any_null(reorder_pd, src_engine, src_md, dst_engine, dst_md))
- return invalid_arguments;
-
- auto s_ek = src_engine->kind();
- auto d_ek = dst_engine->kind();
- if (!IMPLICATION(s_ek != d_ek, one_of(engine_kind::cpu, s_ek, d_ek)))
- return invalid_arguments;
-
- auto r_pd = reinterpret_cast<reorder_pd_t **>(reorder_pd);
- auto s_mdw = memory_desc_wrapper(*src_md);
- auto d_mdw = memory_desc_wrapper(*dst_md);
-
- if (!s_mdw.consistent_with(d_mdw))
- return invalid_arguments;
-
- auto e = (s_ek != engine_kind::cpu) ? src_engine : dst_engine;
-
- const primitive_attr_t dummy_attr;
- if (attr == NULL)
- attr = &dummy_attr;
-
- for (auto r = e->get_reorder_implementation_list(); *r; ++r) {
- if ((*r)(r_pd, e, attr, src_engine, src_md, dst_engine, dst_md)
- == success) {
- (*r_pd)->init_info();
- (*r_pd)->init_scratchpad_md();
- return success;
- }
- }
- return unimplemented;
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp
deleted file mode 100644
index 963cb0f58a..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp
+++ /dev/null
@@ -1,85 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef REORDER_PD_HPP
-#define REORDER_PD_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "primitive_attr.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-struct reorder_pd_t: public primitive_desc_t {
- reorder_pd_t(engine_t *engine, const primitive_attr_t *attr,
- engine_t *src_engine, const memory_desc_t *src_md,
- engine_t *dst_engine, const memory_desc_t *dst_md)
- : primitive_desc_t(engine, attr, primitive_kind::reorder)
- , src_engine_(src_engine)
- , dst_engine_(dst_engine)
- , scratchpad_engine_(nullptr)
- , src_md_(*src_md)
- , dst_md_(*dst_md)
- {}
-
- virtual const op_desc_t *op_desc() const override { return nullptr; }
- virtual void init_info() override { impl::init_info(this, this->info_); }
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (arg == MKLDNN_ARG_FROM)
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_TO)
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override
- { return index == 0 ? &src_md_ : nullptr; }
- virtual const memory_desc_t *dst_md(int index = 0) const override
- { return index == 0 ? &dst_md_ : nullptr; }
-
- virtual int n_inputs() const override { return 1; }
- virtual int n_outputs() const override { return 1; }
-
- float alpha() const { return attr()->output_scales_.scales_[0]; }
- float beta() const {
- const int sum_idx = attr()->post_ops_.find(primitive_kind::sum);
- return sum_idx == -1 ? 0 : attr()->post_ops_.entry_[sum_idx].sum.scale;
- }
- virtual mkldnn::impl::engine_t *scratchpad_engine() const override
- { return scratchpad_engine_; }
-
-protected:
- engine_t *src_engine_;
- engine_t *dst_engine_;
- engine_t *scratchpad_engine_;
-
- memory_desc_t src_md_;
- memory_desc_t dst_md_;
-};
-
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp b/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp
deleted file mode 100644
index 36967431a6..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp
+++ /dev/null
@@ -1,400 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-#include "cpu/gemm/os_blas.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::types;
-using namespace mkldnn::impl::utils;
-
-namespace {
-memory_desc_t copy_maybe_null(const memory_desc_t *md) {
- return md ? *md : zero_md();
-}
-
-rnn_desc_t zero_rnn_desc() {
- auto rd = rnn_desc_t();
- rd.src_layer_desc = zero_md();
- rd.src_iter_desc = zero_md();
- rd.weights_layer_desc = zero_md();
- rd.weights_iter_desc = zero_md();
- rd.bias_desc = zero_md();
- rd.dst_layer_desc = zero_md();
- rd.dst_iter_desc = zero_md();
- rd.diff_src_layer_desc = zero_md();
- rd.diff_src_iter_desc = zero_md();
- rd.diff_weights_layer_desc = zero_md();
- rd.diff_weights_iter_desc = zero_md();
- rd.diff_bias_desc = zero_md();
- rd.diff_dst_layer_desc = zero_md();
- rd.diff_dst_iter_desc = zero_md();
- return rd;
-}
-}
-
-/* Public C Api */
-
-status_t mkldnn_rnn_cell_desc_init(rnn_cell_desc_t *rnn_cell_desc,
- mkldnn_alg_kind_t cell_kind, mkldnn_alg_kind_t act_f,
- unsigned int flags, float alpha, float clipping) {
- using namespace mkldnn::impl::alg_kind;
-
- bool args_ok = true
- && one_of(cell_kind, vanilla_rnn, vanilla_lstm, vanilla_gru,
- gru_linear_before_reset)
- && IMPLICATION(cell_kind == vanilla_rnn,
- one_of(act_f, eltwise_relu, eltwise_tanh, eltwise_logistic));
- if (!args_ok)
- return invalid_arguments;
-
- auto rcd = mkldnn_rnn_cell_desc_t();
-
- rcd.cell_kind = cell_kind;
- rcd.activation_kind = act_f;
- rcd.flags = flags;
- rcd.alpha = rcd.flags & mkldnn_rnn_cell_with_relu ? alpha : 0;
- rcd.clipping = rcd.flags & mkldnn_rnn_cell_with_clipping ? clipping : 0;
-
- *rnn_cell_desc = rcd;
-
- return success;
-}
-
-int mkldnn_rnn_cell_get_gates_count(const rnn_cell_desc_t *rnn_cell_desc) {
- switch (rnn_cell_desc->cell_kind) {
- case mkldnn::impl::alg_kind::vanilla_rnn: return 1;
- case mkldnn::impl::alg_kind::vanilla_gru: return 3;
- case mkldnn::impl::alg_kind::gru_linear_before_reset: return 3;
- case mkldnn::impl::alg_kind::vanilla_lstm: return 4;
- default: assert(!"unknown cell kind"); return 0;
- }
- return 0;
-}
-
-int mkldnn_rnn_cell_get_states_count(const rnn_cell_desc_t *rnn_cell_desc) {
- switch (rnn_cell_desc->cell_kind) {
- case mkldnn::impl::alg_kind::vanilla_rnn: return 1;
- case mkldnn::impl::alg_kind::vanilla_gru: return 1;
- case mkldnn::impl::alg_kind::gru_linear_before_reset: return 1;
- case mkldnn::impl::alg_kind::vanilla_lstm: return 2;
- default: assert(!"unknown cell kind"); return 0;
- }
- return 0;
-}
-
-status_t check_data_type_consistency_fwd(const rnn_cell_desc_t *rnn_cell_desc,
- prop_kind_t prop_kind, const memory_desc_t *src_layer_desc,
- const memory_desc_t *src_iter_desc,
- const memory_desc_t *weights_layer_desc,
- const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
- const memory_desc_t *dst_layer_desc,
- const memory_desc_t *dst_iter_desc) {
- using namespace data_type;
- data_type_t src_layer_dt = src_layer_desc->data_type;
- data_type_t dst_layer_dt = dst_layer_desc->data_type;
- data_type_t weights_iter_dt = weights_iter_desc->data_type;
- data_type_t weights_layer_dt = weights_layer_desc->data_type;
-
- bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt, weights_iter_dt,
- weights_layer_dt)
- && IMPLICATION(!is_zero_md(src_iter_desc),
- src_iter_desc->data_type == f32)
- && IMPLICATION(!is_zero_md(dst_iter_desc),
- dst_iter_desc->data_type == f32)
- && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
-
-#if USE_MKL_PACKED_GEMM
- bool is_u8u8u8 = src_layer_dt == u8
- && IMPLICATION(!is_zero_md(src_iter_desc),
- src_iter_desc->data_type == u8)
- && IMPLICATION(!is_zero_md(dst_iter_desc),
- dst_iter_desc->data_type == u8)
- && one_of(dst_layer_dt, u8, f32)
- && everyone_is(s8, weights_iter_dt, weights_layer_dt)
- && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
-
- bool is_f32u8f32 = src_layer_dt == u8
- && IMPLICATION(!is_zero_md(src_iter_desc),
- src_iter_desc->data_type == f32)
- && IMPLICATION(!is_zero_md(dst_iter_desc),
- dst_iter_desc->data_type == f32)
- && one_of(dst_layer_dt, u8, f32)
- && everyone_is(s8, weights_iter_dt, weights_layer_dt)
- && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
-
- bool is_inference = prop_kind == prop_kind::forward_inference;
- bool is_lstm = rnn_cell_desc->cell_kind == mkldnn_vanilla_lstm;
-
- return (is_f32 || ((is_u8u8u8 || is_f32u8f32) && is_lstm && is_inference))
- ? success
- : unimplemented;
-#else
- return is_f32 ? success : unimplemented;
-#endif
-}
-
-status_t check_dim_consistency(const rnn_cell_desc_t *rnn_cell_desc,
- rnn_direction_t direction, int L, int D, int T, int N, int S, int G,
- int SLC, int SIC, int DLC, int DIC, const memory_desc_t *src_layer_desc,
- const memory_desc_t *src_iter_desc,
- const memory_desc_t *weights_layer_desc,
- const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
- const memory_desc_t *dst_layer_desc,
- const memory_desc_t *dst_iter_desc) {
- bool args_ok;
-
- // * algorithm specific
- args_ok = true
- && IMPLICATION(rnn_cell_desc->cell_kind == alg_kind::vanilla_gru,
- DIC == SIC);
- if (!args_ok) return invalid_arguments;
- int extra_bias =
- rnn_cell_desc->cell_kind == alg_kind::gru_linear_before_reset;
-
- // * on num layers
- args_ok = true
- && L == weights_layer_desc->dims[0]
- && L == weights_iter_desc->dims[0]
- && IMPLICATION(!is_zero_md(bias_desc), L == bias_desc->dims[0])
- && IMPLICATION(!is_zero_md(src_iter_desc), L == src_iter_desc->dims[0])
- && IMPLICATION(!is_zero_md(dst_iter_desc), L == dst_iter_desc->dims[0]);
- if (!args_ok) return invalid_arguments;
-
- // * on num directions
- args_ok = true
- && D == weights_layer_desc->dims[1]
- && D == weights_iter_desc->dims[1]
- && IMPLICATION(!is_zero_md(bias_desc), D == bias_desc->dims[1])
- && IMPLICATION(!is_zero_md(src_iter_desc), D == src_iter_desc->dims[1])
- && IMPLICATION(!is_zero_md(dst_iter_desc), D == dst_iter_desc->dims[1]);
- if (!args_ok) return invalid_arguments;
-
- // * on num iterations
- args_ok = true
- && T == src_layer_desc->dims[0]
- && T == dst_layer_desc->dims[0];
- if (!args_ok) return invalid_arguments;
-
- // * on mb
- args_ok = true
- && N == src_layer_desc->dims[1]
- && N == dst_layer_desc->dims[1]
- && IMPLICATION(!is_zero_md(src_iter_desc), N == src_iter_desc->dims[3])
- && IMPLICATION(!is_zero_md(dst_iter_desc), N == dst_iter_desc->dims[3]);
- if (!args_ok) return invalid_arguments;
-
- // * on num gates
- args_ok = true
- && G == mkldnn_rnn_cell_get_gates_count(rnn_cell_desc)
- && G == weights_layer_desc->dims[3]
- && G == weights_iter_desc->dims[3]
- && IMPLICATION(!is_zero_md(bias_desc),
- G + extra_bias == bias_desc->dims[2]);
- if (!args_ok) return invalid_arguments;
-
- // * on num states
- args_ok = true
- && S == mkldnn_rnn_cell_get_states_count(rnn_cell_desc)
- && IMPLICATION(!is_zero_md(src_iter_desc), S == src_iter_desc->dims[2])
- && IMPLICATION(!is_zero_md(dst_iter_desc), S == dst_iter_desc->dims[2]);
- if (!args_ok) return invalid_arguments;
-
- // * on slc
- args_ok = true
- && SLC == weights_layer_desc->dims[2]
- && SLC == src_layer_desc->dims[2];
- if (!args_ok) return invalid_arguments;
-
- // * on sic
- args_ok = true
- && SIC == weights_iter_desc->dims[2]
- && IMPLICATION(!is_zero_md(src_iter_desc),
- SIC == src_iter_desc->dims[4]);
- if (!args_ok) return invalid_arguments;
-
- // * on dlc
- int dlc_multiplier = (direction == mkldnn_bidirectional_concat) ? 2 : 1;
- args_ok = true
- && DLC == dlc_multiplier * DIC
- && DLC == dst_layer_desc->dims[2];
- if (!args_ok) return invalid_arguments;
-
- // * on dic
- args_ok = true
- && DIC == weights_layer_desc->dims[4]
- && DIC == weights_iter_desc->dims[4]
- && IMPLICATION(!is_zero_md(bias_desc), DIC == bias_desc->dims[3])
- && IMPLICATION(!is_zero_md(dst_iter_desc),
- DIC == dst_iter_desc->dims[4]);
- if (!args_ok) return invalid_arguments;
-
- // * unrolling/fusion conditions
- args_ok = true
- && IMPLICATION(L > 1, (dlc_multiplier * SLC) == DLC)
- && IMPLICATION(T > 1, SIC == DIC);
- if (!args_ok) return invalid_arguments;
-
- return success;
-}
-
-status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc,
- prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc,
- const rnn_direction_t direction, const memory_desc_t *src_layer_desc,
- const memory_desc_t *src_iter_desc,
- const memory_desc_t *weights_layer_desc,
- const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
- const memory_desc_t *dst_layer_desc,
- const memory_desc_t *dst_iter_desc) {
- bool args_ok = true && rnn_cell_desc != nullptr
- && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
- dst_layer_desc);
- if (!args_ok) return invalid_arguments;
-
- //check dimensions consistency
- int L = weights_layer_desc->dims[0];
- int T = src_layer_desc->dims[0];
- int N = src_layer_desc->dims[1];
- const int D = one_of(direction, mkldnn_unidirectional_left2right,
- mkldnn_unidirectional_right2left) ?
- 1 :
- 2;
- int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc);
- int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc);
- int SLC = src_layer_desc->dims[2];
- int SIC = weights_iter_desc->dims[2];
- int DLC = dst_layer_desc->dims[2];
- int DIC = weights_layer_desc->dims[4];
-
- CHECK(check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
- G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc,
- weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc,
- dst_iter_desc));
-
- CHECK(check_data_type_consistency_fwd(rnn_cell_desc, prop_kind,
- src_layer_desc, src_iter_desc, weights_layer_desc,
- weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc));
-
- // Create the descriptor
- mkldnn_rnn_desc_t rd = zero_rnn_desc();
-
- rd.primitive_kind = primitive_kind::rnn;
- rd.prop_kind = prop_kind;
- rd.cell_desc = *rnn_cell_desc;
- rd.direction = direction;
- rd.src_layer_desc = copy_maybe_null(src_layer_desc);
- rd.src_iter_desc = copy_maybe_null(src_iter_desc);
- rd.weights_layer_desc = copy_maybe_null(weights_layer_desc);
- rd.weights_iter_desc = copy_maybe_null(weights_iter_desc);
- rd.bias_desc = copy_maybe_null(bias_desc);
- rd.dst_layer_desc = copy_maybe_null(dst_layer_desc);
- rd.dst_iter_desc = copy_maybe_null(dst_iter_desc);
-
- *rnn_desc = rd;
-
- return success;
-}
-
-status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc,
- prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc,
- const rnn_direction_t direction, const memory_desc_t *src_layer_desc,
- const memory_desc_t *src_iter_desc,
- const memory_desc_t *weights_layer_desc,
- const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
- const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc,
- const memory_desc_t *diff_src_layer_desc,
- const memory_desc_t *diff_src_iter_desc,
- const memory_desc_t *diff_weights_layer_desc,
- const memory_desc_t *diff_weights_iter_desc,
- const memory_desc_t *diff_bias_desc,
- const memory_desc_t *diff_dst_layer_desc,
- const memory_desc_t *diff_dst_iter_desc) {
- bool args_ok = true
- && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
- dst_layer_desc, diff_src_layer_desc,
- diff_weights_layer_desc, diff_weights_iter_desc,
- diff_dst_layer_desc);
- if (!args_ok)
- return invalid_arguments;
-
- auto xnor_md = [=](const memory_desc_t *a_md, const memory_desc_t *b_md) {
- return is_zero_md(a_md) == is_zero_md(b_md);
- };
-
- args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc)
- && xnor_md(dst_iter_desc, diff_dst_iter_desc)
- && xnor_md(src_iter_desc, diff_src_iter_desc);
- if (!args_ok)
- return invalid_arguments;
-
- //check dimensions consistency
- int L = weights_layer_desc->dims[0];
- int T = src_layer_desc->dims[0];
- int N = src_layer_desc->dims[1];
- const int D = one_of(direction, mkldnn_unidirectional_left2right,
- mkldnn_unidirectional_right2left) ?
- 1 :
- 2;
- int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc);
- int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc);
- int SLC = src_layer_desc->dims[2];
- int SIC = weights_iter_desc->dims[2];
- int DLC = dst_layer_desc->dims[2];
- int DIC = weights_layer_desc->dims[4];
-
- status_t st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
- G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc,
- weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc,
- dst_iter_desc);
- if (st != success) return st;
-
- st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
- G, SLC, SIC, DLC, DIC, diff_src_layer_desc, diff_src_iter_desc,
- diff_weights_layer_desc, diff_weights_iter_desc, diff_bias_desc,
- diff_dst_layer_desc, diff_dst_iter_desc);
- if (st != success) return st;
-
- mkldnn_rnn_desc_t rd = zero_rnn_desc();
-
- rd.primitive_kind = primitive_kind::rnn;
- rd.prop_kind = prop_kind;
- rd.cell_desc = *rnn_cell_desc;
- rd.direction = direction;
-
- rd.src_layer_desc = copy_maybe_null(src_layer_desc);
- rd.src_iter_desc = copy_maybe_null(src_iter_desc);
- rd.weights_layer_desc = copy_maybe_null(weights_layer_desc);
- rd.weights_iter_desc = copy_maybe_null(weights_iter_desc);
- rd.bias_desc = copy_maybe_null(bias_desc);
- rd.dst_layer_desc = copy_maybe_null(dst_layer_desc);
- rd.dst_iter_desc = copy_maybe_null(dst_iter_desc);
- rd.diff_src_layer_desc = copy_maybe_null(diff_src_layer_desc);
- rd.diff_src_iter_desc = copy_maybe_null(diff_src_iter_desc);
- rd.diff_weights_layer_desc = copy_maybe_null(diff_weights_layer_desc);
- rd.diff_weights_iter_desc = copy_maybe_null(diff_weights_iter_desc);
- rd.diff_bias_desc = copy_maybe_null(diff_bias_desc);
- rd.diff_dst_layer_desc = copy_maybe_null(diff_dst_layer_desc);
- rd.diff_dst_iter_desc = copy_maybe_null(diff_dst_iter_desc);
-
- *rnn_desc = rd;
-
- return success;
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp
deleted file mode 100644
index 1ee2ba1114..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp
+++ /dev/null
@@ -1,280 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef RNN_PD_HPP
-#define RNN_PD_HPP
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "primitive_desc.hpp"
-#include "type_helpers.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-struct rnn_fwd_pd_t;
-
-struct rnn_pd_t : public primitive_desc_t {
- static constexpr auto base_pkind = primitive_kind::rnn;
-
- rnn_pd_t(engine_t *engine,
- const rnn_desc_t *adesc,
- const primitive_attr_t *attr,
- const rnn_fwd_pd_t *hint_fwd_pd)
- : primitive_desc_t(engine, attr, base_pkind)
- , desc_(*adesc)
- , hint_fwd_pd_(hint_fwd_pd)
- , src_layer_md_(desc_.src_layer_desc)
- , src_iter_md_(desc_.src_iter_desc)
- , weights_layer_md_(desc_.weights_layer_desc)
- , weights_iter_md_(desc_.weights_iter_desc)
- , bias_md_(desc_.bias_desc)
- , dst_layer_md_(desc_.dst_layer_desc)
- , dst_iter_md_(desc_.dst_iter_desc)
- , ws_md_()
- {}
-
- const rnn_desc_t *desc() const { return &desc_; }
- virtual const op_desc_t *op_desc() const override
- { return reinterpret_cast<const op_desc_t *>(this->desc()); }
- virtual void init_info() override { impl::init_info(this, this->info_); }
-
- virtual status_t query(query_t what, int idx, void *result) const override {
- switch (what) {
- case query::rnn_d: *(const rnn_desc_t **)result = desc(); break;
- default: return primitive_desc_t::query(what, idx, result);
- }
- return status::success;
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override {
- if (index == 0) return &src_layer_md_;
- if (index == 1 && with_src_iter()) return &src_iter_md_;
- return nullptr;
- }
- virtual const memory_desc_t *weights_md(int index = 0) const override {
- if (index == 0) return &weights_layer_md_;
- if (index == 1) return &weights_iter_md_;
- if (index == 2 && with_bias()) return &bias_md_;
- return nullptr;
- }
- virtual const memory_desc_t *dst_md(int index = 0) const override {
- if (index == 0) return &dst_layer_md_;
- if (index == 1 && with_dst_iter()) return &dst_iter_md_;
- return nullptr;
- }
- virtual const memory_desc_t *workspace_md(int index = 0) const override
- { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
-
- /* common pooling aux functions */
-
- bool is_training() const {
- return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
- prop_kind::backward);
- }
-
- bool is_fwd() const {
- return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
- prop_kind::forward_inference);
- }
-
- dim_t T() const { return desc_.src_layer_desc.dims[0]; }
- dim_t MB() const { return desc_.src_layer_desc.dims[1]; }
-
- dim_t L() const { return desc_.weights_layer_desc.dims[0]; }
- dim_t D() const { return desc_.weights_layer_desc.dims[1]; }
-
- dim_t SIC() const { return desc_.weights_iter_desc.dims[2]; }
-
- dim_t SLC() const { return desc_.weights_layer_desc.dims[2]; }
- dim_t G() const { return desc_.weights_layer_desc.dims[3]; }
- dim_t DIC() const { return desc_.weights_layer_desc.dims[4]; }
-
- dim_t DLC() const { return desc_.dst_layer_desc.dims[2]; }
-
- bool with_bias() const
- { return !memory_desc_wrapper(desc_.bias_desc).is_zero(); }
-
- bool with_src_iter() const
- { return !(memory_desc_wrapper(desc_.src_iter_desc).is_zero()); }
-
- bool with_dst_iter() const
- { return !memory_desc_wrapper(desc_.dst_iter_desc).is_zero(); }
-
- mkldnn::impl::alg_kind_t cell_kind() const
- { return desc_.cell_desc.cell_kind; }
- mkldnn::impl::alg_kind_t activation_kind() const
- { return desc_.cell_desc.activation_kind; }
-
- bool is_lbr() const
- { return cell_kind() == mkldnn_gru_linear_before_reset; }
-
- mkldnn_rnn_direction_t direction() const { return desc_.direction; }
-
-protected:
- rnn_desc_t desc_;
- const rnn_fwd_pd_t *hint_fwd_pd_;
-
- memory_desc_t src_layer_md_;
- memory_desc_t src_iter_md_;
- memory_desc_t weights_layer_md_;
- memory_desc_t weights_iter_md_;
- memory_desc_t bias_md_;
- memory_desc_t dst_layer_md_;
- memory_desc_t dst_iter_md_;
-
- memory_desc_t ws_md_;
-};
-
-struct rnn_fwd_pd_t: public rnn_pd_t {
- typedef rnn_fwd_pd_t base_class;
- typedef rnn_fwd_pd_t hint_class;
-
- rnn_fwd_pd_t(engine_t *engine,
- const rnn_desc_t *adesc,
- const primitive_attr_t *attr,
- const rnn_fwd_pd_t *hint_fwd_pd)
- : rnn_pd_t(engine, adesc, attr, hint_fwd_pd)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (arg == MKLDNN_ARG_SRC_LAYER)
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_SRC_ITER && with_src_iter())
- return arg_usage_t::input;
-
- if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER,
- MKLDNN_ARG_WEIGHTS_ITER))
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_BIAS && with_bias())
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DST_LAYER)
- return arg_usage_t::output;
-
- if (arg == MKLDNN_ARG_DST_ITER && with_dst_iter())
- return arg_usage_t::output;
-
- if (arg == MKLDNN_ARG_WORKSPACE && is_training())
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual int n_inputs() const override
- { return 3 + with_bias() + with_src_iter(); }
- virtual int n_outputs() const override
- { return 1 + with_dst_iter() + is_training(); }
-};
-
-struct rnn_bwd_pd_t : public rnn_pd_t {
- typedef rnn_bwd_pd_t base_class;
- typedef rnn_fwd_pd_t hint_class;
-
- rnn_bwd_pd_t(engine_t *engine,
- const rnn_desc_t *adesc,
- const primitive_attr_t *attr,
- const rnn_fwd_pd_t *hint_fwd_pd)
- : rnn_pd_t(engine, adesc, attr, hint_fwd_pd)
- , diff_src_layer_md_(desc_.diff_src_layer_desc)
- , diff_src_iter_md_(desc_.diff_src_iter_desc)
- , diff_weights_layer_md_(desc_.diff_weights_layer_desc)
- , diff_weights_iter_md_(desc_.diff_weights_iter_desc)
- , diff_bias_md_(desc_.diff_bias_desc)
- , diff_dst_layer_md_(desc_.diff_dst_layer_desc)
- , diff_dst_iter_md_(desc_.diff_dst_iter_desc)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (utils::one_of(arg, MKLDNN_ARG_SRC_LAYER, MKLDNN_ARG_DST_LAYER,
- MKLDNN_ARG_DIFF_DST_LAYER))
- return arg_usage_t::input;
-
- if (with_src_iter()) {
- if (arg == MKLDNN_ARG_SRC_ITER)
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DIFF_SRC_ITER)
- return arg_usage_t::output;
- }
-
- if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER,
- MKLDNN_ARG_WEIGHTS_ITER))
- return arg_usage_t::input;
-
- if (with_bias()) {
- if (arg == MKLDNN_ARG_BIAS)
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DIFF_BIAS)
- return arg_usage_t::output;
- }
-
- if (utils::one_of(arg, MKLDNN_ARG_DST_ITER, MKLDNN_ARG_DIFF_DST_ITER)
- && with_dst_iter())
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_WORKSPACE)
- return arg_usage_t::input;
-
- if (utils::one_of(arg, MKLDNN_ARG_DIFF_SRC_LAYER,
- MKLDNN_ARG_DIFF_WEIGHTS_LAYER,
- MKLDNN_ARG_DIFF_WEIGHTS_ITER))
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *diff_src_md(int index = 0) const override {
- if (index == 0) return &diff_src_layer_md_;
- if (index == 1 && with_src_iter()) return &diff_src_iter_md_;
- return nullptr;
- }
- virtual const memory_desc_t *diff_weights_md(
- int index = 0) const override {
- if (index == 0) return &diff_weights_layer_md_;
- if (index == 1) return &diff_weights_iter_md_;
- if (index == 2 && with_bias()) return &diff_bias_md_;
- return nullptr;
- }
- virtual const memory_desc_t *diff_dst_md(int index = 0) const override {
- if (index == 0) return &diff_dst_layer_md_;
- if (index == 1 && with_dst_iter()) return &diff_dst_iter_md_;
- return nullptr;
- }
-
- virtual int n_inputs() const override
- { return 6 + with_src_iter() + with_bias() + 2 * with_dst_iter(); }
- virtual int n_outputs() const override
- { return 3 + with_src_iter() + with_bias(); }
-
-protected:
- memory_desc_t diff_src_layer_md_;
- memory_desc_t diff_src_iter_md_;
- memory_desc_t diff_weights_layer_md_;
- memory_desc_t diff_weights_iter_md_;
- memory_desc_t diff_bias_md_;
- memory_desc_t diff_dst_layer_md_;
- memory_desc_t diff_dst_iter_md_;
-};
-
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp
deleted file mode 100644
index 6bc14fc72a..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp
+++ /dev/null
@@ -1,112 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "mkldnn_thread.hpp"
-#include "utils.hpp"
-
-#include "scratchpad.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-/* Allocating memory buffers on a page boundary to reduce TLB/page misses */
-const size_t page_size = 2097152;
-
-/*
- Implementation of the scratchpad_t interface that is compatible with
- a concurrent execution
-*/
-struct concurent_scratchpad_t : public scratchpad_t {
- concurent_scratchpad_t(size_t size) {
- size_ = size;
- scratchpad_ = (char *) malloc(size, page_size);
- assert(scratchpad_ != nullptr);
- }
-
- ~concurent_scratchpad_t() {
- free(scratchpad_);
- }
-
- virtual char *get() const {
- return scratchpad_;
- }
-
-private:
- char *scratchpad_;
- size_t size_;
-};
-
-/*
- Implementation of the scratchpad_t interface that uses a global
- scratchpad
-*/
-
-struct global_scratchpad_t : public scratchpad_t {
- global_scratchpad_t(size_t size) {
- if (size > size_) {
- if (scratchpad_ != nullptr) free(scratchpad_);
- size_ = size;
- scratchpad_ = (char *) malloc(size, page_size);
- assert(scratchpad_ != nullptr);
- }
- reference_count_++;
- }
-
- ~global_scratchpad_t() {
- reference_count_--;
- if (reference_count_ == 0) {
- free(scratchpad_);
- scratchpad_ = nullptr;
- size_ = 0;
- }
- }
-
- virtual char *get() const {
- return scratchpad_;
- }
-
-private:
- /*
- Using thread-local here is unnecessary and even buggy! All threads
- actually share the same scratchpad, which is created and queried only
- on the main thread. If the scratchpad is queried on some thread other
- than the one it was created on (e.g. the application calls the API from
- multiple threads), thread-local causes a segfault because the scratchpad
- is uninitialized on the current thread.
- */
- /*thread_local*/ static char *scratchpad_;
- /*thread_local*/ static size_t size_;
- /*thread_local*/ static unsigned int reference_count_;
-};
-
-/*thread_local*/ char *global_scratchpad_t::scratchpad_ = nullptr;
-/*thread_local*/ size_t global_scratchpad_t::size_ = 0;
-/*thread_local*/ unsigned int global_scratchpad_t::reference_count_ = 0;
-
-
-/*
- Scratchpad creation routine
-*/
-scratchpad_t *create_scratchpad(size_t size) {
-#ifndef MKLDNN_ENABLE_CONCURRENT_EXEC
- return new global_scratchpad_t(size);
-#else
- return new concurent_scratchpad_t(size);
-#endif
-}
-
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp
deleted file mode 100644
index f7a246bc99..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp
+++ /dev/null
@@ -1,36 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef COMMON_SCRATCHPAD_HPP
-#define COMMON_SCRATCHPAD_HPP
-
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-struct scratchpad_t {
- virtual ~scratchpad_t() {}
- virtual char *get() const = 0;
-};
-
-scratchpad_t *create_scratchpad(size_t size);
-
-}
-}
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp b/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp
deleted file mode 100644
index e32e735224..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp
+++ /dev/null
@@ -1,72 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::prop_kind;
-using namespace mkldnn::impl::types;
-
-namespace {
-status_t shuffle_desc_init(shuffle_desc_t *shuffle_desc, prop_kind_t prop_kind,
- const memory_desc_t *data_desc, int axis, dim_t group_size) {
- bool args_ok = true
- && !any_null(shuffle_desc, data_desc)
- && one_of(prop_kind, forward_training, forward_inference,
- backward, backward_data)
- && axis >= 0 && axis < data_desc->ndims
- && group_size > 0 && group_size <= data_desc->dims[axis];
- if (!args_ok) return invalid_arguments;
-
- auto sd = shuffle_desc_t();
- sd.primitive_kind = primitive_kind::shuffle;
- sd.prop_kind = prop_kind;
- sd.data_desc = *data_desc;
- sd.axis = axis;
- sd.group_size = group_size;
-
- bool consistency = true
- && sd.data_desc.dims[axis] % sd.group_size == 0;
- if (!consistency) return invalid_arguments;
-
- *shuffle_desc = sd;
- return success;
-}
-}
-
-status_t mkldnn_shuffle_forward_desc_init(shuffle_desc_t *shuffle_desc,
- prop_kind_t prop_kind, const memory_desc_t *data_desc, int axis,
- dim_t group_size) {
- if (!one_of(prop_kind, forward_training, forward_inference))
- return invalid_arguments;
- return shuffle_desc_init(shuffle_desc, prop_kind, data_desc, axis,
- group_size);
-}
-
-status_t mkldnn_shuffle_backward_desc_init(shuffle_desc_t *shuffle_desc,
- const memory_desc_t *diff_data_desc, int axis, dim_t group_size) {
- return shuffle_desc_init(shuffle_desc, backward_data, diff_data_desc, axis,
- group_size);
-}
-
-// vim: et ts=5 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp
deleted file mode 100644
index cc5553fe7f..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp
+++ /dev/null
@@ -1,121 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef SHUFFLE_PD_HPP
-#define SHUFFLE_PD_HPP
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "primitive_desc.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-struct shuffle_pd_t: public primitive_desc_t {
- static constexpr auto base_pkind = primitive_kind::shuffle;
-
- typedef shuffle_pd_t base_class;
- typedef shuffle_pd_t hint_class;
-
- shuffle_pd_t(engine_t *engine,
- const shuffle_desc_t *adesc,
- const primitive_attr_t *attr,
- const shuffle_pd_t *hint_fwd_pd)
- : primitive_desc_t(engine, attr, base_pkind)
- , desc_(*adesc)
- , hint_fwd_pd_(hint_fwd_pd)
- , data_md_(desc_.data_desc)
- {}
-
- const shuffle_desc_t *desc() const { return &desc_; }
- virtual const op_desc_t *op_desc() const override
- { return reinterpret_cast<const op_desc_t *>(this->desc()); }
- virtual void init_info() override { impl::init_info(this, this->info_); }
-
- virtual status_t query(query_t what, int idx, void *result) const override {
- switch (what) {
- case query::shuffle_d:
- *(const shuffle_desc_t**)result = desc(); break;
- default: return primitive_desc_t::query(what, idx, result);
- }
- return status::success;
- }
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (is_fwd()) {
- if (arg == MKLDNN_ARG_SRC)
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DST)
- return arg_usage_t::output;
- } else {
- if (arg == MKLDNN_ARG_DIFF_DST)
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DIFF_SRC)
- return arg_usage_t::output;
- }
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override
- { return index == 0 && is_fwd() ? &data_md_ : nullptr; }
- virtual const memory_desc_t *dst_md(int index = 0) const override
- { return index == 0 && is_fwd() ? &data_md_ : nullptr; }
-
- virtual const memory_desc_t *diff_src_md(int index = 0) const override
- { return index == 0 && !is_fwd() ? &data_md_ : nullptr; }
- virtual const memory_desc_t *diff_dst_md(int index = 0) const override
- { return index == 0 && !is_fwd() ? &data_md_ : nullptr; }
-
- virtual int n_inputs() const override { return 1; }
- virtual int n_outputs() const override { return 1; }
-
- /* shuffle aux functions */
-
- dim_t MB() const { return data_md()->dims[0]; }
- dim_t C() const { return ndims() >= 2 ? data_md()->dims[1] : 1; }
- dim_t D() const { return ndims() >= 5 ? data_md()->dims[ndims() - 3] : 1; }
- dim_t H() const { return ndims() >= 4 ? data_md()->dims[ndims() - 2] : 1; }
- dim_t W() const { return ndims() >= 3 ? data_md()->dims[ndims() - 1] : 1; }
-
- int ndims() const { return data_md()->ndims; }
-
- int axis() const { return desc_.axis; }
- dim_t group_size() const { return desc_.group_size; }
- dim_t axis_size() const { return data_md()->dims[axis()]; }
-
- bool is_fwd() const {
- return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
- prop_kind::forward_inference);
- }
-
- const memory_desc_t *data_md() const { return &data_md_; }
-
-protected:
- shuffle_desc_t desc_;
- const shuffle_pd_t *hint_fwd_pd_;
- memory_desc_t data_md_;
-};
-
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp b/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp
deleted file mode 100644
index 82848e3d1f..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp
+++ /dev/null
@@ -1,68 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "memory_desc_wrapper.hpp"
-#include "utils.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::prop_kind;
-using namespace mkldnn::impl::alg_kind;
-using namespace mkldnn::impl::types;
-
-namespace {
-status_t softmax_desc_init(softmax_desc_t *softmax_desc, prop_kind_t prop_kind,
- const memory_desc_t *data_desc, const memory_desc_t *diff_desc, int softmax_axis) {
- bool args_ok = true
- && !any_null(softmax_desc, data_desc)
- && 0 <= softmax_axis
- && softmax_axis < data_desc->ndims;
- if (!args_ok) return invalid_arguments;
-
- auto sd = softmax_desc_t();
- sd.primitive_kind = primitive_kind::softmax;
- sd.prop_kind = prop_kind;
-
- bool is_bwd = (sd.prop_kind == backward_data);
- sd.data_desc = *data_desc;
- sd.diff_desc = is_bwd ? *diff_desc : zero_md();
- sd.softmax_axis = softmax_axis;
-
- *softmax_desc = sd;
- return success;
-}
-}
-
-status_t mkldnn_softmax_forward_desc_init(softmax_desc_t *softmax_desc,
- prop_kind_t prop_kind, const memory_desc_t *data_desc,
- int softmax_axis) {
- if (!one_of(prop_kind, forward_inference, forward_training))
- return invalid_arguments;
- return softmax_desc_init(softmax_desc, prop_kind, data_desc, nullptr, softmax_axis);
-}
-
-status_t mkldnn_softmax_backward_desc_init(softmax_desc_t *softmax_desc,
- const memory_desc_t *diff_desc, const mkldnn_memory_desc_t *data_desc,
- int softmax_axis) {
- return softmax_desc_init(softmax_desc, prop_kind::backward_data,
- data_desc, diff_desc, softmax_axis);
-}
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp
deleted file mode 100644
index 8a16ce901c..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp
+++ /dev/null
@@ -1,161 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef SOFTMAX_PD_HPP
-#define SOFTMAX_PD_HPP
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "primitive_desc.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-struct softmax_fwd_pd_t;
-
-struct softmax_pd_t: public primitive_desc_t {
- static constexpr auto base_pkind = primitive_kind::softmax;
-
- softmax_pd_t(engine_t *engine,
- const softmax_desc_t *adesc,
- const primitive_attr_t *attr,
- const softmax_fwd_pd_t *hint_fwd_pd)
- : primitive_desc_t(engine, attr, base_pkind)
- , desc_(*adesc)
- , hint_fwd_pd_(hint_fwd_pd)
- , data_md_(desc_.data_desc)
- {}
-
- const softmax_desc_t *desc() const { return &desc_; }
- virtual const op_desc_t *op_desc() const override
- { return reinterpret_cast<const op_desc_t *>(this->desc()); }
- virtual void init_info() override { impl::init_info(this, this->info_); }
-
- virtual status_t query(query_t what, int idx, void *result) const override {
- switch (what) {
- case query::softmax_d:
- *(const softmax_desc_t**)result = desc(); break;
- default: return primitive_desc_t::query(what, idx, result);
- }
- return status::success;
- }
-
- /* common softmax aux functions */
-
- dim_t MB() const { return data_desc().dims[0]; }
- dim_t C() const { return data_desc().dims[1]; }
- dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
- dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
- dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
-
- int ndims() const { return data_desc().ndims; }
-
- bool is_fwd() const {
- return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
- prop_kind::forward_inference);
- }
-
-protected:
- softmax_desc_t desc_;
- const softmax_fwd_pd_t *hint_fwd_pd_;
-
- memory_desc_t data_md_;
-
-private:
- const memory_desc_t &data_desc() const { return desc_.data_desc; }
-};
-
-struct softmax_fwd_pd_t: public softmax_pd_t {
- typedef softmax_fwd_pd_t base_class;
- typedef softmax_fwd_pd_t hint_class;
-
- softmax_fwd_pd_t(engine_t *engine,
- const softmax_desc_t *adesc,
- const primitive_attr_t *attr,
- const softmax_fwd_pd_t *hint_fwd_pd)
- : softmax_pd_t(engine, adesc, attr, hint_fwd_pd)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (arg == MKLDNN_ARG_SRC)
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DST)
- return arg_usage_t::output;
-
- if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override
- { return index == 0 ? &data_md_ : nullptr; }
- virtual const memory_desc_t *dst_md(int index = 0) const override
- { return index == 0 ? &data_md_ : nullptr; }
-
- virtual int n_inputs() const override { return 1; }
- virtual int n_outputs() const override
- { return 1 + (workspace_md() != nullptr); }
-};
-
-struct softmax_bwd_pd_t: public softmax_pd_t {
- typedef softmax_bwd_pd_t base_class;
- typedef softmax_fwd_pd_t hint_class;
-
- softmax_bwd_pd_t(engine_t *engine,
- const softmax_desc_t *adesc,
- const primitive_attr_t *attr,
- const softmax_fwd_pd_t *hint_fwd_pd)
- : softmax_pd_t(engine, adesc, attr, hint_fwd_pd)
- , diff_data_md_(desc_.diff_desc)
- {}
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (utils::one_of(arg, MKLDNN_ARG_DST, MKLDNN_ARG_DIFF_DST))
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DIFF_SRC)
- return arg_usage_t::output;
-
- if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
- return arg_usage_t::input;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *dst_md(int index = 0) const override
- { return index == 0 ? &data_md_ : nullptr; }
- virtual const memory_desc_t *diff_dst_md(int index = 0) const override
- { return index == 0 ? &diff_data_md_ : nullptr; }
- virtual const memory_desc_t *diff_src_md(int index = 0) const override
- { return index == 0 ? &diff_data_md_ : nullptr; }
-
- virtual int n_inputs() const override
- { return 2 + (workspace_md() != nullptr); }
- virtual int n_outputs() const override { return 1; }
-
-protected:
- memory_desc_t diff_data_md_;
-};
-
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/stream.cpp b/thirdparty/oidn/mkl-dnn/src/common/stream.cpp
deleted file mode 100644
index 00af8935c0..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/stream.cpp
+++ /dev/null
@@ -1,46 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "engine.hpp"
-#include "stream.hpp"
-#include "utils.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::status;
-
-/* API */
-
-status_t mkldnn_stream_create(stream_t **stream, engine_t *engine,
- unsigned flags) {
- bool args_ok = true
- && !utils::any_null(stream, engine)
- && flags == stream_flags::default_flags;
- if (!args_ok)
- return invalid_arguments;
-
- return safe_ptr_assign<stream_t>(*stream, new stream_t(engine, flags));
-}
-
-status_t mkldnn_stream_destroy(stream_t *stream) {
- delete stream;
- return success;
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/stream.hpp b/thirdparty/oidn/mkl-dnn/src/common/stream.hpp
deleted file mode 100644
index f010e5f6ed..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/stream.hpp
+++ /dev/null
@@ -1,44 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef STREAM_HPP
-#define STREAM_HPP
-
-#include <assert.h>
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "engine.hpp"
-
-struct mkldnn_stream: public mkldnn::impl::c_compatible {
- mkldnn_stream(mkldnn::impl::engine_t *engine, unsigned flags)
- : engine_(engine), flags_(flags) {}
- virtual ~mkldnn_stream() {}
-
- /** returns stream's engine */
- mkldnn::impl::engine_t *engine() const { return engine_; }
-
- /** returns stream's kind */
- unsigned flags() const { return flags_; }
-
-protected:
- mkldnn::impl::engine_t *engine_;
- unsigned flags_;
-};
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/sum.cpp b/thirdparty/oidn/mkl-dnn/src/common/sum.cpp
deleted file mode 100644
index 365663c0f8..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/sum.cpp
+++ /dev/null
@@ -1,79 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "engine.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "sum_pd.hpp"
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::status;
-
-status_t mkldnn_sum_primitive_desc_create(primitive_desc_t **sum_pd,
- const memory_desc_t *dst_md, int n, const float *scales,
- const memory_desc_t *src_mds, const primitive_attr_t *attr,
- engine_t *engine) {
- bool args_ok = !any_null(sum_pd, src_mds, scales) && n > 0;
- if (!args_ok) return invalid_arguments;
-
- const primitive_attr_t dummy_attr;
- if (attr == NULL)
- attr = &dummy_attr;
-
- const int ndims = src_mds[0].ndims;
- const dims_t &dims = src_mds[0].dims;
- const data_type_t dt = src_mds[0].data_type;
-
- for (int i = 1; i < n; ++i) {
- if (src_mds[i].ndims != ndims) return invalid_arguments;
- for (int d = 0; d < ndims; ++d) {
- if (src_mds[i].dims[d] != dims[d])
- return invalid_arguments;
- }
- if (src_mds[i].data_type != dt) return invalid_arguments;
- }
-
- memory_desc_t dummy_dst_md;
- if (dst_md) {
- if (dst_md->ndims != ndims) return invalid_arguments;
- for (int d = 0; d < ndims; ++d) {
- if (dst_md->dims[d] != dims[d])
- return invalid_arguments;
- }
- } else {
- dummy_dst_md = src_mds[0];
- dummy_dst_md.format_kind = format_kind::any;
- dst_md = &dummy_dst_md;
- }
-
- auto s_pd = reinterpret_cast<sum_pd_t **>(sum_pd);
-
- for (auto s = engine->get_sum_implementation_list(); *s; ++s) {
- if ((*s)(s_pd, engine, attr, dst_md, n, scales, src_mds) == success) {
- (*s_pd)->init_info();
- (*s_pd)->init_scratchpad_md();
- return success;
- }
- }
- return unimplemented;
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp
deleted file mode 100644
index 80254667df..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp
+++ /dev/null
@@ -1,143 +0,0 @@
-/*******************************************************************************
-* Copyright 2019 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef SUM_PD_HPP
-#define SUM_PD_HPP
-
-#include <assert.h>
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "nstl.hpp"
-#include "primitive_desc.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-struct sum_pd_t: public primitive_desc_t {
- sum_pd_t(engine_t *engine, const primitive_attr_t *attr,
- const memory_desc_t *dst_md, int n, const float *scales,
- const memory_desc_t *src_mds)
- : primitive_desc_t(engine, attr, primitive_kind::sum)
- , n_(n), dst_md_(*dst_md)
- {
- scales_.reserve(n_);
- for (int i = 0; i < n_; ++i) scales_.push_back(scales[i]);
- src_mds_.reserve(n_);
- for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]);
- }
-
- virtual void init_info() override { impl::init_info(this, this->info_); }
-
- virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
- if (arg >= MKLDNN_ARG_MULTIPLE_SRC
- && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs())
- return arg_usage_t::input;
-
- if (arg == MKLDNN_ARG_DST)
- return arg_usage_t::output;
-
- return primitive_desc_t::arg_usage(arg);
- }
-
- virtual const memory_desc_t *src_md(int index = 0) const override
- { return index < n_inputs() ? &src_mds_[index] : nullptr; }
- virtual const memory_desc_t *dst_md(int index = 0) const override
- { return index == 0 ? &dst_md_ : nullptr; }
-
- virtual int n_inputs() const override { return n_; }
- virtual int n_outputs() const override { return 1; }
-
- const float *scales() const { return &scales_[0]; }
-
-protected:
- int n_;
- nstl::vector<float> scales_;
- memory_desc_t dst_md_;
- nstl::vector<memory_desc_t> src_mds_;
-
-protected:
- /* inits dst_md_ in simple cases. The call may fail. */
- status_t init() {
- for (int i = 0; i < n_; ++i) {
- const memory_desc_wrapper src_d(&src_mds_[i]);
- if (!src_d.is_blocking_desc() || src_d.is_additional_buffer())
- return status::unimplemented;
- }
- bool ok = true
- && set_default_params() == status::success
- && attr()->has_default_values();
- return ok ? status::success : status::unimplemented;
- }
-
- status_t set_default_params() {
- if (dst_md_.format_kind != format_kind::any)
- return status::success;
-
- /* The stupidest ever heuristics (but not the same as we had before):
- * - Pick the first non-plain format;
- * - If all formats are plain, pick the format of the first input
- */
- for (int i = 0; i < n_; ++i) {
- const memory_desc_wrapper src_d(src_mds_[i]);
- if (!src_d.is_plain() && src_d.is_blocking_desc()) {
- return memory_desc_init_by_blocking_desc(dst_md_,
- src_d.blocking_desc());
- }
- }
-
- if (src_mds_[0].format_kind != format_kind::blocked)
- return status::unimplemented;
-
- dst_md_ = src_mds_[0];
-
- return status::success;
- }
-};
-
-#define DECLARE_SUM_PD_t(impl_name, ...) \
- static status_t create(sum_pd_t **sum_pd, \
- engine_t *engine, const primitive_attr_t *attr, \
- const memory_desc_t *dst_md, int n, const float *scales, \
- const memory_desc_t *src_mds) { \
- using namespace status; \
- auto _pd = new pd_t(engine, attr, dst_md, n, scales, src_mds); \
- if (_pd == nullptr) return out_of_memory; \
- if (_pd->init() != success) { delete _pd; return unimplemented; } \
- return safe_ptr_assign<sum_pd_t>(*sum_pd, _pd); \
- } \
- virtual status_t create_primitive(primitive_t **p) const override { \
- double ms = get_msec(); \
- auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
- ms = get_msec() - ms; \
- if (mkldnn_verbose()->level >= 2) { \
- printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
- fflush(0); \
- } \
- return ret; \
- } \
- virtual pd_t *clone() const override { return new pd_t(*this); } \
- virtual const char *name() const override { return impl_name; } \
-
-#define DECLARE_SUM_PD_T(impl_name, ...) \
- DECLARE_SUM_PD_t(impl_name, __VA_ARGS__)
-
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp b/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp
deleted file mode 100644
index a408f45980..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp
+++ /dev/null
@@ -1,200 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef TAG_TRAITS_HPP
-#define TAG_TRAITS_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-enum class block_dim_t {
- _,
- _A, _B,
- _AB, _BC,
-};
-
-enum class inner_blk_t {
- _,
- _4a, _4b,
- _8a, _8b,
- _16a, _16b,
-
- _4b4a, _4b4c, _4c4b,
- _8a8b, _8b8a, _8b8c, _8c8b,
- _16a16b, _16a4b, _16b16a, _16b4c, _16b16c, _16c16b,
-
- _2c8b4c, _8a16b2a, _4b16a4b, _8b16a2b, _8b16c2b, _4c16b4c, _8c16b2c,
-};
-
-/** returns the offset within the block for weights blocked over oc and ic */
-template <inner_blk_t f>
-constexpr int AB_or_BC_blk_off(int x0, int x1) {
- using ib = inner_blk_t;
- static_assert(utils::one_of(f, ib::_4b4a, ib::_4b4c, ib::_4c4b, ib::_8a8b,
- ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_16a16b, ib::_16a4b,
- ib::_16b16a, ib::_16b4c, ib::_16b16c, ib::_16c16b, ib::_2c8b4c,
- ib::_8a16b2a, ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b,
- ib::_4c16b4c, ib::_8c16b2c),
- "unexpected inner_blk format");
- return false ? 0
- : (f == ib::_4b4c) ? 4 * x0 + x1
- : (f == ib::_4b4a || f == ib::_4c4b) ? 4 * x1 + x0
- : (f == ib::_8a8b || f == ib::_8b8c) ? 8 * x0 + x1
- : (f == ib::_8b8a || f == ib::_8c8b) ? 8 * x1 + x0
- : (f == ib::_16a16b || f == ib::_16b16c) ? 16 * x0 + x1
- : (f == ib::_16b16a || f == ib::_16c16b) ? 16 * x1 + x0
- : (f == ib::_16a4b || f == ib::_16b4c) ? 4 * x0 + x1
- : (f == ib::_8a16b2a || f == ib::_8b16c2b) ? (x0 / 2) * 32 + x1 * 2 + x0 % 2
- : (f == ib::_4b16a4b || f == ib::_4c16b4c) ? (x1 / 4) * 64 + x0 * 4 + x1 % 4
- : (f == ib::_8b16a2b || f == ib::_8c16b2c) ? (x1 / 2) * 32 + x0 * 2 + x1 % 2
- : (f == ib::_2c8b4c) ? (x1 / 4) * 32 + x0 * 4 + x1 % 4
- : INT_MIN;
-}
-
-template <inner_blk_t b> struct inner_blk_traits {
- using ib = inner_blk_t;
-};
-
-template <format_tag_t> struct tag_traits {
- // block_dim_t block_dims;
- // inner_blk_t inner_blks;
- // int ndims;
-};
-
-#define DECL_TRAITS(_tag, _blk_fmt, _inner_blk, _ndims) \
-template <> struct tag_traits<format_tag::_tag> { \
- static constexpr block_dim_t block_dims = block_dim_t::_blk_fmt; \
- static constexpr inner_blk_t inner_blks = inner_blk_t::_inner_blk; \
- static constexpr int ndims = _ndims; \
-}
-
-DECL_TRAITS(a, _, _, 1);
-DECL_TRAITS(ab, _, _, 2);
-DECL_TRAITS(abc, _, _, 3);
-DECL_TRAITS(abcd, _, _, 4);
-DECL_TRAITS(abcde, _, _, 5);
-DECL_TRAITS(abcdef, _, _, 6);
-DECL_TRAITS(abdec, _, _, 5);
-DECL_TRAITS(acb, _, _, 3);
-DECL_TRAITS(acbde, _, _, 5);
-DECL_TRAITS(acdb, _, _, 4);
-DECL_TRAITS(acdeb, _, _, 5);
-DECL_TRAITS(ba, _, _, 2);
-DECL_TRAITS(bac, _, _, 3);
-DECL_TRAITS(bacd, _, _, 4);
-DECL_TRAITS(bcda, _, _, 4);
-DECL_TRAITS(cba, _, _, 3);
-DECL_TRAITS(cdba, _, _, 4);
-DECL_TRAITS(cdeba, _, _, 5);
-DECL_TRAITS(decab, _, _, 5);
-
-DECL_TRAITS(Abc4a, _A, _4a, 3);
-DECL_TRAITS(aBc4b, _B, _4b, 3);
-DECL_TRAITS(ABc4b16a4b, _AB, _4b16a4b, 3);
-DECL_TRAITS(ABc4b4a, _AB, _4b4a, 3);
-DECL_TRAITS(Abcd4a, _A, _4a, 4);
-DECL_TRAITS(aBcd4b, _B, _4b, 4);
-DECL_TRAITS(ABcd4b4a, _AB, _4b4a, 4);
-DECL_TRAITS(aBCd4c16b4c, _BC, _4c16b4c, 4);
-DECL_TRAITS(aBCd4c4b, _BC, _4c4b, 4);
-DECL_TRAITS(Abcde4a, _A, _4a, 5);
-DECL_TRAITS(aBcde4b, _B, _4b, 5);
-DECL_TRAITS(ABcde4b4a, _AB, _4b4a, 5);
-DECL_TRAITS(aBCde4c4b, _BC, _4c4b, 5);
-DECL_TRAITS(aBcdef4b, _B, _4b, 6);
-DECL_TRAITS(aBCdef4c4b, _BC, _4c4b, 6);
-DECL_TRAITS(aBdc4b, _B, _4b, 4);
-DECL_TRAITS(aBdec4b, _B, _4b, 5);
-DECL_TRAITS(aBdefc4b, _B, _4b, 6);
-DECL_TRAITS(Acb4a, _A, _4a, 3);
-DECL_TRAITS(Acdb4a, _A, _4a, 4);
-DECL_TRAITS(Acdeb4a, _A, _4a, 5);
-
-DECL_TRAITS(Abc16a, _A, _16a, 3);
-DECL_TRAITS(ABc16a16b, _AB, _16a16b, 3);
-DECL_TRAITS(aBc16b, _B, _16b, 3);
-DECL_TRAITS(ABc16b16a, _AB, _16b16a, 3);
-DECL_TRAITS(ABc8a16b2a, _AB, _8a16b2a, 3);
-DECL_TRAITS(ABc8a8b, _AB, _8a8b, 3);
-DECL_TRAITS(aBc8b, _B, _8b, 3);
-DECL_TRAITS(ABc8b16a2b, _AB, _8b16a2b, 3);
-DECL_TRAITS(ABc8b8a, _AB, _8b8a, 3);
-DECL_TRAITS(Abcd16a, _A, _16a, 4);
-DECL_TRAITS(ABcd16a16b, _AB, _16a16b, 4);
-DECL_TRAITS(aBcd16b, _B, _16b, 4);
-DECL_TRAITS(ABcd16b16a, _AB, _16b16a, 4);
-DECL_TRAITS(aBCd16b16c, _BC, _16b16c, 4);
-DECL_TRAITS(aBCd16c16b, _BC, _16c16b, 4);
-DECL_TRAITS(ABcd4b16a4b, _AB, _4b16a4b, 4);
-DECL_TRAITS(ABcd8a16b2a, _AB, _8a16b2a, 4);
-DECL_TRAITS(ABcd8a8b, _AB, _8a8b, 4);
-DECL_TRAITS(aBcd8b, _B, _8b, 4);
-DECL_TRAITS(ABcd8b16a2b, _AB, _8b16a2b, 4);
-DECL_TRAITS(aBCd8b16c2b, _BC, _8b16c2b, 4);
-DECL_TRAITS(ABcd8b8a, _AB, _8b8a, 4);
-DECL_TRAITS(aBCd8b8c, _BC, _8b8c, 4);
-DECL_TRAITS(aBCd8c16b2c, _BC, _8c16b2c, 4);
-DECL_TRAITS(aBCd8c8b, _BC, _8c8b, 4);
-DECL_TRAITS(Abcde16a, _A, _16a, 5);
-DECL_TRAITS(ABcde16a16b, _AB, _16a16b, 5);
-DECL_TRAITS(aBcde16b, _B, _16b, 5);
-DECL_TRAITS(ABcde16b16a, _AB, _16b16a, 5);
-DECL_TRAITS(aBCde16b16c, _BC, _16b16c, 5);
-DECL_TRAITS(aBCde16c16b, _BC, _16c16b, 5);
-DECL_TRAITS(aBCde4c16b4c, _BC, _4c16b4c, 5);
-DECL_TRAITS(Abcde8a, _A, _8a, 5);
-DECL_TRAITS(ABcde8a8b, _AB, _8a8b, 5);
-DECL_TRAITS(aBcde8b, _B, _8b, 5);
-DECL_TRAITS(ABcde8b16a2b, _AB, _8b16a2b, 5);
-DECL_TRAITS(aBCde8b16c2b, _BC, _8b16c2b, 5);
-DECL_TRAITS(ABcde8b8a, _AB, _8b8a, 5);
-DECL_TRAITS(aBCde8b8c, _BC, _8b8c, 5);
-DECL_TRAITS(aBCde2c8b4c, _BC, _2c8b4c, 5);
-DECL_TRAITS(aBCde8c16b2c, _BC, _8c16b2c, 5);
-DECL_TRAITS(aBCde4b4c, _BC, _4b4c, 5);
-DECL_TRAITS(aBCde8c8b, _BC, _8c8b, 5);
-DECL_TRAITS(aBcdef16b, _B, _16b, 6);
-DECL_TRAITS(aBCdef16b16c, _BC, _16b16c, 6);
-DECL_TRAITS(aBCdef16c16b, _BC, _16c16b, 6);
-DECL_TRAITS(aBCdef8b8c, _BC, _8b8c, 6);
-DECL_TRAITS(aBCdef8c16b2c, _BC, _8c16b2c, 6);
-DECL_TRAITS(aBCdef8c8b, _BC, _8c8b, 6);
-DECL_TRAITS(aBdc16b, _B, _16b, 4);
-DECL_TRAITS(aBdc8b, _B, _8b, 4);
-DECL_TRAITS(aBdec16b, _B, _16b, 5);
-DECL_TRAITS(aBdec8b, _B, _8b, 5);
-DECL_TRAITS(aBdefc16b, _B, _16b, 6);
-DECL_TRAITS(aBdefc8b, _B, _8b, 6);
-DECL_TRAITS(Acb16a, _A, _16a, 3);
-DECL_TRAITS(Acb8a, _A, _8a, 3);
-DECL_TRAITS(aCBd16b16c, _BC, _16b16c, 4);
-DECL_TRAITS(aCBde16b16c, _BC, _16b16c, 5);
-DECL_TRAITS(Acdb16a, _A, _16a, 4);
-DECL_TRAITS(Acdb8a, _A, _8a, 4);
-DECL_TRAITS(Acdeb16a, _A, _16a, 5);
-DECL_TRAITS(Acdeb8a, _A, _8a, 5);
-DECL_TRAITS(BAc16a16b, _AB, _16a16b, 3);
-DECL_TRAITS(BAcd16a16b, _AB, _16a16b, 4);
-
-} // namespace impl
-} // namespace mkldnn
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp b/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp
deleted file mode 100644
index 4f06368738..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp
+++ /dev/null
@@ -1,348 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef TYPE_HELPERS_HPP
-#define TYPE_HELPERS_HPP
-
-#include <assert.h>
-#include <math.h>
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "mkldnn_traits.hpp"
-#include "nstl.hpp"
-#include "utils.hpp"
-#include "math_utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-template <typename T>
-status_t safe_ptr_assign(T * &lhs, T* rhs) {
- if (rhs == nullptr) return status::out_of_memory;
- lhs = rhs;
- return status::success;
-}
-
-template <typename T, typename U> struct is_subset
-{ static constexpr bool value = false; };
-template <typename T> struct is_subset<T, T>
-{ static constexpr bool value = true; };
-template <typename T> struct is_subset<T,
- typename utils::enable_if<nstl::is_integral<T>::value, float>::type>
-{ static constexpr bool value = true; };
-#define ISSPEC(t1, t2) template <> \
- struct is_subset<t1, t2> { static constexpr bool value = true; }
-ISSPEC(int16_t, int32_t);
-ISSPEC(int8_t, int32_t);
-ISSPEC(uint8_t, int32_t);
-ISSPEC(int8_t, int16_t);
-ISSPEC(uint8_t, int16_t);
-#undef ISSPEC
-
-inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs);
-
-namespace types {
-
-inline size_t data_type_size(data_type_t data_type) {
- using namespace data_type;
- switch (data_type) {
- case f32: return sizeof(prec_traits<f32>::type);
- case s32: return sizeof(prec_traits<s32>::type);
- case s8: return sizeof(prec_traits<s8>::type);
- case u8: return sizeof(prec_traits<u8>::type);
- case data_type::undef:
- default: assert(!"unknown data_type");
- }
- return 0; /* not supposed to be reachable */
-}
-
-inline format_kind_t format_tag_to_kind(format_tag_t tag) {
- switch (tag) {
- case format_tag::undef: return format_kind::undef;
- case format_tag::any: return format_kind::any;
- case format_tag::last: return format_kind::undef;
- default: return format_kind::blocked;
- }
-
- assert(!"unreachable");
- return format_kind::undef;
-}
-
-inline bool memory_extra_desc_is_equal(const memory_extra_desc_t &lhs,
- const memory_extra_desc_t &rhs) {
- return true
- && lhs.flags == rhs.flags
- && IMPLICATION(lhs.flags & memory_extra_flags::compensation_conv_s8s8,
- lhs.compensation_mask == rhs.compensation_mask)
- && IMPLICATION(lhs.flags & memory_extra_flags::scale_adjust,
- lhs.scale_adjust == rhs.scale_adjust);
-}
-
-inline bool blocking_desc_is_equal(const blocking_desc_t &lhs,
- const blocking_desc_t &rhs, int ndims = MKLDNN_MAX_NDIMS) {
- using mkldnn::impl::utils::array_cmp;
- return true
- && lhs.inner_nblks == rhs.inner_nblks
- && array_cmp(lhs.strides, rhs.strides, ndims)
- && array_cmp(lhs.inner_blks, rhs.inner_blks, lhs.inner_nblks)
- && array_cmp(lhs.inner_idxs, rhs.inner_idxs, lhs.inner_nblks);
-}
-
-inline bool wino_desc_is_equal(const wino_desc_t &lhs,
- const wino_desc_t &rhs) {
- return lhs.wino_format == rhs.wino_format
- && lhs.alpha == rhs.alpha
- && lhs.ic == rhs.ic
- && lhs.oc == rhs.oc
- && lhs.ic_block == rhs.ic_block
- && lhs.oc_block == rhs.oc_block
- && lhs.ic2_block == rhs.ic2_block
- && lhs.oc2_block == rhs.oc2_block
- && lhs.r == rhs.r;
-}
-
-inline bool rnn_packed_desc_is_equal(
- const rnn_packed_desc_t &lhs, const rnn_packed_desc_t &rhs) {
- bool ok = true
- && lhs.format == rhs.format
- && lhs.n_parts == rhs.n_parts
- && lhs.offset_compensation == rhs.offset_compensation
- && lhs.size == rhs.size
- && lhs.n == rhs.n;
- if (!ok)
- return false;
-
- for (int i = 0; i < rhs.n_parts; i++)
- ok = ok && lhs.parts[i] == rhs.parts[i];
- for (int i = 0; i < rhs.n_parts; i++)
- ok = ok && lhs.part_pack_size[i] == rhs.part_pack_size[i];
- return ok;
-}
-
-inline memory_desc_t zero_md() {
- auto zero = memory_desc_t();
- return zero;
-}
-
-inline bool is_zero_md(const memory_desc_t *md) {
- return md == nullptr || *md == zero_md();
-}
-
-inline data_type_t default_accum_data_type(data_type_t src_dt,
- data_type_t dst_dt) {
- using namespace utils;
- using namespace data_type;
-
- if (one_of(f32, src_dt, dst_dt)) return f32;
- if (one_of(s32, src_dt, dst_dt)) return s32;
-
- if (one_of(s8, src_dt, dst_dt) || one_of(u8, src_dt, dst_dt)) return s32;
-
- assert(!"unimplemented use-case: no default parameters available");
- return dst_dt;
-}
-
-inline data_type_t default_accum_data_type(data_type_t src_dt,
- data_type_t wei_dt, data_type_t dst_dt, prop_kind_t prop_kind) {
- using namespace utils;
- using namespace data_type;
- using namespace prop_kind;
-
- /* prop_kind doesn't matter */
- if (everyone_is(f32, src_dt, wei_dt, dst_dt)) return f32;
-
- if (one_of(prop_kind, forward_training, forward_inference)) {
- if ((src_dt == u8 || src_dt == s8)
- && wei_dt == s8 && one_of(dst_dt, f32, s32, s8, u8))
- return s32;
- } else if (prop_kind == backward_data) {
- if (one_of(src_dt, f32, s32, s8, u8) && wei_dt == s8 &&
- one_of(dst_dt, s8, u8))
- return s32;
- }
-
- assert(!"unimplemented use-case: no default parameters available");
- return dst_dt;
-}
-
-}
-
-inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs) {
- using namespace mkldnn::impl::utils;
- bool base_equal = true
- && lhs.ndims == rhs.ndims
- && array_cmp(lhs.dims, rhs.dims, lhs.ndims)
- && lhs.data_type == rhs.data_type
- && array_cmp(lhs.padded_dims, rhs.padded_dims, lhs.ndims)
- && array_cmp(lhs.padded_offsets, rhs.padded_offsets, lhs.ndims)
- && lhs.offset0 == rhs.offset0
- && lhs.format_kind == rhs.format_kind;
- if (!base_equal) return false;
- if (!types::memory_extra_desc_is_equal(lhs.extra, rhs.extra)) return false;
- if (lhs.format_kind == format_kind::blocked)
- return types::blocking_desc_is_equal(lhs.format_desc.blocking,
- rhs.format_desc.blocking, lhs.ndims);
- else if (lhs.format_kind == format_kind::wino)
- return types::wino_desc_is_equal(lhs.format_desc.wino_desc,
- rhs.format_desc.wino_desc);
- else if (lhs.format_kind == format_kind::rnn_packed)
- return types::rnn_packed_desc_is_equal(lhs.format_desc.rnn_packed_desc,
- rhs.format_desc.rnn_packed_desc);
- return true;
-}
-
-inline bool operator!=(const memory_desc_t &lhs, const memory_desc_t &rhs) {
- return !operator==(lhs, rhs);
-}
-
-inline status_t memory_desc_init_by_strides(memory_desc_t &md,
- const dims_t strides) {
- return mkldnn_memory_desc_init_by_strides(
- &md, md.ndims, md.dims, md.data_type, strides);
-}
-
-inline status_t memory_desc_init_by_tag(memory_desc_t &md, format_tag_t tag,
- const dims_t strides = nullptr) {
- status_t status = mkldnn_memory_desc_init_by_tag(
- &md, md.ndims, md.dims, md.data_type, tag);
- if (status != status::success || strides == nullptr)
- return status;
-
- /* TODO: add consistency check */
-
- for (int d = 0; d < md.ndims; ++d)
- md.format_desc.blocking.strides[d] = strides[d];
-
- return status::success;
-}
-
-/** inits memory descriptor based on logical dimensions kept in @p md, and the
- * blocking structure @p blk.
- *
- * @note blk.strides represent the order only (from smaller to bigger)
- *
- * TODO: move md related functions to one single place
- */
-inline status_t memory_desc_init_by_blocking_desc(memory_desc_t &md,
- const blocking_desc_t &blk) {
- dims_t blocks = {0};
- utils::array_set(blocks, 1, md.ndims);
- dim_t block_size = 1;
- for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) {
- blocks[blk.inner_idxs[iblk]] *= blk.inner_blks[iblk];
- block_size *= blk.inner_blks[iblk];
- }
-
- for (int d = 0; d < md.ndims; ++d) {
- md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]);
- md.padded_offsets[d] = 0;
- }
- md.offset0 = 0;
-
- md.format_kind = format_kind::blocked;
- auto &mblk = md.format_desc.blocking;
- mblk = blk;
-
- const int ndims = nstl::min(MKLDNN_MAX_NDIMS, md.ndims); // make GCC 5 happy
- utils::array_copy(mblk.strides, blk.strides, ndims);
-
- int perm[MKLDNN_MAX_NDIMS];
- for (int d = 0; d < ndims; ++d) perm[d] = d;
-
- utils::simultaneous_sort(mblk.strides, perm, ndims,
- [](stride_t a, stride_t b) { return b - a; });
-
- dim_t stride = block_size;
- for (int _d = ndims - 1; _d >= 0; --_d) {
- const int d = perm[_d];
- md.format_desc.blocking.strides[d] = stride;
- stride *= md.padded_dims[d] / blocks[d];
- }
-
- md.extra = utils::zero<memory_extra_desc_t>();
-
- return status::success;
-}
-
-/** returns true if memory desc @p md corresponds to the given format tag and
- * strides.
- * If strides are not passed (or passed as nullptr) the dense structure is
- * assumed (i.e. the one that mkldnn_memory_desc_init_by_tag() returns).
- * Strides might contain `0` value, indicating the stride must match the one
- * that mkldnn_memory_desc_init_by_tag() returns.
- * Strides might contain `-1` values, that would be ignored during the
- * comparison. For instance, this can be used if a stride along minibatch
- * doesn't matter. */
-inline bool memory_desc_matches_tag(const memory_desc_t &md, format_tag_t tag,
- const dims_t strides = nullptr) {
- if (md.format_kind != types::format_tag_to_kind(tag))
- return false;
-
- memory_desc_t md_gold;
- status_t status = mkldnn_memory_desc_init_by_tag(
- &md_gold, md.ndims, md.dims, md.data_type, tag);
- if (status != status::success) return false;
-
- if (md.format_kind != format_kind::blocked)
- return false; // unimplemented yet
-
- const auto &blk = md.format_desc.blocking;
- const auto &blk_gold = md_gold.format_desc.blocking;
-
- using utils::array_cmp;
- bool same_blocks = true
- && blk.inner_nblks == blk_gold.inner_nblks
- && array_cmp(blk.inner_blks, blk_gold.inner_blks, blk.inner_nblks)
- && array_cmp(blk.inner_idxs, blk_gold.inner_idxs, blk.inner_nblks);
-
- if (!same_blocks)
- return false;
-
- if (strides == nullptr)
- return array_cmp(blk.strides, blk_gold.strides, md.ndims);
-
- for (int d = 0; d < md.ndims; ++d) {
- dim_t stride = strides[d];
- if (stride == -1) continue;
- if (stride == 0) stride = blk_gold.strides[d];
- if (blk.strides[d] != stride) return false;
- }
-
- return true;
-}
-
-/** returns matching tag (or undef if match is not found)
- * XXX: This is a workaround that eventually should go away! */
-template <typename... Tags>
-format_tag_t memory_desc_matches_one_of_tag(const memory_desc_t &md,
- Tags ...tags) {
- for (const auto tag: {tags...}) {
- if (memory_desc_matches_tag(md, tag))
- return tag;
- }
- return format_tag::undef;
-}
-
-}
-}
-
-#include "memory_desc_wrapper.hpp"
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/utils.cpp b/thirdparty/oidn/mkl-dnn/src/common/utils.cpp
deleted file mode 100644
index d23f4682dc..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/utils.cpp
+++ /dev/null
@@ -1,135 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <string.h>
-#ifdef _WIN32
-#include <malloc.h>
-#include <windows.h>
-#endif
-#include <limits.h>
-#include <stdlib.h>
-#include <stdio.h>
-
-#include "mkldnn.h"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-int getenv(const char *name, char *buffer, int buffer_size) {
- if (name == NULL || buffer_size < 0 || (buffer == NULL && buffer_size > 0))
- return INT_MIN;
-
- int result = 0;
- int term_zero_idx = 0;
- size_t value_length = 0;
-
-#ifdef _WIN32
- value_length = GetEnvironmentVariable(name, buffer, buffer_size);
-#else
- const char *value = ::getenv(name);
- value_length = value == NULL ? 0 : strlen(value);
-#endif
-
- if (value_length > INT_MAX)
- result = INT_MIN;
- else {
- int int_value_length = (int)value_length;
- if (int_value_length >= buffer_size) {
- result = -int_value_length;
- } else {
- term_zero_idx = int_value_length;
- result = int_value_length;
-#ifndef _WIN32
- strncpy(buffer, value, value_length);
-#endif
- }
- }
-
- if (buffer != NULL)
- buffer[term_zero_idx] = '\0';
- return result;
-}
-
-int getenv_int(const char *name, int default_value)
-{
- int value = default_value;
- // # of digits in the longest 32-bit signed int + sign + terminating null
- const int len = 12;
- char value_str[len];
- if (getenv(name, value_str, len) > 0)
- value = atoi(value_str);
- return value;
-}
-
-FILE *fopen(const char *filename, const char *mode) {
-#ifdef _WIN32
- FILE *fp = NULL;
- return ::fopen_s(&fp, filename, mode) ? NULL : fp;
-#else
- return ::fopen(filename, mode);
-#endif
-}
-
-void *malloc(size_t size, int alignment) {
- void *ptr;
-
-#ifdef _WIN32
- ptr = _aligned_malloc(size, alignment);
- int rc = ptr ? 0 : -1;
-#else
- int rc = ::posix_memalign(&ptr, alignment, size);
-#endif
-
- return (rc == 0) ? ptr : 0;
-}
-
-void free(void *p) {
-#ifdef _WIN32
- _aligned_free(p);
-#else
- ::free(p);
-#endif
-}
-
-// Atomic operations
-int32_t fetch_and_add(int32_t *dst, int32_t val) {
-#ifdef _WIN32
- return InterlockedExchangeAdd(reinterpret_cast<long*>(dst), val);
-#else
- return __sync_fetch_and_add(dst, val);
-#endif
-}
-
-static int jit_dump_flag = 0;
-static bool jit_dump_flag_initialized = false;
-bool jit_dump_enabled() {
- if (!jit_dump_flag_initialized) {
- jit_dump_flag = getenv_int("MKLDNN_JIT_DUMP");
- jit_dump_flag_initialized = true;
- }
- return jit_dump_flag != 0;
-}
-
-}
-}
-
-mkldnn_status_t mkldnn_set_jit_dump(int enabled) {
- using namespace mkldnn::impl::status;
- mkldnn::impl::jit_dump_flag = enabled;
- mkldnn::impl::jit_dump_flag_initialized = true;
- return success;
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/utils.hpp b/thirdparty/oidn/mkl-dnn/src/common/utils.hpp
deleted file mode 100644
index d5a8ec5139..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/utils.hpp
+++ /dev/null
@@ -1,370 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef UTILS_HPP
-#define UTILS_HPP
-
-#include <stddef.h>
-#include <stdio.h>
-#include <stdlib.h>
-#include <assert.h>
-#include <stdint.h>
-
-#if defined(__x86_64__) || defined(_M_X64)
-#define MKLDNN_X86_64
-#endif
-
-#define MSAN_ENABLED 0
-#if defined(__has_feature)
-#if __has_feature(memory_sanitizer)
-#undef MSAN_ENABLED
-#define MSAN_ENABLED 1
-#include <sanitizer/msan_interface.h>
-#endif
-#endif
-
-#include "c_types_map.hpp"
-#include "nstl.hpp"
-#include "z_magic.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-// Sanity check for 64 bits
-static_assert(sizeof(void*) == 8, "Intel(R) MKL-DNN supports 64 bit only");
-
-#define CHECK(f) do { \
- status_t status = f; \
- if (status != status::success) \
- return status; \
-} while (0)
-
-#define IMPLICATION(cause, effect) (!(cause) || !!(effect))
-
-namespace utils {
-
-/* a bunch of std:: analogues to be compliant with any msvs version
- *
- * Rationale: msvs c++ (and even some c) headers contain special pragma that
- * injects msvs-version check into object files in order to abi-mismatches
- * during the static linking. This makes sense if e.g. std:: objects are passed
- * through between application and library, which is not the case for mkl-dnn
- * (since there is no any c++-rt dependent stuff, ideally...). */
-
-/* SFINAE helper -- analogue to std::enable_if */
-template<bool expr, class T = void> struct enable_if {};
-template<class T> struct enable_if<true, T> { typedef T type; };
-
-/* analogue std::conditional */
-template <bool, typename, typename> struct conditional {};
-template <typename T, typename F> struct conditional<true, T, F>
-{ typedef T type; };
-template <typename T, typename F> struct conditional<false, T, F>
-{ typedef F type; };
-
-template <bool, typename, bool, typename, typename> struct conditional3 {};
-template <typename T, typename FT, typename FF>
-struct conditional3<true, T, false, FT, FF> { typedef T type; };
-template <typename T, typename FT, typename FF>
-struct conditional3<false, T, true, FT, FF> { typedef FT type; };
-template <typename T, typename FT, typename FF>
-struct conditional3<false, T, false, FT, FF> { typedef FF type; };
-
-template <bool, typename U, U, U> struct conditional_v {};
-template <typename U, U t, U f> struct conditional_v<true, U, t, f>
-{ static constexpr U value = t; };
-template <typename U, U t, U f> struct conditional_v<false, U, t, f>
-{ static constexpr U value = f; };
-
-template <typename T> struct remove_reference { typedef T type; };
-template <typename T> struct remove_reference<T&> { typedef T type; };
-template <typename T> struct remove_reference<T&&> { typedef T type; };
-
-template <typename T>
-inline T&& forward(typename utils::remove_reference<T>::type &t)
-{ return static_cast<T&&>(t); }
-template <typename T>
-inline T&& forward(typename utils::remove_reference<T>::type &&t)
-{ return static_cast<T&&>(t); }
-
-template <typename T>
-inline typename remove_reference<T>::type zero()
-{ auto zero = typename remove_reference<T>::type(); return zero; }
-
-template <typename T, typename P>
-inline bool everyone_is(T val, P item) { return val == item; }
-template <typename T, typename P, typename... Args>
-inline bool everyone_is(T val, P item, Args... item_others) {
- return val == item && everyone_is(val, item_others...);
-}
-
-template <typename T, typename P>
-constexpr bool one_of(T val, P item) { return val == item; }
-template <typename T, typename P, typename... Args>
-constexpr bool one_of(T val, P item, Args... item_others) {
- return val == item || one_of(val, item_others...);
-}
-
-template <typename... Args>
-inline bool any_null(Args... ptrs) { return one_of(nullptr, ptrs...); }
-
-template<typename T>
-inline void array_copy(T *dst, const T *src, size_t size) {
- for (size_t i = 0; i < size; ++i) dst[i] = src[i];
-}
-template<typename T>
-inline bool array_cmp(const T *a1, const T *a2, size_t size) {
- for (size_t i = 0; i < size; ++i) if (a1[i] != a2[i]) return false;
- return true;
-}
-template<typename T, typename U>
-inline void array_set(T *arr, const U& val, size_t size) {
- for (size_t i = 0; i < size; ++i) arr[i] = static_cast<T>(val);
-}
-
-namespace product_impl {
-template<size_t> struct int2type{};
-
-template <typename T>
-constexpr int product_impl(const T *arr, int2type<0>) { return arr[0]; }
-
-template <typename T, size_t num>
-inline T product_impl(const T *arr, int2type<num>) {
- return arr[0]*product_impl(arr+1, int2type<num-1>()); }
-}
-
-template <size_t num, typename T>
-inline T array_product(const T *arr) {
- return product_impl::product_impl(arr, product_impl::int2type<num-1>());
-}
-
-template<typename T, typename R = T>
-inline R array_product(const T *arr, size_t size) {
- R prod = 1;
- for (size_t i = 0; i < size; ++i) prod *= arr[i];
- return prod;
-}
-
-/** sorts an array of values using @p comparator. While sorting the array
- * of value, the function permutes an array of @p keys accordingly.
- *
- * @note The arrays of @p keys can be omitted. In this case the function
- * sorts the array of @vals only.
- */
-template <typename T, typename U, typename F>
-inline void simultaneous_sort(T *vals, U *keys, size_t size, F comparator) {
- if (size == 0) return;
-
- for (size_t i = 0; i < size - 1; ++i) {
- bool swapped = false;
-
- for (size_t j = 0; j < size - i - 1; j++) {
- if (comparator(vals[j], vals[j + 1]) > 0) {
- nstl::swap(vals[j], vals[j + 1]);
- if (keys) nstl::swap(keys[j], keys[j + 1]);
- swapped = true;
- }
- }
-
- if (swapped == false) break;
- }
-}
-
-template <typename T, typename U>
-inline typename remove_reference<T>::type div_up(const T a, const U b) {
- assert(b);
- return (a + b - 1) / b;
-}
-
-template <typename T, typename U>
-inline typename remove_reference<T>::type rnd_up(const T a, const U b) {
- return div_up(a, b) * b;
-}
-
-template <typename T, typename U>
-inline typename remove_reference<T>::type rnd_dn(const T a, const U b) {
- return (a / b) * b;
-}
-
-template <typename T> T *align_ptr(T *ptr, uintptr_t alignment)
-{ return (T *)(((uintptr_t)ptr + alignment - 1) & ~(alignment - 1)); }
-
-template <typename T, typename U, typename V>
-inline U this_block_size(const T offset, const U max, const V block_size) {
- assert(offset < max);
- // TODO (Roma): can't use nstl::max() due to circular dependency... we
- // need to fix this
- const T block_boundary = offset + block_size;
- if (block_boundary > max)
- return max - offset;
- else
- return block_size;
-}
-
-template<typename T>
-inline T nd_iterator_init(T start) { return start; }
-template<typename T, typename U, typename W, typename... Args>
-inline T nd_iterator_init(T start, U &x, const W &X, Args &&... tuple) {
- start = nd_iterator_init(start, utils::forward<Args>(tuple)...);
- x = start % X;
- return start / X;
-}
-
-inline bool nd_iterator_step() { return true; }
-template<typename U, typename W, typename... Args>
-inline bool nd_iterator_step(U &x, const W &X, Args &&... tuple) {
- if (nd_iterator_step(utils::forward<Args>(tuple)...) ) {
- x = (x + 1) % X;
- return x == 0;
- }
- return false;
-}
-
-template<typename U, typename W, typename Y>
-inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X)
-{
- U max_jump = end - cur;
- U dim_jump = X - x;
- if (dim_jump <= max_jump) {
- x = 0;
- cur += dim_jump;
- return true;
- } else {
- cur += max_jump;
- x += max_jump;
- return false;
- }
-}
-template<typename U, typename W, typename Y, typename... Args>
-inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X,
- Args &&... tuple)
-{
- if (nd_iterator_jump(cur, end, utils::forward<Args>(tuple)...)) {
- x = (x + 1) % X;
- return x == 0;
- }
- return false;
-}
-
-template <typename T>
-inline T pick(size_t i, const T &x0) { return x0; }
-template <typename T, typename ...Args>
-inline T pick(size_t i, const T &x0, Args &&... args) {
- return i == 0 ? x0 : pick(i - 1, utils::forward<Args>(args)...);
-}
-
-template <typename T>
-T pick_by_prop_kind(prop_kind_t prop_kind, const T &val_fwd_inference,
- const T &val_fwd_training, const T &val_bwd_d, const T &val_bwd_w) {
- switch (prop_kind) {
- case prop_kind::forward_inference: return val_fwd_inference;
- case prop_kind::forward_training: return val_fwd_training;
- case prop_kind::backward_data: return val_bwd_d;
- case prop_kind::backward_weights: return val_bwd_w;
- default: assert(!"unsupported prop_kind");
- }
- return T();
-}
-
-template <typename T>
-T pick_by_prop_kind(prop_kind_t prop_kind,
- const T &val_fwd, const T &val_bwd_d, const T &val_bwd_w)
-{ return pick_by_prop_kind(prop_kind, val_fwd, val_fwd, val_bwd_d, val_bwd_w); }
-
-template <typename Telem, size_t Tdims>
-struct array_offset_calculator {
- template <typename... Targs>
- array_offset_calculator(Telem *base, Targs... Fargs) : _dims{ Fargs... }
- {
- _base_ptr = base;
- }
- template <typename... Targs>
- inline Telem &operator()(Targs... Fargs)
- {
- return *(_base_ptr + _offset(1, Fargs...));
- }
-
-private:
- template <typename... Targs>
- inline size_t _offset(size_t const dimension, size_t element)
- {
- return element;
- }
-
- template <typename... Targs>
- inline size_t _offset(size_t const dimension, size_t theta, size_t element)
- {
- return element + (_dims[dimension] * theta);
- }
-
- template <typename... Targs>
- inline size_t _offset(size_t const dimension, size_t theta, size_t element,
- Targs... Fargs)
- {
- size_t t_prime = element + (_dims[dimension] * theta);
- return _offset(dimension + 1, t_prime, Fargs...);
- }
-
- Telem *_base_ptr;
- const int _dims[Tdims];
-};
-
-}
-
-int32_t fetch_and_add(int32_t *dst, int32_t val);
-inline void yield_thread() {}
-
-// Reads an environment variable 'name' and stores its string value in the
-// 'buffer' of 'buffer_size' bytes on success.
-//
-// - Returns the length of the environment variable string value (excluding
-// the terminating 0) if it is set and its contents (including the terminating
-// 0) can be stored in the 'buffer' without truncation.
-//
-// - Returns negated length of environment variable string value and writes
-// "\0" to the buffer (if it is not NULL) if the 'buffer_size' is to small to
-// store the value (including the terminating 0) without truncation.
-//
-// - Returns 0 and writes "\0" to the buffer (if not NULL) if the environment
-// variable is not set.
-//
-// - Returns INT_MIN if the 'name' is NULL.
-//
-// - Returns INT_MIN if the 'buffer_size' is negative.
-//
-// - Returns INT_MIN if the 'buffer' is NULL and 'buffer_size' is greater than
-// zero. Passing NULL 'buffer' with 'buffer_size' set to 0 can be used to
-// retrieve the length of the environment variable value string.
-//
-int getenv(const char *name, char *buffer, int buffer_size);
-// Reads an integer from the environment
-int getenv_int(const char *name, int default_value = 0);
-bool jit_dump_enabled();
-FILE *fopen(const char *filename, const char *mode);
-
-constexpr int msan_enabled = MSAN_ENABLED;
-inline void msan_unpoison(void *ptr, size_t size) {
-#if MSAN_ENABLED
- __msan_unpoison(ptr, size);
-#endif
-}
-
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp b/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp
deleted file mode 100644
index 89a57772cf..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp
+++ /dev/null
@@ -1,665 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <stdlib.h>
-#ifndef _WIN32
-#include <sys/time.h>
-#endif
-
-#include "mkldnn.h"
-#include "mkldnn_version.h"
-#include "c_types_map.hpp"
-#include "verbose.hpp"
-#include "cpu/cpu_isa_traits.hpp"
-
-#include "batch_normalization_pd.hpp"
-#include "pooling_pd.hpp"
-#include "concat_pd.hpp"
-#include "reorder_pd.hpp"
-#include "convolution_pd.hpp"
-#include "rnn_pd.hpp"
-#include "deconvolution_pd.hpp"
-#include "shuffle_pd.hpp"
-#include "eltwise_pd.hpp"
-#include "softmax_pd.hpp"
-#include "inner_product_pd.hpp"
-#include "sum_pd.hpp"
-#include "lrn_pd.hpp"
-
-/* MKL-DNN CPU ISA info */
-#define ISA_ANY "No instruction set specific optimizations"
-#define SSE42 "Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2)"
-#define AVX "Intel(R) Advanced Vector Extensions (Intel(R) AVX)"
-#define AVX2 "Intel(R) Advanced Vector Extensions 2 (Intel(R) AVX2)"
-#define AVX512_COMMON "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \
- "AVX-512)"
-#define AVX512_CORE "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \
- "AVX-512) with AVX512BW, AVX512VL, and AVX512DQ extensions"
-#define AVX512_CORE_VNNI "Intel(R) AVX512-Deep Learning Boost (Intel(R) " \
- "AVX512-DL Boost)"
-#define AVX512_MIC "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \
- "AVX-512) with AVX512CD, AVX512ER, and AVX512PF extensions"
-#define AVX512_MIC_4OPS "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \
- "AVX-512) with AVX512_4FMAPS and AVX512_4VNNIW extensions"
-
-namespace mkldnn {
-namespace impl {
-
-static verbose_t verbose;
-static bool initialized;
-static bool version_printed = false;
-
-const verbose_t *mkldnn_verbose() {
-#if !defined(DISABLE_VERBOSE)
- if (!initialized) {
- const int len = 2;
- char val[len] = {0};
- if (getenv("MKLDNN_VERBOSE", val, len) == 1)
- verbose.level = atoi(val);
- initialized = true;
- }
- if (!version_printed && verbose.level > 0) {
- printf("mkldnn_verbose,info,"
- "Intel(R) MKL-DNN v%d.%d.%d (Git Hash %s),%s\n",
- mkldnn_version()->major, mkldnn_version()->minor,
- mkldnn_version()->patch, mkldnn_version()->hash,
- get_isa_info());
- version_printed = true;
- }
-#else
- verbose.level = 0;
-#endif
- return &verbose;
-}
-
-double get_msec() {
-#ifdef _WIN32
- static LARGE_INTEGER frequency;
- if (frequency.QuadPart == 0)
- QueryPerformanceFrequency(&frequency);
- LARGE_INTEGER now;
- QueryPerformanceCounter(&now);
- return 1e+3 * now.QuadPart / frequency.QuadPart;
-#else
- struct timeval time;
- gettimeofday(&time, NULL);
- return 1e+3 * time.tv_sec + 1e-3 * time.tv_usec;
-#endif
-}
-
-const char *get_isa_info() {
- using namespace mkldnn::impl::cpu;
- if (mayiuse(avx512_mic_4ops)) return AVX512_MIC_4OPS;
- if (mayiuse(avx512_mic)) return AVX512_MIC;
- if (mayiuse(avx512_core_vnni)) return AVX512_CORE_VNNI;
- if (mayiuse(avx512_core)) return AVX512_CORE;
- if (mayiuse(avx512_common)) return AVX512_COMMON;
- if (mayiuse(avx2)) return AVX2;
- if (mayiuse(avx)) return AVX;
- if (mayiuse(sse42)) return SSE42;
- return ISA_ANY;
-}
-
-/* init_info section */
-namespace {
-#if !defined(DISABLE_VERBOSE)
-#define MKLDNN_VERBOSE_DAT_LEN 256
-#define MKLDNN_VERBOSE_AUX_LEN 384
-#define MKLDNN_VERBOSE_PRB_LEN 384
-
-#define DECL_DAT_AUX_PRB_STRS() \
- int dat_written = 0, aux_written = 0, prb_written = 0; \
- MAYBE_UNUSED((dat_written * aux_written * prb_written)); \
- char dat_str[MKLDNN_VERBOSE_DAT_LEN] = {'\0'}; MAYBE_UNUSED(dat_str); \
- char aux_str[MKLDNN_VERBOSE_AUX_LEN] = {'\0'}; MAYBE_UNUSED(aux_str); \
- char prb_str[MKLDNN_VERBOSE_PRB_LEN] = {'\0'}; MAYBE_UNUSED(prb_str)
-
-#define DFMT "%" PRId64
-
-void clear_buf(char *buf, int &written) {
- /* TODO: do it better */
- buf[0] = '#';
- buf[1] = '\0';
- written = 1;
-}
-
-#define DPRINT(buf, buf_len, written, ...) do { \
- int l = snprintf(buf + written, buf_len - written, __VA_ARGS__); \
- if (l < 0 || written + l > buf_len) { \
- clear_buf(buf, written); \
- } else { \
- written += l; \
- } \
-} while(0)
-
-// XXX: Outputs strings corresponding to memory formats used for data tensors.
-void format_prb_desc_str(char *str, int len, const memory_desc_t *md) {
- const auto dims = md->dims;
- int written = 0;
- if (md->ndims == 1)
- DPRINT(str, len, written,
- "x" DFMT, dims[0]);
- else if (md->ndims == 2)
- DPRINT(str, len, written,
- "mb" DFMT "ic" DFMT, dims[0], dims[1]);
- else if (md->ndims == 3)
- DPRINT(str, len, written,
- "mb" DFMT "ic" DFMT "iw" DFMT,
- dims[0], dims[1], dims[2]);
- else if (md->ndims == 4)
- DPRINT(str, len, written,
- "mb" DFMT "ic" DFMT "ih" DFMT "iw" DFMT,
- dims[0], dims[1], dims[2], dims[3]);
- else if (md->ndims == 5)
- DPRINT(str, len, written,
- "mb" DFMT "ic" DFMT "id" DFMT "ih" DFMT "iw" DFMT,
- dims[0], dims[1], dims[2], dims[3], dims[4]);
- else
- mkldnn_md2dim_str(str, len, md);
-}
-
-void verbose_templ(char *buffer, mkldnn_primitive_kind_t prim_kind,
- const char *impl_str, mkldnn_prop_kind_t prop_kind,
- const char *data_str, const char *aux_str, const char *prb_str) {
- MAYBE_UNUSED(verbose_templ);
- int written = 0;
- DPRINT(buffer, MKLDNN_VERBOSE_BUF_LEN, written, "%s,%s,%s,%s,%s,%s",
- mkldnn_prim_kind2str(prim_kind), impl_str,
- mkldnn_prop_kind2str(prop_kind), data_str, aux_str, prb_str);
-}
-
-template <typename pd_t> static void init_info_bnorm(pd_t *s, char *buffer) {
- DECL_DAT_AUX_PRB_STRS();
-
- if (1) { // data
- auto md = s->src_md();
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- if (1) { // diff data
- auto md = s->diff_src_md();
- if (md) {
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- }
-
- DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
- "flags:%u", s->desc()->flags);
-
- format_prb_desc_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md());
-
- verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
- aux_str, prb_str);
-}
-
-template <typename pd_t> static void init_info_conv(pd_t *s, char *buffer) {
- DECL_DAT_AUX_PRB_STRS();
-
- if (1) { // src
- auto md = s->desc()->prop_kind == prop_kind::backward_data
- ? s->diff_src_md() : s->src_md();
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- if (1) { // wei
- auto md = s->desc()->prop_kind == prop_kind::backward_weights
- ? s->diff_weights_md() : s->weights_md();
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- if (1) { // bia
- auto md = s->desc()->prop_kind == prop_kind::backward_weights
- ? s->diff_weights_md(1) : s->weights_md(1);
- if (md) {
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bia_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- }
- if (1) { // dst
- auto md = !s->is_fwd() ? s->diff_dst_md() : s->dst_md();
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
-
- DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
- "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind));
-
- if (s->ndims() == 5) {
- if (s->with_groups())
- DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
- "mb" DFMT "_g" DFMT "ic" DFMT "oc" DFMT
- "_id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "dd" DFMT "pd" DFMT
- "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT
- "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT,
- s->MB(), s->G(), s->IC(), s->OC(),
- s->ID(), s->OD(), s->KD(), s->KSD(), s->KDD(), s->padFront(),
- s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(),
- s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL());
- else
- DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
- "mb" DFMT "_ic" DFMT "oc" DFMT
- "_id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "dd" DFMT "pd" DFMT
- "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT
- "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT,
- s->MB(), s->IC(), s->OC(),
- s->ID(), s->OD(), s->KD(), s->KSD(), s->KDD(), s->padFront(),
- s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(),
- s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL());
- } else {
- if (s->with_groups())
- DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
- "mb" DFMT "_g" DFMT "ic" DFMT "oc" DFMT
- "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT
- "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT,
- s->MB(), s->G(), s->IC(), s->OC(),
- s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(),
- s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL());
- else
- DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
- "mb" DFMT "_ic" DFMT "oc" DFMT
- "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT
- "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT,
- s->MB(), s->IC(), s->OC(),
- s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(),
- s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL());
- }
-
- verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
- aux_str, prb_str);
-}
-
-template <typename pd_t> static void init_info_shuffle(pd_t *s, char *buffer) {
- DECL_DAT_AUX_PRB_STRS();
-
- auto md = s->is_fwd() ? s->src_md() : s->diff_dst_md();
-
- if (1) { // data
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
-
- DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
- "axis:%d group_size:" DFMT, s->axis(), s->group_size());
-
- mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, md);
-
- verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
- aux_str, prb_str);
-}
-
-template <typename pd_t> static void init_info_eltwise(pd_t *s, char *buffer) {
- DECL_DAT_AUX_PRB_STRS();
-
- if (1) { // data
- auto md = s->src_md();
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- if (1) { // diff data
- auto md = s->diff_src_md();
- if (md) {
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- }
-
- DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
- "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind));
-
- mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md());
-
- verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
- aux_str, prb_str);
-}
-
-template <typename pd_t> static void init_info_iprod(pd_t *s, char *buffer) {
- DECL_DAT_AUX_PRB_STRS();
-
- if (1) { // src
- auto md = s->desc()->prop_kind == prop_kind::backward_data
- ? s->diff_src_md() : s->src_md();
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- if (1) { // wei
- auto md = s->desc()->prop_kind == prop_kind::backward_weights
- ? s->diff_weights_md() : s->weights_md();
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- if (1) { // bia
- auto md = s->desc()->prop_kind == prop_kind::backward_weights
- ? s->diff_weights_md(1) : s->weights_md(1);
- if (md) {
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bia_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- }
- if (1) { // dst
- auto md = !s->is_fwd() ? s->diff_dst_md() : s->dst_md();
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
-
- DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
- "mb" DFMT "ic" DFMT "oc" DFMT, s->MB(), s->IC_total(), s->OC());
-
- verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
- aux_str, prb_str);
-}
-
-template <typename pd_t> static void init_info_lrn(pd_t *s, char *buffer) {
- DECL_DAT_AUX_PRB_STRS();
-
- if (1) { // data
- auto md = s->src_md();
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- if (1) { // diff data
- auto md = s->diff_src_md();
- if (md) {
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- }
-
- DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
- "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind));
-
- format_prb_desc_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md());
-
- verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
- aux_str, prb_str);
-}
-
-template <typename pd_t> static void init_info_mem(pd_t *s, char *buffer) {
- DECL_DAT_AUX_PRB_STRS();
-
- if (1) { // src
- auto md = s->src_md();
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- if (1) { // dst
- auto md = s->dst_md();
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
-
- DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
- "num:%d", s->n_inputs());
-
- mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->dst_md());
-
- verbose_templ(buffer, s->kind(), s->name(), prop_kind::undef, dat_str,
- aux_str, prb_str);
-}
-
-template <typename pd_t> static void init_info_pool(pd_t *s, char *buffer) {
- DECL_DAT_AUX_PRB_STRS();
-
- if (1) { // src
- auto md = s->is_fwd() ? s->src_md() : s->diff_src_md();
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- if (1) { // dst
- auto md = s->is_fwd() ? s->dst_md() : s->diff_dst_md();
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- if (1) { // ws
- auto md = s->workspace_md();
- if (md) {
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " ws_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- }
-
- DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
- "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind));
-
- if (s->is_3d()) {
- DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
- "mb" DFMT "ic" DFMT "_"
- "id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "pd" DFMT "_"
- "ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "ph" DFMT "_"
- "iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "pw" DFMT "",
- s->MB(), s->C(),
- s->ID(), s->OD(), s->KD(), s->KSD(), s->padFront(),
- s->IH(), s->OH(), s->KH(), s->KSH(), s->padT(),
- s->IW(), s->OW(), s->KW(), s->KSW(), s->padL());
- } else {
- DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
- "mb" DFMT "ic" DFMT "_"
- "ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "ph" DFMT "_"
- "iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "pw" DFMT,
- s->MB(), s->C(),
- s->IH(), s->OH(), s->KH(), s->KSH(), s->padT(),
- s->IW(), s->OW(), s->KW(), s->KSW(), s->padL());
- }
-
- verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
- aux_str, prb_str);
-}
-
-template <typename pd_t> static void init_info_softmax(pd_t *s, char *buffer) {
- DECL_DAT_AUX_PRB_STRS();
-
- if (1) { // data
- auto md = s->dst_md();
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- if (1) { // diff data
- auto md = s->diff_src_md();
- if (md) {
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- }
-
- mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->dst_md());
-
- verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
- aux_str, prb_str);
-}
-
-template <typename pd_t> static void init_info_rnn(pd_t *s, char *buffer) {
- DECL_DAT_AUX_PRB_STRS();
-
- if (1) { // src layer
- auto md = s->is_fwd() ? s->src_md(0) : s->diff_src_md(0);
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_layer_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- if (1) { // src iter
- auto md = s->is_fwd() ? s->src_md(1) : s->diff_src_md(1);
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_iter_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- if (1) { // wei_layer
- auto md = s->is_fwd() ? s->weights_md(0) : s->diff_weights_md(0);
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_layer_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- if (1) { // wei_iter
- auto md = s->is_fwd() ? s->weights_md(1) : s->diff_weights_md(1);
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_layer_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- if (1) { // bias
- auto md = s->is_fwd() ? s->weights_md(2) : s->diff_weights_md(2);
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bias_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- if (1) { // dst layer
- auto md = s->is_fwd() ? s->dst_md(0) : s->diff_dst_md(0);
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "dst_layer_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
- if (1) { // dst iter
- auto md = s->is_fwd() ? s->dst_md(1) : s->diff_dst_md(1);
- DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "dst_iter_");
- int l = mkldnn_md2fmt_str(dat_str + dat_written,
- MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
- if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
- }
-
- alg_kind_t alg_kind = s->cell_kind();
- rnn_direction_t rnn_dir = s->direction();
- DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
- "alg:%s_%s", mkldnn_alg_kind2str(alg_kind),
- mkldnn_rnn_direction2str(rnn_dir));
-
- DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
- "l" DFMT "t" DFMT "mb" DFMT
- "sic" DFMT "slc" DFMT "dic" DFMT "dlc" DFMT,
- s->L(), s->T(), s->MB(),
- s->SIC(), s->SLC(), s->DIC(), s->DLC());
-
- verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
- aux_str, prb_str);
-}
-
-#undef DPRINT
-
-#else // !defined(DISABLE_VERBOSE)
-
-#define DEFINE_STUB(name) \
- template <typename pd_t> \
- static void CONCAT2(init_info_, name)(pd_t *s, char *buffer) \
- { UNUSED(s); UNUSED(buffer); }
-
-DEFINE_STUB(bnorm);
-DEFINE_STUB(conv);
-DEFINE_STUB(eltwise);
-DEFINE_STUB(iprod);
-DEFINE_STUB(lrn);
-DEFINE_STUB(mem);
-DEFINE_STUB(pool);
-DEFINE_STUB(softmax);
-DEFINE_STUB(rnn);
-DEFINE_STUB(shuffle);
-#undef DEFINE_STUB
-
-#endif // !defined(DISABLE_VERBOSE)
-}
-
-void init_info(batch_normalization_pd_t *s, char *b)
-{ init_info_bnorm(s, b); }
-void init_info(concat_pd_t *s, char *b)
-{ init_info_mem(s, b); }
-void init_info(convolution_pd_t *s, char *b)
-{ init_info_conv(s, b); }
-void init_info(deconvolution_pd_t *s, char *b)
-{ init_info_conv(s, b); }
-void init_info(eltwise_pd_t *s, char *b)
-{ init_info_eltwise(s, b); }
-void init_info(inner_product_pd_t *s, char *b)
-{ init_info_iprod(s, b); }
-void init_info(lrn_pd_t *s, char *b)
-{ init_info_lrn(s, b); }
-void init_info(pooling_pd_t *s, char *b)
-{ init_info_pool(s, b); }
-void init_info(reorder_pd_t *s, char *b)
-{ init_info_mem(s, b); }
-void init_info(rnn_pd_t *s, char *b)
-{ init_info_rnn(s, b); }
-void init_info(shuffle_pd_t *s, char *b)
-{ init_info_shuffle(s, b); }
-void init_info(softmax_pd_t *s, char *b)
-{ init_info_softmax(s, b); }
-void init_info(sum_pd_t *s, char *b)
-{ init_info_mem(s, b); }
-
-}
-}
-
-mkldnn_status_t mkldnn_set_verbose(int level) {
- using namespace mkldnn::impl::status;
- if (level < 0 || level > 2) return invalid_arguments;
- mkldnn::impl::verbose.level = level;
- mkldnn::impl::initialized = true;
- return success;
-}
-
-const mkldnn_version_t *mkldnn_version() {
- static mkldnn_version_t ver = {
- MKLDNN_VERSION_MAJOR,
- MKLDNN_VERSION_MINOR,
- MKLDNN_VERSION_PATCH,
- MKLDNN_VERSION_HASH};
- return &ver;
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp b/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp
deleted file mode 100644
index e3049750cb..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp
+++ /dev/null
@@ -1,62 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef VERBOSE_HPP
-#define VERBOSE_HPP
-
-#include <stdio.h>
-#include <cinttypes>
-
-#include "mkldnn_debug.h"
-#include "c_types_map.hpp"
-#include "utils.hpp"
-#include "z_magic.hpp"
-
-namespace mkldnn {
-namespace impl {
-
-struct verbose_t {
- int level;
-};
-
-const verbose_t *mkldnn_verbose();
-double get_msec();
-const char *get_isa_info();
-
-#if !defined(DISABLE_VERBOSE)
-#define MKLDNN_VERBOSE_BUF_LEN 1024
-#else
-#define MKLDNN_VERBOSE_BUF_LEN 1
-#endif
-
-void init_info(batch_normalization_pd_t *s, char *buffer);
-void init_info(concat_pd_t *s, char *buffer);
-void init_info(convolution_pd_t *s, char *buffer);
-void init_info(deconvolution_pd_t *s, char *buffer);
-void init_info(eltwise_pd_t *s, char *buffer);
-void init_info(inner_product_pd_t *s, char *buffer);
-void init_info(lrn_pd_t *s, char *buffer);
-void init_info(pooling_pd_t *s, char *buffer);
-void init_info(reorder_pd_t *s, char *buffer);
-void init_info(rnn_pd_t *s, char *buffer);
-void init_info(shuffle_pd_t *s, char *buffer);
-void init_info(softmax_pd_t *s, char *buffer);
-void init_info(sum_pd_t *s, char *buffer);
-
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp b/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp
deleted file mode 100644
index 520bd4710b..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp
+++ /dev/null
@@ -1,46 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef Z_MAGIC_HPP
-#define Z_MAGIC_HPP
-
-#define CHAIn2(a,b) a b
-#define CHAIN2(a,b) CHAIn2(a,b)
-
-#define CONCAt2(a,b) a ## b
-#define CONCAT2(a,b) CONCAt2(a,b)
-
-#define STRINGIFy(s) #s
-#define STRINGIFY(s) STRINGIFy(s)
-
-#ifdef _MSC_VER
-# define PRAGMA_MACRo(x) __pragma(x)
-# define PRAGMA_MACRO(x) PRAGMA_MACRo(x)
-#else
-# define PRAGMA_MACRo(x) _Pragma(#x)
-# define PRAGMA_MACRO(x) PRAGMA_MACRo(x)
-#endif
-
-#define UNUSED(x) ((void)x)
-#define MAYBE_UNUSED(x) UNUSED(x)
-
-#if defined(_WIN32) && !defined(__GNUC__)
-#define __PRETTY_FUNCTION__ __FUNCSIG__
-#endif
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp
deleted file mode 100644
index 7cf7822d90..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp
+++ /dev/null
@@ -1,112 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-
-#include "cpu_barrier.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-namespace simple_barrier {
-
-void generate(jit_generator &code, Xbyak::Reg64 reg_ctx,
- Xbyak::Reg64 reg_nthr) {
-# define BAR_CTR_OFF offsetof(ctx_t, ctr)
-# define BAR_SENSE_OFF offsetof(ctx_t, sense)
- using namespace Xbyak;
-
- Xbyak::Reg64 reg_tmp = [&]() {
- /* returns register which is neither reg_ctx nor reg_nthr */
- Xbyak::Reg64 regs[] = { util::rax, util::rbx, util::rcx };
- for (size_t i = 0; i < sizeof(regs) / sizeof(regs[0]); ++i)
- if (!utils::one_of(regs[i], reg_ctx, reg_nthr))
- return regs[i];
- return regs[0]; /* should not happen */
- }();
-
- Label barrier_exit_label, barrier_exit_restore_label, spin_label;
-
- code.cmp(reg_nthr, 1);
- code.jbe(barrier_exit_label);
-
- code.push(reg_tmp);
-
- /* take and save current sense */
- code.mov(reg_tmp, code.ptr[reg_ctx + BAR_SENSE_OFF]);
- code.push(reg_tmp);
- code.mov(reg_tmp, 1);
-
- if (mayiuse(avx512_mic)) {
- code.prefetchwt1(code.ptr[reg_ctx + BAR_CTR_OFF]);
- code.prefetchwt1(code.ptr[reg_ctx + BAR_CTR_OFF]);
- }
-
- code.lock(); code.xadd(code.ptr[reg_ctx + BAR_CTR_OFF], reg_tmp);
- code.add(reg_tmp, 1);
- code.cmp(reg_tmp, reg_nthr);
- code.pop(reg_tmp); /* restore previous sense */
- code.jne(spin_label);
-
- /* the last thread {{{ */
- code.mov(code.qword[reg_ctx + BAR_CTR_OFF], 0); // reset ctx
-
- // notify waiting threads
- code.not_(reg_tmp);
- code.mov(code.ptr[reg_ctx + BAR_SENSE_OFF], reg_tmp);
- code.jmp(barrier_exit_restore_label);
- /* }}} the last thread */
-
- code.CodeGenerator::L(spin_label);
- code.pause();
- code.cmp(reg_tmp, code.ptr[reg_ctx + BAR_SENSE_OFF]);
- code.je(spin_label);
-
- code.CodeGenerator::L(barrier_exit_restore_label);
- code.pop(reg_tmp);
-
- code.CodeGenerator::L(barrier_exit_label);
-# undef BAR_CTR_OFF
-# undef BAR_SENSE_OFF
-}
-
-/** jit barrier generator */
-struct jit_t: public jit_generator {
- void (*barrier)(ctx_t *ctx, size_t nthr);
-
- jit_t() {
- generate(*this, abi_param1, abi_param2);
- ret();
- barrier = reinterpret_cast<decltype(barrier)>(const_cast<uint8_t*>(
- this->getCode()));
- }
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_t)
-};
-
-void barrier(ctx_t *ctx, int nthr) {
- static jit_t j; /* XXX: constructed on load ... */
- j.barrier(ctx, nthr);
-}
-
-}
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp
deleted file mode 100644
index 0f55e33aa8..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp
+++ /dev/null
@@ -1,60 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_BARRIER_HPP
-#define CPU_BARRIER_HPP
-
-#include <assert.h>
-
-#include "jit_generator.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-namespace simple_barrier {
-
-STRUCT_ALIGN(64,
-struct ctx_t {
- enum { CACHE_LINE_SIZE = 64 };
- volatile size_t ctr;
- char pad1[CACHE_LINE_SIZE - 1 * sizeof(size_t)];
- volatile size_t sense;
- char pad2[CACHE_LINE_SIZE - 1 * sizeof(size_t)];
-});
-
-inline void ctx_init(ctx_t *ctx) { *ctx = utils::zero<ctx_t>(); }
-void barrier(ctx_t *ctx, int nthr);
-
-/** injects actual barrier implementation into another jitted code
- * @params:
- * code -- jit_generator object where the barrier is to be injected
- * reg_ctx -- read-only register with pointer to the barrier context
- * reg_nnthr -- read-only register with the # of synchronizing threads
- */
-void generate(jit_generator &code, Xbyak::Reg64 reg_ctx,
- Xbyak::Reg64 reg_nthr);
-
-}
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp
deleted file mode 100644
index 1ed5ad57b9..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp
+++ /dev/null
@@ -1,40 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_BATCH_NORMALIZATION_PD_HPP
-#define CPU_BATCH_NORMALIZATION_PD_HPP
-
-#include "batch_normalization_pd.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct cpu_batch_normalization_fwd_pd_t: public batch_normalization_fwd_pd_t {
- using batch_normalization_fwd_pd_t::batch_normalization_fwd_pd_t;
-};
-
-struct cpu_batch_normalization_bwd_pd_t: public batch_normalization_bwd_pd_t {
- using batch_normalization_bwd_pd_t::batch_normalization_bwd_pd_t;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp
deleted file mode 100644
index b8d5c4fcaf..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp
+++ /dev/null
@@ -1,140 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "utils.hpp"
-
-#include "jit_generator.hpp"
-
-#include "cpu_batch_normalization_utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-namespace bnorm_utils {
-
-void cache_balance(size_t working_set_size, dim_t C_blks,
- dim_t &C_blks_per_iter, int64_t &iters) {
- int nthrs = mkldnn_get_max_threads();
- int l3_size = get_cache_size(3, true) * nthrs / 2;
-
- C_blks_per_iter = l3_size / working_set_size;
-
- if (C_blks_per_iter == 0)
- C_blks_per_iter = 1;
- if (C_blks_per_iter > C_blks)
- C_blks_per_iter = C_blks;
-
- iters = (C_blks + C_blks_per_iter - 1) / C_blks_per_iter;
-}
-
-bool thread_balance(bool do_blocking, bool spatial_thr_allowed, int ithr,
- int nthr, dim_t N, dim_t C_blks, dim_t SP, int &C_ithr, int &C_nthr,
- dim_t &C_blk_s, dim_t &C_blk_e, int &N_ithr, int &N_nthr, dim_t &N_s,
- dim_t &N_e, int &S_ithr, int &S_nthr, dim_t &S_s, dim_t &S_e) {
- if (nthr <= C_blks || !mkldnn_thr_syncable()) {
- C_ithr = ithr; C_nthr = nthr;
- N_ithr = 0; N_nthr = 1;
- S_ithr = 0; S_nthr = 1;
- N_s = 0; N_e = N; S_s = 0; S_e = SP;
- balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e);
- } else {
- if (do_blocking) {
- N_nthr = (int)nstl::min<dim_t>(N, nthr);
- C_nthr = (int)nstl::min<dim_t>(C_blks, nthr / N_nthr);
- S_nthr = (int)nstl::min<dim_t>(SP, nthr / (C_nthr * N_nthr));
- } else {
- C_nthr = (int)math::gcd((dim_t)nthr, C_blks);
- N_nthr = (int)nstl::min<dim_t>(N, nthr / C_nthr);
- S_nthr = (int)nstl::min<dim_t>(SP, nthr / (C_nthr * N_nthr));
- }
-
- if (!spatial_thr_allowed)
- S_nthr = 1;
-
- if (S_nthr < 1) S_nthr = 1;
- if (ithr < C_nthr * N_nthr * S_nthr) {
- N_ithr = (ithr / S_nthr) % N_nthr ;
- C_ithr = ithr / (N_nthr * S_nthr);
- S_ithr = ithr % S_nthr;
- balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e);
- balance211(N, N_nthr, N_ithr, N_s, N_e);
- balance211(SP, S_nthr, S_ithr, S_s, S_e);
- } else {
- S_ithr = N_ithr = C_ithr = -ithr;
- S_s = S_e = N_s = N_e = C_blk_s = C_blk_e = -1;
- }
- }
-
- // spatial_thr_allowed is meant to help maintain
- // consistent decisions about spatial threading
- // between mutiple invocations of this routine.
- // It is caller's responsibility to check the
- // return value and pass it as a flag to the
- // next call if needed.
- if (S_nthr == 1)
- spatial_thr_allowed = false;
-
- return spatial_thr_allowed;
-}
-
-bool is_spatial_thr(const batch_normalization_pd_t *bdesc, int simd_w,
- int data_size) {
- if (!mkldnn_thr_syncable()) return false;
-
- dim_t nthr = mkldnn_get_max_threads();
- dim_t SP = bdesc->W() * bdesc->D() * bdesc->H();
- dim_t C_PADDED = memory_desc_wrapper(bdesc->src_md())
- .padded_dims()[1];
- assert(C_PADDED % simd_w == 0);
-
- size_t data = bdesc->MB() * C_PADDED * SP * data_size;
- size_t l3_size_ = get_cache_size(3, true) * nthr / 2;
- bool do_blocking = (data >= l3_size_ / 2 && l3_size_ > 0);
- dim_t C_blks_per_iter{ 1 }, iters{ 1 };
- dim_t C_blks = C_PADDED / simd_w;
-
- if (do_blocking) {
- int num_tensors = bdesc->is_fwd() ? 1 : 2;
- size_t working_set_size
- = (bdesc->MB() * SP * simd_w * data_size) * num_tensors;
- cache_balance(working_set_size, C_blks, C_blks_per_iter, iters);
- }
-
- // Spatial threading decision made in this function shall be consistent
- // with thread_balance() behavior.
- C_blks = do_blocking ? C_blks_per_iter : C_blks;
-
- if (nthr <= C_blks) return false;
-
- dim_t S_nthr = 1;
- if (do_blocking) {
- dim_t N_nthr = nstl::min(bdesc->MB(), nthr);
- dim_t C_nthr = nstl::min(C_blks, nthr / N_nthr);
- S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr));
- } else {
- dim_t C_nthr = math::gcd(nthr, C_blks);
- dim_t N_nthr = nstl::min(bdesc->MB(), nthr / C_nthr);
- S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr));
- }
-
- return S_nthr > 1;
-}
-
-}
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp
deleted file mode 100644
index 0daef0716c..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp
+++ /dev/null
@@ -1,43 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_BATCH_NORMALIZATION_UTILS_HPP
-#define CPU_BATCH_NORMALIZATION_UTILS_HPP
-
-#include "batch_normalization_pd.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-namespace bnorm_utils {
-
-void cache_balance(size_t working_set_size, dim_t C_blks,
- dim_t &C_blks_per_iter, int64_t &iters);
-
-bool thread_balance(bool do_blocking, bool spatial_thr_allowed, int ithr,
- int nthr, dim_t N, dim_t C_blks, dim_t SP, int &C_ithr, int &C_nthr,
- dim_t &C_blk_s, dim_t &C_blk_e, int &N_ithr, int &N_nthr, dim_t &N_s,
- dim_t &N_e, int &S_ithr, int &S_nthr, dim_t &S_s, dim_t &S_e);
-
-bool is_spatial_thr(const batch_normalization_pd_t *bdesc, int simd_w,
- int data_size);
-
-}
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp
deleted file mode 100644
index b926491202..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp
+++ /dev/null
@@ -1,51 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "cpu_engine.hpp"
-
-/*
-#include "cpu/ref_concat.hpp"
-#include "cpu/simple_concat.hpp"
-*/
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using cpd_create_f = mkldnn::impl::engine_t::concat_primitive_desc_create_f;
-
-namespace {
-#define INSTANCE(...) __VA_ARGS__::pd_t::create
-static const cpd_create_f cpu_concat_impl_list[] = {
- /*
- INSTANCE(simple_concat_t<data_type::f32>),
- INSTANCE(simple_concat_t<data_type::u8>),
- INSTANCE(simple_concat_t<data_type::s8>),
- INSTANCE(simple_concat_t<data_type::s32>),
- INSTANCE(ref_concat_t),
- */
- nullptr,
-};
-#undef INSTANCE
-}
-
-const cpd_create_f *cpu_engine_t::get_concat_implementation_list() const {
- return cpu_concat_impl_list;
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp
deleted file mode 100644
index 0b01bcf163..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp
+++ /dev/null
@@ -1,41 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_CONCAT_PD_HPP
-#define CPU_CONCAT_PD_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "concat_pd.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct cpu_concat_pd_t: public concat_pd_t {
- using concat_pd_t::concat_pd_t;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp
deleted file mode 100644
index 52a38a2294..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp
+++ /dev/null
@@ -1,74 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_CONVOLUTION_PD_HPP
-#define CPU_CONVOLUTION_PD_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "convolution_pd.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct cpu_convolution_fwd_pd_t: public convolution_fwd_pd_t {
- using convolution_fwd_pd_t::convolution_fwd_pd_t;
-
- bool has_padded_dst() const {
- memory_desc_wrapper dst_d(&dst_md_);
- return OC() != dst_d.padded_dims()[1];
- }
-
- bool wants_padded_bias() const {
- if (!with_bias()) return false;
- return has_padded_dst();
- }
-
- bool wants_zero_pad_dst(bool jit_impl = true) const {
- if (!has_padded_dst()) return false;
- const auto &po = attr()->post_ops_;
- int idx;
- if ((idx = po.find(primitive_kind::eltwise)) == -1) return false;
- return !math::eltwise_fwd_preserves_zero(po.entry_[idx].eltwise.alg,
- jit_impl);
- }
-};
-
-struct cpu_convolution_bwd_data_pd_t: public convolution_bwd_data_pd_t {
- using convolution_bwd_data_pd_t::convolution_bwd_data_pd_t;
-};
-
-struct cpu_convolution_bwd_weights_pd_t: public convolution_bwd_weights_pd_t {
- using convolution_bwd_weights_pd_t::convolution_bwd_weights_pd_t;
-
- bool wants_padded_bias() const {
- if (!with_bias()) return false;
- memory_desc_wrapper diff_dst_d(&diff_dst_md_);
- return OC() != diff_dst_d.padded_dims()[1];
- }
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp
deleted file mode 100644
index 164c8601d7..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp
+++ /dev/null
@@ -1,46 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_DECONVOLUTION_PD_HPP
-#define CPU_DECONVOLUTION_PD_HPP
-
-#include <assert.h>
-
-#include "deconvolution_pd.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct cpu_deconvolution_fwd_pd_t: public deconvolution_fwd_pd_t {
- using deconvolution_fwd_pd_t::deconvolution_fwd_pd_t;
-};
-
-struct cpu_deconvolution_bwd_data_pd_t: public deconvolution_bwd_data_pd_t {
- using deconvolution_bwd_data_pd_t::deconvolution_bwd_data_pd_t;
-};
-
-struct cpu_deconvolution_bwd_weights_pd_t: public deconvolution_bwd_weights_pd_t {
- using deconvolution_bwd_weights_pd_t::deconvolution_bwd_weights_pd_t;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp
deleted file mode 100644
index c52f00026e..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp
+++ /dev/null
@@ -1,45 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_ELTWISE_PD_HPP
-#define CPU_ELTWISE_PD_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "eltwise_pd.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct cpu_eltwise_fwd_pd_t: public eltwise_fwd_pd_t {
- using eltwise_fwd_pd_t::eltwise_fwd_pd_t;
-};
-
-struct cpu_eltwise_bwd_pd_t: public eltwise_bwd_pd_t {
- using eltwise_bwd_pd_t::eltwise_bwd_pd_t;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp
deleted file mode 100644
index ce0a3667ad..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp
+++ /dev/null
@@ -1,324 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-
-#include "type_helpers.hpp"
-#include "verbose.hpp"
-
-#include "cpu_engine.hpp"
-#include "cpu_memory.hpp"
-
-//#include "cpu/rnn/ref_rnn.hpp"
-
-//#include "cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
-//#include "cpu/jit_avx512_common_1x1_convolution.hpp"
-#include "cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp"
-#include "cpu/jit_avx512_common_convolution_winograd.hpp"
-//#include "cpu/jit_avx512_core_x8s8s32x_convolution.hpp"
-#include "cpu/jit_avx512_common_convolution.hpp"
-//#include "cpu/jit_avx2_1x1_convolution.hpp"
-//#include "cpu/jit_sse42_1x1_convolution.hpp"
-#include "cpu/jit_avx2_convolution.hpp"
-#include "cpu/jit_sse42_convolution.hpp"
-//#include "cpu/gemm_convolution.hpp"
-//#include "cpu/gemm_x8s8s32x_convolution.hpp"
-//#include "cpu/ref_convolution.hpp"
-//#include "cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp"
-//#include "cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp"
-//#include "cpu/ref_deconvolution.hpp"
-//#include "cpu/ref_shuffle.hpp"
-//#include "cpu/jit_uni_eltwise.hpp"
-//#include "cpu/ref_eltwise.hpp"
-//#include "cpu/ref_softmax.hpp"
-#include "cpu/jit_uni_pooling.hpp"
-//#include "cpu/jit_uni_i8i8_pooling.hpp"
-//#include "cpu/ref_pooling.hpp"
-//#include "cpu/nchw_pooling.hpp"
-//#include "cpu/nhwc_pooling.hpp"
-//#include "cpu/jit_avx512_common_lrn.hpp"
-//#include "cpu/jit_uni_lrn.hpp"
-//#include "cpu/ref_lrn.hpp"
-//#include "cpu/jit_uni_batch_normalization.hpp"
-//#include "cpu/ref_batch_normalization.hpp"
-//#include "cpu/ncsp_batch_normalization.hpp"
-//#include "cpu/nspc_batch_normalization.hpp"
-//#include "cpu/ref_inner_product.hpp"
-//#include "cpu/gemm_inner_product.hpp"
-//#include "cpu/gemm_x8s8s32x_inner_product.hpp"
-//#include "cpu/jit_uni_dw_convolution.hpp"
-//#include "cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp"
-#include "cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-status_t cpu_engine_t::memory_create(memory_t **memory,
- const memory_desc_t *md, void *handle) {
- auto _memory = new cpu_memory_t(this, md, handle);
- if (_memory == nullptr)
- return status::out_of_memory;
-
- status_t status = _memory->init();
- if (status != status::success) {
- delete _memory;
- return status;
- }
-
- return safe_ptr_assign<memory_t>(*memory, _memory);
-}
-
-using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f;
-
-namespace {
-using namespace mkldnn::impl::data_type;
-
-#define INSTANCE(...) &primitive_desc_t::create<__VA_ARGS__::pd_t>
-static const pd_create_f cpu_impl_list[] = {
- /* RNN */
- /*
- INSTANCE(ref_rnn_fwd_f32_t),
- INSTANCE(ref_rnn_fwd_u8s8_t),
- INSTANCE(ref_rnn_bwd_f32_t),
- */
- /* conv */
- /*
- INSTANCE(jit_avx512_common_dw_convolution_fwd_t),
- INSTANCE(jit_avx512_common_dw_convolution_bwd_data_t),
- INSTANCE(jit_avx512_common_dw_convolution_bwd_weights_t),
- INSTANCE(jit_avx512_common_1x1_convolution_fwd_f32_t),
- INSTANCE(jit_avx512_common_1x1_convolution_bwd_data_f32_t),
- INSTANCE(jit_avx512_common_1x1_convolution_bwd_weights_t),
- */
- INSTANCE(jit_avx512_core_fp32_wino_conv_2x3_fwd_t),
- INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_fwd_t),
- //INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t),
- //INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t),
- INSTANCE(jit_avx512_common_convolution_winograd_fwd_t),
- //INSTANCE(jit_avx512_common_convolution_winograd_bwd_data_t),
- //INSTANCE(jit_avx512_common_convolution_winograd_bwd_weights_t),
- INSTANCE(jit_avx512_common_convolution_fwd_t<f32>),
- //INSTANCE(jit_avx512_common_convolution_bwd_data_t<f32>),
- //INSTANCE(jit_avx512_common_convolution_bwd_weights_t<f32>),
- /*
- INSTANCE(jit_avx2_dw_convolution_fwd_t),
- INSTANCE(jit_avx2_dw_convolution_bwd_data_t),
- INSTANCE(jit_avx2_dw_convolution_bwd_weights_t),
- INSTANCE(jit_avx2_1x1_convolution_fwd_t),
- INSTANCE(jit_avx2_1x1_convolution_bwd_data_t),
- INSTANCE(jit_avx2_1x1_convolution_bwd_weights_t),
- INSTANCE(jit_sse42_dw_convolution_fwd_t),
- INSTANCE(jit_sse42_dw_convolution_bwd_data_t),
- INSTANCE(jit_sse42_dw_convolution_bwd_weights_t),
- INSTANCE(jit_sse42_1x1_convolution_fwd_t),
- */
- INSTANCE(jit_avx2_convolution_fwd_t),
- //INSTANCE(jit_avx2_convolution_bwd_data_t),
- //INSTANCE(jit_avx2_convolution_bwd_weights_t),
- INSTANCE(jit_sse42_convolution_fwd_t),
- /*
- INSTANCE(gemm_convolution_fwd_t),
- INSTANCE(gemm_convolution_bwd_data_t),
- INSTANCE(gemm_convolution_bwd_weights_t),
- INSTANCE(ref_convolution_fwd_t<f32>),
- INSTANCE(ref_convolution_bwd_data_t<f32, f32, f32, f32>),
- INSTANCE(ref_convolution_bwd_weights_t<f32, f32, f32, f32>),
- */
- /* conv (int) */
- /*
- INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<f32>),
- INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<s32>),
- INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<s8>),
- INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<u8>),
- INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,f32>),
- INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,s32>),
- INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,u8>),
- INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,s8>),
- INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,f32>),
- INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,s32>),
- INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,u8>),
- INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,s8>),
- INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,f32>),
- INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,s32>),
- INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,u8>),
- INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,s8>),
- INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,f32>),
- INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,s32>),
- INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,u8>),
- INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,s8>),
- INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, s32>),
- INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, u8>),
- INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, s8>),
- INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, f32>),
- INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, s32>),
- INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, u8>),
- INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, s8>),
- INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, f32>),
- INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<s32>),
- INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<u8>),
- INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<s8>),
- INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<f32>),
- INSTANCE(ref_convolution_fwd_t<u8, s8, f32, s32>),
- INSTANCE(ref_convolution_fwd_t<u8, s8, s32, s32>),
- INSTANCE(ref_convolution_fwd_t<u8, s8, s8, s32>),
- INSTANCE(ref_convolution_fwd_t<u8, s8, u8, s32>),
- INSTANCE(ref_convolution_bwd_data_t<f32, s8, u8, s32>),
- INSTANCE(ref_convolution_bwd_data_t<s32, s8, u8, s32>),
- INSTANCE(ref_convolution_bwd_data_t<s8, s8, u8, s32>),
- INSTANCE(ref_convolution_bwd_data_t<u8, s8, u8, s32>),
- */
- /* deconv */
- /*
- INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,f32>),
- INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,s32>),
- INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,u8>),
- INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,s8>),
- INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,f32>),
- INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,s32>),
- INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,u8>),
- INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,s8>),
- INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,s32>),
- INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,u8>),
- INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,s8>),
- INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,f32>),
- INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,s32>),
- INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,u8>),
- INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,s8>),
- INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,f32>),
- INSTANCE(ref_deconvolution_bwd_weights_t),
- INSTANCE(ref_deconvolution_bwd_data_t),
- INSTANCE(ref_deconvolution_fwd_t),
- */
- /* shuffle */
- /*
- INSTANCE(ref_shuffle_t<4>), // f32 or s32
- INSTANCE(ref_shuffle_t<1>), // s8 or u8
- */
- /* eltwise */
- /*
- INSTANCE(jit_uni_eltwise_fwd_t<avx512_common>),
- INSTANCE(jit_uni_eltwise_bwd_t<avx512_common>),
- INSTANCE(jit_uni_eltwise_fwd_t<avx2>),
- INSTANCE(jit_uni_eltwise_bwd_t<avx2>),
- INSTANCE(jit_uni_eltwise_fwd_t<sse42>),
- INSTANCE(jit_uni_eltwise_bwd_t<sse42>),
- INSTANCE(ref_eltwise_fwd_t<f32>),
- INSTANCE(ref_eltwise_bwd_t<f32>),
- */
- /* eltwise (int) */
- /*
- INSTANCE(ref_eltwise_fwd_t<s32>),
- INSTANCE(ref_eltwise_fwd_t<s8>),
- INSTANCE(ref_eltwise_fwd_t<u8>),
- INSTANCE(ref_eltwise_bwd_t<s32>),
- */
- /* softmax */
- /*
- INSTANCE(ref_softmax_fwd_t<f32>),
- INSTANCE(ref_softmax_bwd_t<f32>),
- */
- /* pool */
- INSTANCE(jit_uni_pooling_fwd_t<avx512_common>),
- //INSTANCE(jit_uni_pooling_bwd_t<avx512_common>),
- INSTANCE(jit_uni_pooling_fwd_t<avx>),
- //INSTANCE(jit_uni_pooling_bwd_t<avx>),
- INSTANCE(jit_uni_pooling_fwd_t<sse42>),
- //INSTANCE(jit_uni_pooling_bwd_t<sse42>),
- /*
- INSTANCE(nchw_pooling_fwd_t<f32>),
- INSTANCE(nchw_pooling_bwd_t<f32>),
- INSTANCE(nhwc_pooling_fwd_t<f32>),
- INSTANCE(nhwc_pooling_bwd_t<f32>),
- INSTANCE(ref_pooling_fwd_t<f32>),
- INSTANCE(ref_pooling_bwd_t<f32>),
- */
- /* pool (int) */
- /*
- INSTANCE(jit_uni_i8i8_pooling_fwd_t<avx512_core>),
- INSTANCE(jit_uni_i8i8_pooling_fwd_t<avx2>),
- INSTANCE(ref_pooling_fwd_t<s32>),
- INSTANCE(ref_pooling_fwd_t<s8, s32>),
- INSTANCE(ref_pooling_fwd_t<u8, s32>),
- INSTANCE(ref_pooling_bwd_t<s32>),
- */
- /* lrn */
- /*
- INSTANCE(jit_avx512_common_lrn_fwd_t),
- INSTANCE(jit_avx512_common_lrn_bwd_t),
- INSTANCE(jit_uni_lrn_fwd_t<avx2>),
- INSTANCE(jit_uni_lrn_bwd_t<avx2>),
- INSTANCE(jit_uni_lrn_fwd_t<sse42>),
- INSTANCE(ref_lrn_fwd_t<f32>),
- INSTANCE(ref_lrn_bwd_t<f32>),
- */
- /* batch normalization */
- /*
- INSTANCE(jit_uni_batch_normalization_fwd_t<avx512_common>),
- INSTANCE(jit_uni_batch_normalization_bwd_t<avx512_common>),
- INSTANCE(jit_uni_batch_normalization_fwd_t<avx2>),
- INSTANCE(jit_uni_batch_normalization_bwd_t<avx2>),
- INSTANCE(jit_uni_batch_normalization_fwd_t<sse42>),
- INSTANCE(jit_uni_batch_normalization_bwd_t<sse42>),
- INSTANCE(ncsp_batch_normalization_fwd_t),
- INSTANCE(ncsp_batch_normalization_bwd_t),
- INSTANCE(nspc_batch_normalization_fwd_t),
- INSTANCE(nspc_batch_normalization_bwd_t),
- INSTANCE(ref_batch_normalization_fwd_t<f32>),
- INSTANCE(ref_batch_normalization_bwd_t<f32>),
- INSTANCE(ref_batch_normalization_fwd_t<s8>),
- */
- /* inner product */
- /*
- INSTANCE(gemm_inner_product_fwd_t<f32>),
- INSTANCE(gemm_inner_product_bwd_data_t<f32>),
- INSTANCE(gemm_inner_product_bwd_weights_t<f32>),
- INSTANCE(ref_inner_product_fwd_t<f32>),
- INSTANCE(ref_inner_product_bwd_data_t<f32, f32, f32, f32>),
- INSTANCE(ref_inner_product_bwd_weights_t<f32>),
- */
- /* inner product (int) */
- /*
- INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, u8>),
- INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, s8>),
- INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, s32>),
- INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, f32>),
- INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, u8>),
- INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, s8>),
- INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, s32>),
- INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, f32>),
- INSTANCE(ref_inner_product_fwd_t<u8, s8, u8, s32>),
- INSTANCE(ref_inner_product_fwd_t<u8, s8, s8, s32>),
- INSTANCE(ref_inner_product_fwd_t<u8, s8, s32, s32>),
- INSTANCE(ref_inner_product_fwd_t<u8, s8, f32, s32>),
- */
- /* eol */
- nullptr,
-};
-#undef INSTANCE
-}
-
-const pd_create_f* cpu_engine_t::get_implementation_list() const {
- return cpu_impl_list;
-}
-
-cpu_engine_factory_t engine_factory;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp
deleted file mode 100644
index e4c877ee05..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp
+++ /dev/null
@@ -1,70 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_ENGINE_HPP
-#define CPU_ENGINE_HPP
-
-#include <assert.h>
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "../common/engine.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-class cpu_engine_t: public engine_t {
-public:
- cpu_engine_t(): engine_t(engine_kind::cpu) {}
-
- /* implementation part */
-
- virtual status_t memory_create(memory_t **memory,
- const memory_desc_t *md, void *handle) override;
-
- virtual const concat_primitive_desc_create_f*
- get_concat_implementation_list() const override;
- virtual const reorder_primitive_desc_create_f*
- get_reorder_implementation_list() const override;
- virtual const sum_primitive_desc_create_f*
- get_sum_implementation_list() const override;
- virtual const primitive_desc_create_f*
- get_implementation_list() const override;
-};
-
-class cpu_engine_factory_t: public engine_factory_t {
-public:
- virtual size_t count() const override { return 1; }
- virtual engine_kind_t kind() const override { return engine_kind::cpu; }
- virtual status_t engine_create(engine_t **engine,
- size_t index) const override {
- assert(index == 0);
- *engine = new cpu_engine_t();
- return status::success;
- };
-};
-
-extern cpu_engine_factory_t engine_factory;
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp
deleted file mode 100644
index 5880d3450c..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp
+++ /dev/null
@@ -1,84 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_INNER_PRODUCT_PD_HPP
-#define CPU_INNER_PRODUCT_PD_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "inner_product_pd.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-namespace {
-inline bool dense_gemm_consitency_check(const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &wei_d, const memory_desc_wrapper &dst_d) {
- using namespace utils;
-
- auto strides_compatible = [&]() {
- bool ok = true;
- auto w_str = wei_d.blocking_desc().strides;
- auto d_str = src_d.blocking_desc().strides;
- for (int i = 1; i < src_d.ndims() - 1; i++) {
- ok = ok && w_str[i] / d_str[i] == w_str[i + 1] / d_str[i + 1];
- }
- return ok && one_of(w_str[1] / d_str[1], 1, wei_d.padded_dims()[0]);
- };
- return true && src_d.is_blocking_desc() && wei_d.is_blocking_desc()
- && src_d.ndims() == wei_d.ndims()
- && src_d.blocking_desc().inner_nblks
- == wei_d.blocking_desc().inner_nblks
- && utils::one_of(src_d.blocking_desc().inner_nblks, 0, 1)
- && array_cmp(src_d.blocking_desc().inner_blks,
- wei_d.blocking_desc().inner_blks,
- wei_d.blocking_desc().inner_nblks)
- && array_cmp(src_d.blocking_desc().inner_idxs,
- wei_d.blocking_desc().inner_idxs,
- wei_d.blocking_desc().inner_nblks)
- && strides_compatible()
- && dst_d.matches_tag(format_tag::nc)
- && src_d.only_padded_dim(1)
- && wei_d.only_padded_dim(1)
- && src_d.padded_dims()[1] == wei_d.padded_dims()[1]
- && src_d.is_dense(true)
- && dst_d.is_dense()
- && wei_d.is_dense(true);
-}
-}
-
-struct cpu_inner_product_fwd_pd_t: public inner_product_fwd_pd_t {
- using inner_product_fwd_pd_t::inner_product_fwd_pd_t;
-};
-
-struct cpu_inner_product_bwd_data_pd_t: public inner_product_bwd_data_pd_t {
- using inner_product_bwd_data_pd_t::inner_product_bwd_data_pd_t;
-};
-
-struct cpu_inner_product_bwd_weights_pd_t: public inner_product_bwd_weights_pd_t {
- using inner_product_bwd_weights_pd_t::inner_product_bwd_weights_pd_t;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp
deleted file mode 100644
index da6e9dac8e..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp
+++ /dev/null
@@ -1,151 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_ISA_TRAITS_HPP
-#define CPU_ISA_TRAITS_HPP
-
-#include <type_traits>
-
-#define XBYAK64
-#define XBYAK_NO_OP_NAMES
-/* in order to make selinux happy memory that would be marked with X-bit should
- * be obtained with mmap */
-#define XBYAK_USE_MMAP_ALLOCATOR
-#if defined(_MSC_VER) && !defined(__INTEL_COMPILER)
-/* turn off `size_t to other-type implicit casting` warning
- * currently we have a lot of jit-generated instructions that
- * take uint32_t, but we pass size_t (e.g. due to using sizeof).
- * FIXME: replace size_t parameters with the appropriate ones */
-#pragma warning (disable: 4267)
-#endif
-#include "xbyak/xbyak.h"
-#include "xbyak/xbyak_util.h"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-typedef enum {
- isa_any,
- sse41,
- sse42,
- avx,
- avx2,
- avx512_common,
- avx512_core,
- avx512_core_vnni,
- avx512_mic,
- avx512_mic_4ops,
-} cpu_isa_t;
-
-template <cpu_isa_t> struct cpu_isa_traits {}; /* ::vlen -> 32 (for avx2) */
-
-template <> struct cpu_isa_traits<sse42> {
- typedef Xbyak::Xmm Vmm;
- static constexpr int vlen_shift = 4;
- static constexpr int vlen = 16;
- static constexpr int n_vregs = 16;
-};
-template <> struct cpu_isa_traits<avx> {
- typedef Xbyak::Ymm Vmm;
- static constexpr int vlen_shift = 5;
- static constexpr int vlen = 32;
- static constexpr int n_vregs = 16;
-};
-template <> struct cpu_isa_traits<avx2>:
- public cpu_isa_traits<avx> {};
-
-template <> struct cpu_isa_traits<avx512_common> {
- typedef Xbyak::Zmm Vmm;
- static constexpr int vlen_shift = 6;
- static constexpr int vlen = 64;
- static constexpr int n_vregs = 32;
-};
-template <> struct cpu_isa_traits<avx512_core>:
- public cpu_isa_traits<avx512_common> {};
-
-template <> struct cpu_isa_traits<avx512_mic>:
- public cpu_isa_traits<avx512_common> {};
-
-template <> struct cpu_isa_traits<avx512_mic_4ops>:
- public cpu_isa_traits<avx512_common> {};
-
-namespace {
-
-static Xbyak::util::Cpu cpu;
-static inline bool mayiuse(const cpu_isa_t cpu_isa) {
- using namespace Xbyak::util;
-
- switch (cpu_isa) {
- case sse41:
- case sse42:
- // FIXME: SSE4.2 is actually NOT required
- //return cpu.has(Cpu::tSSE42);
- return cpu.has(Cpu::tSSE41);
- case avx:
- return cpu.has(Cpu::tAVX);
- case avx2:
- return cpu.has(Cpu::tAVX2);
- case avx512_common:
- return cpu.has(Cpu::tAVX512F);
- case avx512_core:
- return true
- && cpu.has(Cpu::tAVX512F)
- && cpu.has(Cpu::tAVX512BW)
- && cpu.has(Cpu::tAVX512VL)
- && cpu.has(Cpu::tAVX512DQ);
- case avx512_core_vnni:
- return true
- && cpu.has(Cpu::tAVX512F)
- && cpu.has(Cpu::tAVX512BW)
- && cpu.has(Cpu::tAVX512VL)
- && cpu.has(Cpu::tAVX512DQ)
- && cpu.has(Cpu::tAVX512_VNNI);
- case avx512_mic:
- return true
- && cpu.has(Cpu::tAVX512F)
- && cpu.has(Cpu::tAVX512CD)
- && cpu.has(Cpu::tAVX512ER)
- && cpu.has(Cpu::tAVX512PF);
- case avx512_mic_4ops:
- return true
- && mayiuse(avx512_mic)
- && cpu.has(Cpu::tAVX512_4FMAPS)
- && cpu.has(Cpu::tAVX512_4VNNIW);
- case isa_any:
- return true;
- }
- return false;
-}
-}
-
-/* whatever is required to generate string literals... */
-#include "z_magic.hpp"
-#define JIT_IMPL_NAME_HELPER(prefix, isa, suffix_if_any) \
- (isa == sse42 ? prefix STRINGIFY(sse42) : \
- (isa == avx ? prefix STRINGIFY(avx) : \
- (isa == avx2 ? prefix STRINGIFY(avx2) : \
- (isa == avx512_common ? prefix STRINGIFY(avx512_common) : \
- (isa == avx512_core ? prefix STRINGIFY(avx512_core) : \
- (isa == avx512_mic ? prefix STRINGIFY(avx512_mic) : \
- (isa == avx512_mic_4ops ? prefix STRINGIFY(avx512_mic_4ops) : \
- prefix suffix_if_any)))))))
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp
deleted file mode 100644
index 49988f4c2d..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp
+++ /dev/null
@@ -1,42 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_LRN_PD_HPP
-#define CPU_LRN_PD_HPP
-
-#include <assert.h>
-
-#include "lrn_pd.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct cpu_lrn_fwd_pd_t: public lrn_fwd_pd_t {
- using lrn_fwd_pd_t::lrn_fwd_pd_t;
-};
-
-struct cpu_lrn_bwd_pd_t: public lrn_bwd_pd_t {
- using lrn_bwd_pd_t::lrn_bwd_pd_t;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp
deleted file mode 100644
index 3c0624cf46..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp
+++ /dev/null
@@ -1,277 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-
-#include "mkldnn_traits.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_memory.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::data_type;
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::format_tag;
-
-enum blk_kind_t { a, b, c, ab, ba, bc, cb };
-
-template <data_type_t dt, blk_kind_t blk_kind, int blksize>
-void typed_zero_pad_blk(
- const memory_desc_wrapper &m_d, typename prec_traits<dt>::type *data) {
- using data_t = typename prec_traits<dt>::type;
- const auto &dims = m_d.dims();
- const auto &pdims = m_d.padded_dims();
- const auto &blk = m_d.blocking_desc();
- auto dim_is_blocked = [&](int dim) {
- for (int i = 0; i < blk.inner_nblks; i++)
- if (blk.inner_idxs[i] == dim)
- return true;
- return false;
- };
- bool A_blocked = dim_is_blocked(0), B_blocked = dim_is_blocked(1),
- C_blocked = dim_is_blocked(2);
-
- assert(blk.inner_nblks < 4);
- assert((A_blocked || B_blocked || C_blocked) || (A_blocked && B_blocked)
- || (C_blocked && B_blocked));
-
- const int a_tail_s = A_blocked ? dims[0] % blksize : 0;
- const int b_tail_s = B_blocked ? dims[1] % blksize : 0;
- const int c_tail_s = C_blocked ? dims[2] % blksize : 0;
- assert(a_tail_s || b_tail_s || c_tail_s);
-
- const int A = A_blocked ? pdims[0] / blksize : dims[0];
- const int B = B_blocked ? pdims[1] / blksize : dims[1];
- const int C = C_blocked ? pdims[2] / blksize : dims[2];
- const int D = m_d.ndims() > 3 ? dims[3] : 1;
- const int E = m_d.ndims() > 4 ? dims[4] : 1;
- const int F = m_d.ndims() > 5 ? dims[5] : 1;
- const int inner_blk = blk.inner_nblks == 3 ? blk.inner_blks[2] : 1;
-
- auto zeroize_tail = [&](data_t *d, const int tail_s) {
- for (int b = tail_s; b < blksize; ++b)
- d[b] = 0;
- };
- auto zeroize_tail_inner = [&](data_t *d, const int tail_s) {
- for (int b1 = 0; b1 < blksize; ++b1)
- for (int b2 = tail_s; b2 < blksize; ++b2)
- d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2
- + b1 % inner_blk]
- = 0;
- };
- auto zeroize_tail_outer = [&](data_t *d, const int tail_s) {
- for (int b1 = tail_s; b1 < blksize; ++b1)
- for (int b2 = 0; b2 < blksize; ++b2)
- d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2
- + b1 % inner_blk]
- = 0;
- };
-
- if (c_tail_s) {
- parallel_nd(A, B, D, E, F, [&](int a, int b, int d, int e, int f) {
- auto x = &data[m_d.blk_off(a, b, C - 1, d, e, f)];
- if (blk_kind == c)
- zeroize_tail(x, c_tail_s);
- else if (blk_kind == bc)
- zeroize_tail_inner(x, c_tail_s);
- else if (blk_kind == cb)
- zeroize_tail_outer(x, c_tail_s);
- });
- }
-
- if (b_tail_s) {
- parallel_nd(A, C, D, E, F, [&](int a, int c, int d, int e, int f) {
- auto x = &data[m_d.blk_off(a, B - 1, c, d, e, f)];
- if (blk_kind == b)
- zeroize_tail(x, b_tail_s);
- else if (blk_kind == ab || blk_kind == cb)
- zeroize_tail_inner(x, b_tail_s);
- else if (blk_kind == ba || blk_kind == bc)
- zeroize_tail_outer(x, b_tail_s);
- });
- }
-
- if (a_tail_s) {
- parallel_nd(B, C, D, E, F, [&](int b, int c, int d, int e, int f) {
- auto x = &data[m_d.blk_off(A - 1, b, c, d, e, f)];
- if (blk_kind == a)
- zeroize_tail(x, a_tail_s);
- else if (blk_kind == ba)
- zeroize_tail_inner(x, a_tail_s);
- else if (blk_kind == ab)
- zeroize_tail_outer(x, a_tail_s);
- });
- }
-}
-
-/*
- * all
- */
-template <data_type_t dt>
-void typed_zero_pad_generic_blocked(
- const memory_desc_wrapper &m_d, typename prec_traits<dt>::type *data) {
- const int ndims = m_d.ndims();
- const auto &dims = m_d.dims();
- const auto &pdims = m_d.padded_dims();
-
- const ptrdiff_t nelems = (ptrdiff_t)m_d.nelems(true);
-
- /* [D_0] .. [D_k][D_k+1] .. [D_ndim - 1]
- * | \ /
- * | ---------------------
- * has contiguous
- * padding
- *
- * step <-- D_k+1 * ... * D_ndims-1
- * step_dim <-- k
- */
-
- ptrdiff_t step = 1;
- int step_dim = ndims - 1;
- for (; step_dim >= 0; --step_dim) {
- if (dims[step_dim] != pdims[step_dim])
- break;
- step *= dims[step_dim];
- }
-
- assert(step_dim >= 0 && "no zero padding is required");
- if (step_dim < 0)
- return;
-
- parallel_nd(nelems / step, [&](ptrdiff_t e1) {
- bool need_zero = false;
-
- ptrdiff_t idx = e1;
- for (int d = step_dim; d >= 0; --d) {
- if (idx % pdims[d] >= dims[d]) {
- need_zero = true;
- break;
- }
- idx /= pdims[d];
- }
-
- if (need_zero) {
- for (ptrdiff_t e0 = 0; e0 < step; ++e0)
- data[m_d.off_l(e1 * step + e0, true)] = 0;
- }
- });
-}
-
-template <data_type_t dt>
-status_t cpu_memory_t::typed_zero_pad() const {
- const memory_desc_wrapper mdw(md());
-
- if (mdw.format_kind() != format_kind::blocked)
- return unimplemented;
-
- if (mdw.nelems(false) == mdw.nelems(true))
- return success;
-
- auto *data = (typename prec_traits<dt>::type *)data_;
- auto blk = mdw.blocking_desc();
-
- auto get_blksize = [&](int ind) {
- int blksize = 1;
- for (int i = 0; i < blk.inner_nblks; i++) {
- if (blk.inner_idxs[i] == ind)
- blksize *= blk.inner_blks[i];
- }
- return blksize;
- };
- const int blksize = get_blksize(blk.inner_idxs[0]);
-
-# define CASE(blksize_, blk_kind) \
- do { \
- if (blksize == blksize_) { \
- typed_zero_pad_blk<dt, blk_kind, blksize_>(mdw, data); \
- return success; \
- } \
- } while(0)
-
- switch (blk.inner_nblks) {
- case 1:
- if (blk.inner_idxs[0] == 0) {
- CASE(4, a);
- CASE(8, a);
- CASE(16, a);
- } else if (blk.inner_idxs[0] == 1) {
- CASE(4, b);
- CASE(8, b);
- CASE(16, b);
- }
- break;
- case 2:
- case 3:
- if (!IMPLICATION(blk.inner_nblks == 3,
- blk.inner_idxs[0] == blk.inner_idxs[2]))
- break;
-
- if (blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1) {
- CASE(4, ab);
- CASE(8, ab);
- CASE(16, ab);
- } else if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0) {
- CASE(4, ba);
- CASE(8, ba);
- CASE(16, ba);
- }
- if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 2) {
- CASE(4, bc);
- CASE(8, bc);
- CASE(16, bc);
- } else if (blk.inner_idxs[0] == 2 && blk.inner_idxs[1] == 1) {
- CASE(4, cb);
- CASE(8, cb);
- CASE(16, cb);
- }
- break;
- default: break;
- }
-
-# undef CASE
-
- // the last line of defence
- typed_zero_pad_generic_blocked<dt>(mdw, data);
- return success;
-}
-
-status_t cpu_memory_t::zero_pad() const {
- memory_desc_wrapper mdw(md());
- const bool skip_zeroing = false
- || data_ == nullptr
- || mdw.is_zero()
- || !mdw.is_blocking_desc();
- if (skip_zeroing) return success;
-
- switch (mdw.data_type()) {
- case f32: return typed_zero_pad<f32>();
- case s32: return typed_zero_pad<s32>();
- case s8: return typed_zero_pad<s8>();
- case u8: return typed_zero_pad<u8>();
- default: assert(!"memory is undefined"); return unimplemented;
- }
- return unimplemented;
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp
deleted file mode 100644
index 2c01bcc6af..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp
+++ /dev/null
@@ -1,89 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_MEMORY_HPP
-#define CPU_MEMORY_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "memory.hpp"
-#include "memory_desc_wrapper.hpp"
-
-#include "cpu_engine.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct cpu_memory_t: public memory_t {
- cpu_memory_t(cpu_engine_t *engine, const memory_desc_t *md, void *handle)
- : memory_t(engine, md)
- , own_data_(handle == MKLDNN_NATIVE_HANDLE_ALLOCATE)
- , data_((char *)handle) {}
-
- cpu_memory_t(cpu_engine_t *engine, const memory_desc_t *md)
- : cpu_memory_t(engine, md, nullptr) {}
-
- ~cpu_memory_t() { if (own_data_) free(data_); }
-
- virtual status_t init() override {
- if (own_data_) {
- data_ = nullptr;
- const size_t size = memory_desc_wrapper(this->md()).size();
- if (size) {
- data_ = (char *)malloc(size, 64);
- if (data_ == nullptr)
- return status::out_of_memory;
- }
- }
- return zero_pad();
- }
-
- cpu_engine_t *engine() const { return (cpu_engine_t *)memory_t::engine(); }
-
- virtual status_t get_data_handle(void **handle) const override {
- *handle = static_cast<void *>(data_);
- return status::success;
- }
-
- virtual mkldnn::impl::status_t set_data_handle(void *handle) override {
- if (own_data_) { free(data_); own_data_ = false; }
- data_ = static_cast<char *>(handle);
- return zero_pad();
- }
-
- virtual mkldnn::impl::status_t zero_pad() const override;
-
-private:
- bool own_data_;
- char *data_;
-
- template <mkldnn::impl::data_type_t>
- mkldnn::impl::status_t typed_zero_pad() const;
-
- cpu_memory_t(const cpu_memory_t &) = delete;
- cpu_memory_t &operator=(const cpu_memory_t &) = delete;
- cpu_memory_t &operator=(cpu_memory_t &&) = delete;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp
deleted file mode 100644
index ac2daa415e..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp
+++ /dev/null
@@ -1,40 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_POOLING_PD_HPP
-#define CPU_POOLING_PD_HPP
-
-#include "pooling_pd.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct cpu_pooling_fwd_pd_t: public pooling_fwd_pd_t {
- using pooling_fwd_pd_t::pooling_fwd_pd_t;
-};
-
-struct cpu_pooling_bwd_pd_t: public pooling_bwd_pd_t {
- using pooling_bwd_pd_t::pooling_bwd_pd_t;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp
deleted file mode 100644
index 56127f36c2..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp
+++ /dev/null
@@ -1,83 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_PRIMITIVE_HPP
-#define CPU_PRIMITIVE_HPP
-
-#include "mkldnn.h"
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "primitive.hpp"
-#include "scratchpad.hpp"
-
-#define CTX_IN_MEM(type, arg) static_cast<type>(ctx.input(arg))
-#define CTX_OUT_MEM(type, arg) static_cast<type>(ctx.output(arg))
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct cpu_memory_t;
-
-struct cpu_primitive_t: public primitive_t {
- cpu_primitive_t(const primitive_desc_t *pd,
- bool use_global_scratchpad = false)
- : primitive_t(pd)
- , scratchpad_buffer_(nullptr)
- , global_scratchpad_(nullptr)
- {
- const size_t scratchpad_size =
- this->pd()->scratchpad_size(scratchpad_mode::library);
-
- if (scratchpad_size) {
- if (use_global_scratchpad)
- global_scratchpad_ = create_scratchpad(scratchpad_size);
- else
- scratchpad_buffer_ = malloc(scratchpad_size, 64);
- }
- }
-
- virtual ~cpu_primitive_t() {
- delete global_scratchpad_;
- free(scratchpad_buffer_);
- }
-
-protected:
- memory_tracking::grantor_t scratchpad(const exec_ctx_t &ctx) const {
- void *ptr = nullptr;
- if (pd()->attr()->scratchpad_mode_ == scratchpad_mode::user) {
- ptr = CTX_OUT_MEM(void *, MKLDNN_ARG_SCRATCHPAD);
- } else {
- ptr = global_scratchpad_
- ? global_scratchpad_->get() : scratchpad_buffer_;
- }
-
- return pd()->scratchpad_registry().grantor(ptr);
- }
-
-private:
- void *scratchpad_buffer_;
- scratchpad_t *global_scratchpad_;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp
deleted file mode 100644
index 1d41ac5cea..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp
+++ /dev/null
@@ -1,544 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-
-#include "mkldnn_thread.hpp"
-#include "mkldnn_types.h"
-#include "nstl.hpp"
-#include "utils.hpp"
-
-#include "cpu_reducer.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace memory_tracking::names;
-
-void reduce_balancer_t::balance() {
- using namespace nstl;
- using namespace utils;
-
- assert(nthr_ > 0 && job_size_ > 0 && njobs_ > 0 && reduction_size_ > 0);
-
- const int job_complexity = 1;
-
- const int min_njobs_per_group = max(1, njobs_ / nthr_);
- const int max_njobs_per_group = max(1,
- static_cast<int>(max_buffer_size_ / (nthr_ * job_size_)));
-
- /* initial guess */
- int ngroups = min(njobs_ / min_njobs_per_group, nthr_);
- int nthr_per_group = syncable_ ? min(nthr_ / ngroups, reduction_size_) : 1;
- int njobs_per_group_ub = div_up(njobs_, ngroups);
-
- /* rough upper-bound estimation, will be fixed during brute force */
- size_t thread_complexity_ub = njobs_ * job_size_ * reduction_size_;
-
- /* brute force parameters for the best balance... */
- for (int c_njobs_per_group = min_njobs_per_group;
- c_njobs_per_group < njobs_; ++c_njobs_per_group) {
- /* current assumption */
- int c_ngroups = min(njobs_ / c_njobs_per_group, nthr_);
- int c_nthr_per_group = syncable_
- ? min(nthr_ / c_ngroups, reduction_size_) : 1;
- int c_njobs_per_group_ub = div_up(njobs_, c_ngroups);
-
- if (c_nthr_per_group > 1 && c_njobs_per_group_ub > max_njobs_per_group)
- continue;
-
- int c_thread_reduction_ub = div_up(reduction_size_, c_nthr_per_group);
- size_t c_group_size_ub = job_size_ * c_njobs_per_group_ub;
- size_t c_thread_complexity_ub = c_group_size_ub * (
- job_complexity * c_thread_reduction_ub
- + (c_nthr_per_group != 1));
-
- if (c_thread_complexity_ub < thread_complexity_ub) {
- ngroups = c_ngroups;
- nthr_per_group = c_nthr_per_group;
- njobs_per_group_ub = c_njobs_per_group_ub;
- thread_complexity_ub = c_thread_complexity_ub;
- }
- }
-
- assert(njobs_per_group_ub <= max_njobs_per_group || nthr_per_group == 1);
- assert(ngroups * nthr_per_group <= nthr_);
- assert((size_t)njobs_per_group_ub * job_size_ * nthr_ <= max_buffer_size_
- || nthr_per_group == 1); /* no reduction buffer overflow */
- assert(IMPLICATION(!syncable_, nthr_per_group == 1));
-
- ngroups_ = ngroups;
- nthr_per_group_ = nthr_per_group;
- njobs_per_group_ub_ = njobs_per_group_ub;
-}
-
-/* reducer jit-ted driver */
-
-using namespace Xbyak;
-
-template <impl::data_type_t data_type>
-struct reducer_2d_driver_t: public c_compatible {
- typedef typename prec_traits<data_type>::type data_t;
-
- reducer_2d_driver_t(int n_src, size_t src_ld,
- size_t src_step, size_t dst_step, bool nullify_dst)
- : n_src_(n_src), src_ld_(src_ld), src_step_(src_step)
- , dst_step_(dst_step), nullify_dst_(nullify_dst), ker_(nullptr) {}
- virtual ~reducer_2d_driver_t() {}
- void operator()(data_t *dst, const data_t *srcs, size_t ny, size_t nx)
- { assert(ker_); ker_(dst, srcs, ny, nx); }
-
-protected:
- int n_src_;
- size_t src_ld_, src_step_, dst_step_;
- bool nullify_dst_;
- void (*ker_)(data_t *dst, const data_t *srcs, size_t ny, size_t nx);
-};
-
-template <impl::data_type_t data_type, cpu_isa_t isa>
-struct reducer_2d_driver_f_s_32_t: public reducer_2d_driver_t<data_type>,
- public jit_generator
-{
- DECLARE_CPU_JIT_AUX_FUNCTIONS(reducer_2d_driver_f_s_32_t)
-
- /* cpu specific part */
- using Vmm = typename utils::conditional<isa == avx2, Ymm, Zmm>::type;
- const AddressFrame &vmmword = (isa == avx2) ? yword : zword;
- void uni_vadd(const Xmm& x1, const Xmm& x2, const Operand& op)
- { if (data_type == data_type::f32) vaddps(x1, x2, op);
- else vpaddd(x1, x2, op); }
- void uni_add(const Xmm& x1, const Operand& op)
- { if (data_type == data_type::f32) addss(x1, op); else paddd(x1, op); }
-
- const int vlen = cpu_isa_traits<isa>::vlen;
- const int typesize
- = sizeof(typename mkldnn::impl::prec_traits<data_type>::type);
- Xbyak::Reg64 reg_dst = abi_param1;
- Xbyak::Reg64 reg_src = abi_param2;
- Xbyak::Reg64 reg_ny = abi_param3;
- Xbyak::Reg64 reg_nx = abi_param4;
-
- Xbyak::Reg64 reg_x = rax;
- Xbyak::Reg64 reg_src_id = r10;
-
- reducer_2d_driver_f_s_32_t(int n_src, size_t src_ld, size_t src_step,
- size_t dst_step, bool nullify_dst)
- : reducer_2d_driver_t<data_type>(n_src, src_ld, src_step,
- dst_step, nullify_dst)
- { generate(); }
-
- void nullify_dst(int nloads, int load_len) {
- UNUSED(load_len);
- for (int i = 0; i < nloads; ++i)
- uni_vpxor(Vmm(i), Vmm(i), Vmm(i));
- /* prefetches[dst] ? */
- }
-
- void load_dst(int nloads, int load_len) {
- for (int i = 0; i < nloads; ++i) {
- if (load_len == typesize)
- movd(Xmm(i), ptr[reg_dst + i * load_len]);
- else if (load_len == vlen)
- vmovups(Vmm(i), ptr[reg_dst + i * load_len]);
- else
- assert(!"unsupported");
- }
- }
-
- void store_dst(int nloads, int load_len) {
- for (int i = 0; i < nloads; ++i) {
- if (load_len == typesize)
- movd(ptr[reg_dst + i * load_len], Xmm(i));
- else if (load_len == vlen)
- vmovups(ptr[reg_dst + i * load_len], Vmm(i));
- else
- assert(!"unsupported");
- }
- }
-
- void accumulate(int nloads, int load_len, size_t base_off) {
- for (int i = 0; i < nloads; ++i) {
- size_t off = base_off + i * load_len;
-
- if (load_len == typesize)
- uni_add(Xmm(i), ptr[reg_src + off]);
- else if (load_len == vlen)
- uni_vadd(Vmm(i), Vmm(i), vmmword[reg_src + off]);
- else
- assert(!"unsupported");
- }
- }
-
- void loop_x() {
- const int nloads[] = {cpu_isa_traits<isa>::n_vregs, 1, 1};
- const int nbranches = sizeof(nloads) / sizeof(nloads[0]);
-
- const int load_len[nbranches] = {vlen, vlen, typesize};
- Label loop_x_label[nbranches + 1];
-
- mov(reg_x, reg_nx);
-
- for (int id = 0; id < nbranches; ++id) {
- L(loop_x_label[id]);
-
- cmp(reg_x, nloads[id] * load_len[id]);
- jl(loop_x_label[id + 1], T_NEAR);
-
- if (this->nullify_dst_)
- nullify_dst(nloads[id], load_len[id]);
- else
- load_dst(nloads[id], load_len[id]);
-
- if (nloads[id] > 1) {
- Label loop_srcs;
- mov(reg_src_id, this->n_src_);
- L(loop_srcs);
-
- accumulate(nloads[id], load_len[id], 0);
- add(reg_src, this->src_ld_ * typesize);
-
- dec(reg_src_id);
- jnz(loop_srcs, T_NEAR);
-
- sub(reg_src, this->n_src_ * this->src_ld_ * typesize);
- } else {
- for (int src_id = 0; src_id < this->n_src_; ++src_id) {
- const size_t base_off = src_id * this->src_ld_ * typesize;
- accumulate(nloads[id], load_len[id], base_off);
- }
- }
-
- store_dst(nloads[id], load_len[id]);
-
- add(reg_src, nloads[id] * load_len[id]);
- add(reg_dst, nloads[id] * load_len[id]);
-
- sub(reg_x, nloads[id] * load_len[id]);
-
- jmp(loop_x_label[id], T_NEAR);
- }
-
- L(loop_x_label[nbranches]);
-
- /* restore address registers */
- sub(reg_src, reg_nx);
- sub(reg_dst, reg_nx);
- }
-
- void generate() {
- assert(isa == avx2 || isa == avx512_common || isa == avx512_mic);
-
- preamble();
-
- shl(reg_nx, 2);
-
- Label ny_loop;
- L(ny_loop);
-
- loop_x();
-
- add(reg_dst, this->dst_step_ * typesize);
- add(reg_src, this->src_step_ * typesize);
-
- dec(reg_ny);
- jnz(ny_loop, T_NEAR);
-
- postamble();
- this->ker_ = reinterpret_cast<decltype(this->ker_)>(
- const_cast<uint8_t*>(this->getCode()));
- }
-};
-
-template <impl::data_type_t data_type>
-inline reducer_2d_driver_t<data_type> *create_reduce_2d_drv(int n_src,
- size_t src_ld, size_t src_step, size_t dst_step, bool nullify_dst) {
- if (mayiuse(avx512_common))
- return new reducer_2d_driver_f_s_32_t<data_type, avx512_common>(n_src,
- src_ld, src_step, dst_step, nullify_dst);
- else if (mayiuse(avx2))
- return new reducer_2d_driver_f_s_32_t<data_type, avx2>(n_src, src_ld,
- src_step, dst_step, nullify_dst);
- assert(!"unimplemented");
- return nullptr;
-}
-
-/* cpu_reducer_t */
-
-template <impl::data_type_t data_type>
-void cpu_reducer_t<data_type>::conf_t::init_scratchpad(
- memory_tracking::registrar_t &scratchpad) const {
- if (balancer_.nthr_per_group_ == 1) return;
-
- const size_t space_size = balancer_.ngroups_
- * (balancer_.nthr_per_group_ - 1)
- * cpu_reducer_t<data_type>::space_per_thread(balancer_);
- scratchpad.book(key_reducer_space, sizeof(data_t) * space_size, PAGE_4K);
- scratchpad.book(key_reducer_space_bctx,
- sizeof(simple_barrier::ctx_t) * balancer_.ngroups_);
-}
-
-template <impl::data_type_t data_type>
-cpu_reducer_t<data_type>::cpu_reducer_t(const conf_t &conf)
- : conf_(conf), drv_(nullptr)
-{
- if (balancer().nthr_per_group_ == 1) return;
-
- drv_ = create_reduce_2d_drv<data_type>(balancer().nthr_per_group_ - 1,
- space_per_thread(balancer()), 0, 0, false);
-}
-
-template <impl::data_type_t data_type>
-cpu_reducer_t<data_type>::~cpu_reducer_t() { delete drv_; }
-
-template <impl::data_type_t data_type>
-typename cpu_reducer_t<data_type>::data_t *
-cpu_reducer_t<data_type>::get_local_ptr(int ithr, data_t *dst,
- const memory_tracking::grantor_t &scratchpad) const {
- const int id_in_grp = balancer().id_in_group(ithr);
-
- /* threads 0 from each group writes directly to the destination */
- if (id_in_grp == 0)
- return dst + balancer().ithr_job_off(ithr) * balancer().job_size_;
-
- const int grp_id = balancer().group_id(ithr);
- const int offset_factor = grp_id * (balancer().nthr_per_group_ - 1)
- + (id_in_grp - 1);
-
- auto space = scratchpad.template get<data_t>(key_reducer_space);
- return space + offset_factor * space_per_thread(balancer());
-}
-
-template <impl::data_type_t data_type>
-void cpu_reducer_t<data_type>::reduce_nolock(int ithr, data_t *dst,
- const memory_tracking::grantor_t &scratchpad) const {
- bool redundant_reduction = balancer().nthr_per_group_ == 1
- || balancer().idle(ithr);
- if (redundant_reduction) return;
-
-#ifdef SIMPLE_IMPL
- if (balancer().id_in_group(ithr) != 0)
- return; /* only threads 0 do the reduction */
-
- const int njobs_in_grp = balancer().ithr_njobs(ithr);
- data_t *d = get_local_ptr(ithr, dst, scratchpad);
- for (int id_in_grp = 1; id_in_grp < balancer_.nthr_per_group_; ++id_in_grp)
- {
- const data_t *space = get_local_ptr(ithr + id_in_grp, dst, scratchpad);
- for (size_t i = 0; i < (size_t)njobs_in_grp * balancer().job_size_; ++i)
- d[i] += space[i];
- }
-#else
- using namespace utils;
-
- const int id_in_grp = balancer().id_in_group(ithr);
- const int njobs_in_grp = balancer().ithr_njobs(ithr);
- const size_t cl = 64 / sizeof(data_t);
-
- const size_t reduction_size = njobs_in_grp * balancer().job_size_;
- size_t start{0}, end{0};
- balance211(div_up(reduction_size, cl), balancer().nthr_per_group_,
- id_in_grp, start, end);
-
- if (start == end) return;
-
- data_t *d = get_local_ptr(ithr - id_in_grp, dst, scratchpad) + start * cl;
- const data_t *space = get_local_ptr(ithr - id_in_grp + 1, dst, scratchpad)
- + start * cl;
- const size_t len = nstl::min(end * cl, reduction_size) - start * cl;
-
- (*drv_)(d, space, 1, len);
-#endif
-}
-
-template struct cpu_reducer_t<data_type::f32>;
-template struct cpu_reducer_t<data_type::s32>;
-
-/* cpu_reducer_2d_t */
-
-template <impl::data_type_t data_type>
-void cpu_reducer_2d_t<data_type>::conf_t::init_scratchpad(
- memory_tracking::registrar_t &scratchpad) const {
- if (balancer_.nthr_per_group_ == 1) return;
-
- const size_t space_size = balancer_.ngroups_ * balancer_.nthr_per_group_
- * cpu_reducer_2d_t<data_type>::space_per_thread(balancer_);
- scratchpad.book(key_reducer_space, sizeof(data_t) * space_size);
- scratchpad.book(key_reducer_space_bctx,
- sizeof(simple_barrier::ctx_t) * balancer_.ngroups_);
-}
-
-template <impl::data_type_t data_type>
-cpu_reducer_2d_t<data_type>::cpu_reducer_2d_t(const conf_t &conf)
- : conf_(conf), drv_(nullptr)
-{
- if (balancer().nthr_per_group_ == 1) return;
-
- drv_ = create_reduce_2d_drv<data_type>(balancer().nthr_per_group_,
- space_per_thread(balancer()), conf_.job_size_x_, conf_.dst_x_,
- true);
-}
-
-template <impl::data_type_t data_type>
-cpu_reducer_2d_t<data_type>::~cpu_reducer_2d_t() { delete drv_; }
-
-template <impl::data_type_t data_type>
-typename cpu_reducer_2d_t<data_type>::data_t *cpu_reducer_2d_t<data_type>::
-get_local_ptr(int ithr, const memory_tracking::grantor_t &scratchpad) const {
- const int id_in_grp = balancer().id_in_group(ithr);
- const int grp_id = balancer().group_id(ithr);
- const int offset_factor = grp_id * balancer().nthr_per_group_ + id_in_grp;
- auto space = scratchpad.template get<data_t>(key_reducer_space);
- return space + offset_factor * space_per_thread(balancer());
-}
-
-template <impl::data_type_t data_type>
-int cpu_reducer_2d_t<data_type>::choose_x_blocking(int nx, int ny,
- int nthr_per_grp) const {
- // find x_blocking for better balance reducing work between threads
- assert(conf_.x_block_ > 0 && nx > conf_.x_block_
- && nx % conf_.x_block_ == 0);
- int x_blocking = nx / conf_.x_block_;
- int min_x_blocking =
- utils::div_up(x_blocking, nstl::max(1, nthr_per_grp / ny));
- while (true) {
- if (x_blocking % 2 == 0 && x_blocking >= min_x_blocking * 2)
- x_blocking /= 2;
- else if (x_blocking % 3 == 0 && x_blocking >= min_x_blocking * 3)
- x_blocking /= 3;
- else
- break;
- }
- if (x_blocking >= min_x_blocking * 4) x_blocking = 1;
- x_blocking *= conf_.x_block_;
- return x_blocking;
-}
-
-template <impl::data_type_t data_type>
-void cpu_reducer_2d_t<data_type>::reduce_block(const data_t* space_base,
- data_t *dst, int job, int start_y, int start_x,
- int ny_start, int nx_start, int ny_step, int nx_step) const {
- data_t *d = dst + (start_y + ny_start) * conf_.dst_x_
- + start_x + nx_start;
- const data_t *space = space_base + job * balancer().job_size_
- + ny_start * conf_.job_size_x_ + nx_start;
-#ifdef SIMPLE_IMPL
- for (int idg = 0; idg < balancer().nthr_per_group_; ++idg) {
- const data_t *w = &space[idg * space_per_thread(balancer())];
- for (int y = 0; y < ny_step; ++y)
- for (int x = 0; x < nx_step; ++x) {
- d[y * conf_.dst_x_ + x]
- = (idg == 0 ? 0 : d[y * conf_.dst_x_ + x])
- + w[y * conf_.job_size_x_ + x];
- }
- }
-#else
- (*drv_)(d, space, ny_step, nx_step);
-#endif
-}
-
-template <impl::data_type_t data_type>
-void cpu_reducer_2d_t<data_type>::reduce_nolock(int ithr, data_t *dst,
- const memory_tracking::grantor_t &scratchpad) const {
- bool redundant_reduction = balancer().nthr_per_group_ == 1
- || balancer().idle(ithr);
- if (redundant_reduction) return;
-
- const int id_in_grp = balancer().id_in_group(ithr);
- const int njobs_in_grp = balancer().ithr_njobs(ithr);
- const int njobs_x = utils::div_up(conf_.dst_x_, conf_.job_size_x_);
- const int global_job_start = balancer().ithr_job_off(ithr);
-
- const data_t *space_base = get_local_ptr(ithr - id_in_grp, scratchpad);
-
- const int pr_grps = nstl::min(njobs_in_grp, balancer().nthr_per_group_);
- const int pr_nthr_per_grp = balancer().nthr_per_group_ / pr_grps;
-
- if (id_in_grp >= pr_grps * pr_nthr_per_grp)
- return; /* idle */
-
- const int pr_my_grp = id_in_grp / pr_nthr_per_grp;
- const int pr_my_id = id_in_grp % pr_nthr_per_grp;
-
- int pr_job_start{0}, pr_job_end{0};
- balance211(njobs_in_grp, pr_grps, pr_my_grp, pr_job_start, pr_job_end);
-
- for (int j = pr_job_start; j < pr_job_end; ++j) {
- const int global_job = global_job_start + j;
- const int j_y = global_job / njobs_x;
- const int j_x = global_job % njobs_x;
- const int start_y = j_y * conf_.job_size_y_;
- const int start_x = j_x * conf_.job_size_x_;
- const int ny = nstl::min(conf_.dst_y_ - start_y, conf_.job_size_y_);
- const int nx = nstl::min(conf_.dst_x_ - start_x, conf_.job_size_x_);
- int x_blocking = choose_x_blocking(nx, ny, pr_nthr_per_grp);
-
- int nxy_start{0}, nxy_end{0};
- balance211(ny * nx / x_blocking, pr_nthr_per_grp, pr_my_id,
- nxy_start, nxy_end);
- if (nxy_start == nxy_end) continue;
- nxy_start *= x_blocking;
- nxy_end *= x_blocking;
-
- int nxy = nxy_start;
- if (nxy % nx != 0) {
- int nx_step = nstl::min(nx - nxy % nx, nxy_end - nxy);
- reduce_block(space_base, dst, j, start_y, start_x,
- nxy / nx, nxy % nx, 1, nx_step);
- nxy += nx_step;
- }
- if ((nxy_end - nxy) > nx) {
- int ny_step = (nxy_end - nxy) / nx;
- reduce_block(space_base, dst, j, start_y, start_x,
- nxy / nx, nxy % nx, ny_step, nx);
- nxy += nx * ny_step;
- }
- if ((nxy_end - nxy) > 0) {
- reduce_block(space_base, dst, j, start_y, start_x,
- nxy / nx, nxy % nx, 1, nxy_end - nxy);
- }
- }
-}
-
-template struct cpu_reducer_2d_t<data_type::f32>;
-template struct cpu_reducer_2d_t<data_type::s32>;
-
-/* accumulator section */
-
-template <impl::data_type_t data_type>
-cpu_accumulator_1d_t<data_type>::cpu_accumulator_1d_t(): drv_(nullptr) {
- drv_ = create_reduce_2d_drv<data_type>(1, 0, 0, 0, false);
-}
-
-template <impl::data_type_t data_type>
-cpu_accumulator_1d_t<data_type>::~cpu_accumulator_1d_t() {
- delete drv_;
-}
-
-template <impl::data_type_t data_type>
-void cpu_accumulator_1d_t<data_type>::accumulate(data_t *dst,
- const data_t *src, size_t size) {
- (*drv_)(dst, src, 1, size);
-}
-
-template struct cpu_accumulator_1d_t<data_type::f32>;
-template struct cpu_accumulator_1d_t<data_type::s32>;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp
deleted file mode 100644
index 27f5939cd2..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp
+++ /dev/null
@@ -1,334 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_REDUCER_HPP
-#define CPU_REDUCER_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "mkldnn_thread.hpp"
-#include "mkldnn_types.h"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-
-#include "cpu_barrier.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-/** class to perform balancing over 3D array
- *
- * Conceptually the reduction happens according to the picture below:
- *
- * <--job_size->
- * +-----------+ +-----------+ +-----------+ ^
- * | | | | | | |
- * | | | | | | |
- * | 1 | | 2 | . . . | njobs | | reduction_size
- * | | | | | | |
- * | | | | | | |
- * +-----------+ +-----------+ +-----------+ v
- *
- * | | | | | | | | |
- * v v v v v v v v v
- * ===================================================== vertical reduction
- *
- * +-----------+ +-----------+ . . . +-----------+ result
- *
- * In a simple case the result must be contiguous in memory.
- * @class cpu_reducer_t is an implementation.
- *
- * Threads are divided into groups. The groups are independent of each other.
- * Each group may work on several jobs (the distribution is not uniform, since
- * njobs might be not a multiple of groups). Threads within a group work on
- * different parts of the reduction dimension. Thread 0 in each group is called
- * master (@sa reduce_balancer_t::master()).
- *
- * If threading driver does not allow sync between sub-group of threads (e.g.
- * Intel(R) TBB) the # of thread per group is enforced to be 1.
- */
-struct reduce_balancer_t {
- reduce_balancer_t() { init(1, 1, 1, 1, 0); } /* trivial balance */
- reduce_balancer_t(int nthr, int job_size, int njobs, int reduction_size,
- size_t max_buffer_size)
- { init(nthr, job_size, njobs, reduction_size, max_buffer_size); }
-
- reduce_balancer_t &init(int nthr, int job_size, int njobs,
- int reduction_size, size_t max_buffer_size)
- {
- syncable_ = mkldnn_thr_syncable();
- nthr_ = nthr;
- job_size_ = job_size;
- njobs_ = njobs;
- reduction_size_ = reduction_size;
- max_buffer_size_ = max_buffer_size;
- balance();
- return *this;
- }
-
- bool syncable_;
- int nthr_;
- int job_size_, njobs_, reduction_size_;
-
- int ngroups_; /** number of independent work (thread) groups */
- int nthr_per_group_; /** number of threads within a single work group */
- int njobs_per_group_ub_; /** the max # of jobs within a work group */
-
- bool master(int ithr) const { return id_in_group(ithr) == 0; }
- bool idle(int ithr) const { return ithr >= nthr_per_group_ * ngroups_; }
-
- int group_id(int ithr) const { return ithr / nthr_per_group_; }
- int id_in_group(int ithr) const { return ithr % nthr_per_group_; }
-
- int grp_njobs(int grp) const {
- if (grp >= ngroups_) return 0;
- return njobs_ / ngroups_ + (grp < njobs_ % ngroups_);
- }
- int grp_job_off(int grp) const {
- if (grp >= ngroups_) return njobs_;
- return njobs_ / ngroups_ * grp + nstl::min(grp, njobs_ % ngroups_);
- }
-
- int ithr_njobs(int ithr) const { return grp_njobs(group_id(ithr)); }
- int ithr_job_off(int ithr) const { return grp_job_off(group_id(ithr)); }
-
-private:
- size_t max_buffer_size_;
- void balance();
-};
-
-/** forward declaration of reduce driver */
-template <impl::data_type_t data_type> struct reducer_2d_driver_t;
-
-/** class to perform a reduction over 3D array
- *
- * Balancing is based on @class reduce_balancer_t.
- * Restrictions: the result of the reduction must be contiguous in memory. *
- * The reduction happens according to the picture below (once more):
- *
- * <--job_size->
- * +-----------+ +-----------+ +-----------+ ^
- * | | | | | | |
- * | | | | | | |
- * | 1 | | 2 | . . . | njobs | | reduction_size
- * | | | | | | |
- * | | | | | | |
- * +-----------+ +-----------+ +-----------+ v
- *
- * | | | | | | | | |
- * v v v v v v v v v
- * ===================================================== vertical reduction
- *
- * +-----------+ +-----------+ . . . +-----------+ (contiguous) result
- *
- * An example how work might be shared is shown below.
- *
- * In this example group 0 owns 2 (independent) jobs -- 2 big squares.
- * The number of threads per group is also 2 (thread 0 of group 0 and thread 1
- * of group 0). Master threads (i.e. threads with id 0 in corresponding group)
- * from each group put the partial result directly into destination memory,
- * while all the other threads with-in the group use workspace (on the picture
- * the only thread 1). Once intermediate results obtained each group reduces
- * corresponding part (own jobs) to the destination memory.
- *
- * <------- group 0 ------->
- *
- * +-----------+ +-----------+ ^
- * | | | | | thread 0 of reduces to the dest-memory
- * | | | | | group 0 +-----------+ +-----------+
- * |- - - - - -| |- - - - - -| X
- * | | | | | thread 1 of reduces to workspace[tid=1]:
- * | | | | | group 0 +-----------+ +-----------+
- * +-----------+ +-----------+ v
- * | | | | | |
- * v v v v v v
- * ((barrier)) =============================
- *
- * dest-memory: +-----------+ +-----------+
- */
-template <impl::data_type_t data_type>
-struct cpu_reducer_t {
- typedef typename prec_traits<data_type>::type data_t;
-
- struct conf_t {
- conf_t() = default;
- conf_t &init(const reduce_balancer_t &balancer)
- { balancer_ = balancer; return *this; }
-
- void init_scratchpad(memory_tracking::registrar_t &scratchpad) const;
-
- reduce_balancer_t balancer_;
- };
-
- cpu_reducer_t(const conf_t &conf);
- ~cpu_reducer_t();
-
- /** initializes reducer.
- * Must be called from a single thread prior to actual usage */
- void init(const memory_tracking::grantor_t &scratchpad) const {
- if (balancer().nthr_per_group_ == 1) return;
-
- auto bctx = scratchpad.template get<simple_barrier::ctx_t>(
- memory_tracking::names::key_reducer_space_bctx);
- for (int i = 0; i < balancer().ngroups_; ++i)
- simple_barrier::ctx_init(&bctx[i]);
- }
-
- /** for given thread returns the pointer where to put partial results.
- * Reduction destination @p dst must be provided as well (master threads
- * from each group will use it for partial result to reduce memory
- * pressure).
- *
- * @note: job offset is already applied by get_local_ptr(), which means all
- * threads should start writing from the very beginning of returned
- * address.
- */
- data_t *get_local_ptr(int ithr, data_t *dst,
- const memory_tracking::grantor_t &scratchpad) const;
-
- /** performs the reduction with built-in synchronization. */
- void reduce(int ithr, data_t *dst,
- const memory_tracking::grantor_t &scratchpad) const {
- bool redundant_reduction = balancer().nthr_per_group_ == 1
- || balancer().idle(ithr);
- if (redundant_reduction) return;
-
- auto bctx = scratchpad.template get<simple_barrier::ctx_t>(
- memory_tracking::names::key_reducer_space_bctx);
- simple_barrier::barrier(&bctx[balancer().group_id(ithr)],
- balancer().nthr_per_group_);
-
- reduce_nolock(ithr, dst, scratchpad);
- }
-
- const reduce_balancer_t &balancer() const { return conf_.balancer_; }
-
-private:
- static size_t space_per_thread(const reduce_balancer_t &balancer)
- { return balancer.njobs_per_group_ub_ * balancer.job_size_; }
-
- /* The scratchpad is organized as follows:
- *
- * data_t space[nthr_][njobs_per_group_ub_][jobs_size_];
- * simple_barrier::ctx_t barriers[groups_]; */
-
- const conf_t conf_;
- reducer_2d_driver_t<data_type> *drv_;
-
- void reduce_nolock(int ithr, data_t *dst,
- const memory_tracking::grantor_t &scratchpad) const;
-};
-
-template <impl::data_type_t data_type>
-struct cpu_reducer_2d_t {
- typedef typename prec_traits<data_type>::type data_t;
-
- struct conf_t {
- conf_t() = default;
- conf_t &init(const reduce_balancer_t &balancer, int job_size_x,
- int job_size_y, int x_block, int dst_x, int dst_y) {
- balancer_ = balancer;
- job_size_x_ = job_size_x;
- job_size_y_ = job_size_y;
- x_block_ = x_block;
- dst_x_ = dst_x;
- dst_y_ = dst_y;
- return *this;
- }
-
- void init_scratchpad(memory_tracking::registrar_t &scratchpad) const;
-
- reduce_balancer_t balancer_;
- int job_size_x_, job_size_y_, x_block_, dst_x_, dst_y_;
- };
-
- cpu_reducer_2d_t(const conf_t &conf);
- ~cpu_reducer_2d_t();
-
- /** initializes reducer.
- * Must be called from a single thread prior to actual usage */
- void init(const memory_tracking::grantor_t &scratchpad) const {
- if (balancer().nthr_per_group_ == 1) return;
-
- auto bctx = scratchpad.template get<simple_barrier::ctx_t>(
- memory_tracking::names::key_reducer_space_bctx);
- for (int i = 0; i < balancer().ngroups_; ++i)
- simple_barrier::ctx_init(&bctx[i]);
- }
-
- /** for given thread returns the pointer where to put partial results */
- data_t *get_local_ptr(int ithr,
- const memory_tracking::grantor_t &scratchpad) const;
-
- /** performs the reduction with built-in synchronization. */
- void reduce(int ithr, data_t *dst,
- const memory_tracking::grantor_t &scratchpad) const {
- bool redundant_reduction = balancer().nthr_per_group_ == 1
- || balancer().idle(ithr);
- if (redundant_reduction) return;
-
- auto bctx = scratchpad.template get<simple_barrier::ctx_t>(
- memory_tracking::names::key_reducer_space_bctx);
- simple_barrier::barrier(&bctx[balancer().group_id(ithr)],
- balancer().nthr_per_group_);
-
- reduce_nolock(ithr, dst, scratchpad);
- }
-
- const reduce_balancer_t &balancer() const { return conf_.balancer_; }
-
-private:
- static size_t space_per_thread(const reduce_balancer_t &balancer)
- { return balancer.njobs_per_group_ub_ * balancer.job_size_; }
-
- /* The scratchpad is organized as follows:
- *
- * data_t space[nthr_][njobs_per_group_ub_][jobs_size_];
- * simple_barrier::ctx_t barriers[groups_]; */
-
- const conf_t conf_;
- reducer_2d_driver_t<data_type> *drv_;
-
- int choose_x_blocking(int nx, int ny, int nthr_per_grp) const;
- void reduce_block(const data_t* space_base, data_t *dst,
- int job, int start_y, int start_x,
- int ny_start, int nx_start, int ny_step, int nx_step) const;
- void reduce_nolock(int ithr, data_t *dst,
- const memory_tracking::grantor_t &scratchpad) const;
-};
-
-/** simple 1d accumulator: y[:] += x[:] */
-template <impl::data_type_t data_type>
-struct cpu_accumulator_1d_t {
- typedef typename prec_traits<data_type>::type data_t;
-
- cpu_accumulator_1d_t();
- ~cpu_accumulator_1d_t();
- void accumulate(data_t *dst, const data_t *src, size_t size);
-
- reducer_2d_driver_t<data_type> *drv_;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp
deleted file mode 100644
index 82be70353d..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp
+++ /dev/null
@@ -1,262 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-
-#include "cpu_engine.hpp"
-#include "cpu_primitive.hpp"
-#include "cpu_reorder_pd.hpp"
-#include "cpu_memory.hpp"
-#include "type_helpers.hpp"
-
-#include "cpu/jit_uni_reorder.hpp"
-#include "cpu/simple_reorder.hpp"
-#include "cpu/wino_reorder.hpp"
-#include "cpu/rnn/rnn_reorders.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using rpd_create_f = mkldnn::impl::engine_t::reorder_primitive_desc_create_f;
-
-namespace {
-using namespace mkldnn::impl::data_type;
-using namespace mkldnn::impl::format_tag;
-
-#define REG_SR(idt, ifmt, odt, ofmt, ...) \
- simple_reorder_t<idt, ifmt, odt, ofmt, __VA_ARGS__>::pd_t::create
-
-#define REG_SR_BIDIR(idt, ifmt, odt, ofmt) \
- REG_SR(idt, ifmt, odt, ofmt, fmt_order::keep), \
- REG_SR(idt, ifmt, odt, ofmt, fmt_order::reverse)
-
-#define REG_SR_DIRECT_COPY(idt, odt) \
- REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy), \
- REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy_except_dim_0)
-
-static const rpd_create_f cpu_reorder_impl_list[] = {
- /* winograd */
- wino_reorder_t<f32, f32>::pd_t::create,
- //wino_reorder_t<f32, s8>::pd_t::create,
-
- /* rnn reorders */
- rnn_data_reorder_t<f32, u8>::pd_t::create,
- rnn_weights_reorder_t<f32, f32>::pd_t::create,
- rnn_weights_reorder_t<f32, s8>::pd_t::create,
-
- /* conv reorders w/ compensation */
- REG_SR(f32, any, s8, hwio, fmt_order::keep, spec::conv_s8s8),
- REG_SR(f32, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8),
- REG_SR(s8, any, s8, hwio, fmt_order::keep, spec::conv_s8s8),
- REG_SR(s8, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8),
-
- REG_SR(f32, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8),
- REG_SR(f32, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8),
- REG_SR(s8, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8),
- REG_SR(s8, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8),
-
- REG_SR(f32, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8),
- REG_SR(f32, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8),
- REG_SR(s8, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8),
- REG_SR(s8, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8),
-
- REG_SR(f32, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8),
- REG_SR(s8, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8),
-
- REG_SR(f32, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8),
- REG_SR(s8, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8),
-
- REG_SR(f32, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8),
- REG_SR(s8, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8),
- REG_SR(f32, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8),
- REG_SR(s8, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8),
-
- /* regular reorders */
-
-#if defined(__INTEL_COMPILER) || (defined(__GNUC__) && !defined(__clang__))
- /* Direct copy for icc which is faster than jitted code;
- * Direct copy for gcc which might or might not be faster than jitted
- * code, but still worth it because doesn't require jitting, i.e. much
- * faster creation time. This is tentative solution and should be removed
- * later (when we will cache jitted code?...). */
- REG_SR_DIRECT_COPY(f32, f32),
-#endif
-
-#ifdef __INTEL_COMPILER
- /* direct copy for icc, which is faster than jitted code */
- /*
- REG_SR_DIRECT_COPY(f32, s32),
- REG_SR_DIRECT_COPY(f32, s8),
- REG_SR_DIRECT_COPY(f32, u8),
- REG_SR_DIRECT_COPY(s32, f32),
- REG_SR_DIRECT_COPY(s32, s32),
- REG_SR_DIRECT_COPY(s32, s8),
- REG_SR_DIRECT_COPY(s32, u8),
- REG_SR_DIRECT_COPY(s8, f32),
- REG_SR_DIRECT_COPY(s8, s32),
- REG_SR_DIRECT_COPY(s8, s8),
- REG_SR_DIRECT_COPY(s8, u8),
- REG_SR_DIRECT_COPY(u8, f32),
- REG_SR_DIRECT_COPY(u8, s32),
- REG_SR_DIRECT_COPY(u8, s8),
- REG_SR_DIRECT_COPY(u8, u8),
- */
-#endif
-
- /* jit */
- jit_uni_reorder_create,
-
- /* fp32: flat <-> blocked with tail */
- /*
- REG_SR_BIDIR(f32, any, f32, nCw4c),
- REG_SR_BIDIR(f32, any, f32, nCw8c),
- REG_SR_BIDIR(f32, any, f32, OIw4i4o),
- REG_SR_BIDIR(f32, any, f32, OIw8i8o),
- REG_SR_BIDIR(f32, any, f32, OIw8o8i),
- REG_SR_BIDIR(f32, any, f32, gOIw4i4o),
- REG_SR_BIDIR(f32, any, f32, gOIw8i8o),
- REG_SR_BIDIR(f32, any, f32, gOIw8o8i),
-
- REG_SR_BIDIR(f32, any, f32, nCw16c),
- REG_SR_BIDIR(f32, any, f32, OIw16o16i),
- REG_SR_BIDIR(f32, any, f32, OIw16i16o),
- REG_SR_BIDIR(f32, any, f32, IOw16o16i),
- REG_SR_BIDIR(f32, any, f32, gOIw16o16i),
- REG_SR_BIDIR(f32, any, f32, gOIw16i16o),
- REG_SR_BIDIR(f32, any, f32, gIOw16o16i),
-
- REG_SR_BIDIR(f32, any, f32, nChw4c),
- REG_SR_BIDIR(f32, any, f32, nChw8c),
- REG_SR_BIDIR(f32, any, f32, OIhw4i4o),
- REG_SR_BIDIR(f32, any, f32, Ohwi8o),
-
- REG_SR_BIDIR(f32, any, f32, OIhw8i8o),
- REG_SR_BIDIR(f32, any, f32, OIhw8o8i),
- REG_SR_BIDIR(f32, any, f32, gOIhw4i4o),
- REG_SR_BIDIR(f32, any, f32, gOIhw4o4i),
- REG_SR_BIDIR(f32, any, f32, gOhwi8o),
- REG_SR_BIDIR(f32, any, f32, gOIhw8i8o),
- REG_SR_BIDIR(f32, any, f32, gOIhw8o8i),
-
- REG_SR_BIDIR(f32, any, f32, nChw16c),
- REG_SR_BIDIR(f32, any, f32, Oihw4o),
- REG_SR_BIDIR(f32, any, f32, Oihw16o),
- REG_SR_BIDIR(f32, any, f32, Ohwi4o),
- REG_SR_BIDIR(f32, any, f32, Ohwi16o),
- REG_SR_BIDIR(f32, any, f32, OIhw16o16i),
- REG_SR_BIDIR(f32, any, f32, OIhw16i16o),
- REG_SR_BIDIR(f32, any, f32, IOhw16o16i),
- REG_SR_BIDIR(f32, any, f32, gOihw4o),
- REG_SR_BIDIR(f32, any, f32, gOihw16o),
- REG_SR_BIDIR(f32, any, f32, gOhwi4o),
- REG_SR_BIDIR(f32, any, f32, gOhwi16o),
- REG_SR_BIDIR(f32, any, f32, gOIhw16o16i),
- REG_SR_BIDIR(f32, any, f32, gOIhw16i16o),
- REG_SR_BIDIR(f32, any, f32, gIOhw16o16i),
-
- REG_SR_BIDIR(f32, any, f32, nCdhw4c),
- REG_SR_BIDIR(f32, any, f32, nCdhw8c),
- REG_SR_BIDIR(f32, any, f32, OIdhw4i4o),
- REG_SR_BIDIR(f32, any, f32, Odhwi8o),
- REG_SR_BIDIR(f32, any, f32, OIdhw8i8o),
- REG_SR_BIDIR(f32, any, f32, OIdhw8o8i),
- REG_SR_BIDIR(f32, any, f32, gOIdhw4i4o),
- REG_SR_BIDIR(f32, any, f32, gOdhwi8o),
- REG_SR_BIDIR(f32, any, f32, gOIdhw8i8o),
- REG_SR_BIDIR(f32, any, f32, gOIdhw8o8i),
-
- REG_SR_BIDIR(f32, any, f32, nCdhw16c),
- REG_SR_BIDIR(f32, any, f32, Oidhw4o),
- REG_SR_BIDIR(f32, any, f32, Oidhw16o),
- REG_SR_BIDIR(f32, any, f32, Odhwi16o),
- REG_SR_BIDIR(f32, any, f32, OIdhw16o16i),
- REG_SR_BIDIR(f32, any, f32, OIdhw16i16o),
- REG_SR_BIDIR(f32, any, f32, gOidhw4o),
- REG_SR_BIDIR(f32, any, f32, gOidhw16o),
- REG_SR_BIDIR(f32, any, f32, gOdhwi16o),
- REG_SR_BIDIR(f32, any, f32, gOIdhw16o16i),
- REG_SR_BIDIR(f32, any, f32, gOIdhw16i16o),
- */
-
- /* fp32: blocked <-> blocked with tail */
- REG_SR_BIDIR(f32, nCw8c, f32, nCw16c),
- REG_SR_BIDIR(f32, nChw8c, f32, nChw16c),
- REG_SR_BIDIR(f32, nCdhw8c, f32, nCdhw16c),
-
- /* int: flat <-> blocked with tail */
- /*
- REG_SR_BIDIR(f32, any, s32, nChw16c),
- REG_SR_BIDIR(f32, any, s8, nChw16c),
- REG_SR_BIDIR(f32, any, u8, nChw16c),
- REG_SR_BIDIR(s32, any, f32, nChw16c),
- REG_SR_BIDIR(s32, any, s32, nChw16c),
- REG_SR_BIDIR(s32, any, s8, nChw16c),
- REG_SR_BIDIR(s32, any, u8, nChw16c),
- REG_SR_BIDIR(s8, any, f32, nChw16c),
- REG_SR_BIDIR(s8, any, s32, nChw16c),
- REG_SR_BIDIR(s8, any, s8, nChw16c),
- REG_SR_BIDIR(s8, any, u8, nChw16c),
- REG_SR_BIDIR(u8, any, f32, nChw16c),
- REG_SR_BIDIR(u8, any, s32, nChw16c),
- REG_SR_BIDIR(u8, any, s8, nChw16c),
- REG_SR_BIDIR(u8, any, u8, nChw16c),
-
- REG_SR_BIDIR(f32, any, f32, OIhw4i16o4i),
- REG_SR_BIDIR(f32, any, s8, OIhw4i16o4i),
- REG_SR_BIDIR(s8, any, f32, OIhw4i16o4i),
- REG_SR_BIDIR(s8, any, s8, OIhw4i16o4i),
- REG_SR_BIDIR(f32, any, s8, gOIhw4i16o4i),
- REG_SR_BIDIR(s8, any, f32, gOIhw4i16o4i),
- REG_SR_BIDIR(f32, any, f32, gOIhw4i16o4i),
- REG_SR_BIDIR(s8, any, s8, gOIhw4i16o4i),
- */
-
- /* reference: the last line of defence */
- /*
- REG_SR(f32, any, f32, any, fmt_order::any, spec::reference),
- REG_SR(f32, any, s32, any, fmt_order::any, spec::reference),
- REG_SR(f32, any, s8, any, fmt_order::any, spec::reference),
- REG_SR(f32, any, u8, any, fmt_order::any, spec::reference),
-
- REG_SR(s32, any, f32, any, fmt_order::any, spec::reference),
- REG_SR(s32, any, s32, any, fmt_order::any, spec::reference),
- REG_SR(s32, any, s8, any, fmt_order::any, spec::reference),
- REG_SR(s32, any, u8, any, fmt_order::any, spec::reference),
-
- REG_SR(s8, any, f32, any, fmt_order::any, spec::reference),
- REG_SR(s8, any, s32, any, fmt_order::any, spec::reference),
- REG_SR(s8, any, s8, any, fmt_order::any, spec::reference),
- REG_SR(s8, any, u8, any, fmt_order::any, spec::reference),
-
- REG_SR(u8, any, f32, any, fmt_order::any, spec::reference),
- REG_SR(u8, any, s32, any, fmt_order::any, spec::reference),
- REG_SR(u8, any, u8, any, fmt_order::any, spec::reference),
- REG_SR(u8, any, s8, any, fmt_order::any, spec::reference),
- */
-
- /* eol */
- nullptr,
-};
-}
-
-const rpd_create_f *cpu_engine_t::get_reorder_implementation_list() const {
- return cpu_reorder_impl_list;
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp
deleted file mode 100644
index 1622eb6849..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp
+++ /dev/null
@@ -1,48 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_REORDER_PD_HPP
-#define CPU_REORDER_PD_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "reorder_pd.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct cpu_reorder_pd_t: public reorder_pd_t {
- using reorder_pd_t::reorder_pd_t;
-
- status_t init() {
- const auto &post_ops = attr()->post_ops_;
- bool args_ok = IMPLICATION(post_ops.len_ != 0, post_ops.len_ == 1
- && post_ops.entry_[0].kind == primitive_kind::sum);
- scratchpad_engine_ = src_engine_;
- return args_ok ? status::success : status::unimplemented;
- }
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp
deleted file mode 100644
index f16587b99f..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp
+++ /dev/null
@@ -1,41 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_SHUFFLE_PD_HPP
-#define CPU_SHUFFLE_PD_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "shuffle_pd.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct cpu_shuffle_pd_t: public shuffle_pd_t {
- using shuffle_pd_t::shuffle_pd_t;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp
deleted file mode 100644
index 3a39eab974..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp
+++ /dev/null
@@ -1,45 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_SOFTMAX_PD_HPP
-#define CPU_SOFTMAX_PD_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "softmax_pd.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct cpu_softmax_fwd_pd_t: public softmax_fwd_pd_t {
- using softmax_fwd_pd_t::softmax_fwd_pd_t;
-};
-
-struct cpu_softmax_bwd_pd_t: public softmax_bwd_pd_t {
- using softmax_bwd_pd_t::softmax_bwd_pd_t;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp
deleted file mode 100644
index 1ab5d9f174..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp
+++ /dev/null
@@ -1,48 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "cpu_engine.hpp"
-
-/*
-#include "cpu/ref_sum.hpp"
-#include "cpu/simple_sum.hpp"
-*/
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using spd_create_f = mkldnn::impl::engine_t::sum_primitive_desc_create_f;
-
-namespace {
-#define INSTANCE(...) __VA_ARGS__::pd_t::create
-static const spd_create_f cpu_sum_impl_list[] = {
- /*
- INSTANCE(simple_sum_t<data_type::f32>),
- INSTANCE(ref_sum_t),
- */
- nullptr,
-};
-#undef INSTANCE
-}
-
-const spd_create_f *cpu_engine_t::get_sum_implementation_list() const {
- return cpu_sum_impl_list;
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp
deleted file mode 100644
index 0965129f9b..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp
+++ /dev/null
@@ -1,39 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_SUM_PD_HPP
-#define CPU_SUM_PD_HPP
-
-#include "c_types_map.hpp"
-#include "sum_pd.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct cpu_sum_pd_t: public sum_pd_t {
- using sum_pd_t::sum_pd_t;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp
deleted file mode 100644
index a9810dec28..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp
+++ /dev/null
@@ -1,372 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-#include <cmath>
-
-#include "mkldnn_thread.hpp"
-#include "utils.hpp"
-#include "gemm_utils_f32.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-namespace gemm_utils {
-#define BM_NOCOPY_AVX 64
-#define BN_NOCOPY_AVX 48
-#define BK_NOCOPY_AVX 384
-#define BN_LARGE_NOCOPY_AVX 192
-#define BM_SMALL_NOCOPY_AVX 16
-#define BN_SMALL_NOCOPY_AVX 1
-#define BK_SMALL_NOCOPY_AVX 4
-// Determine number of threads for each dimension of a 3-D partitioning
-// algorithm based on input parameters
-// m/n/k - First/second/third parameter for GEMM
-// nthrs - total available number of threads
-// nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension
-// BM/BN/BK - blocking values
-void calc_nthr_nocopy_avx(int m, int n, int k,
- int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN,
- int *BK)
-{
- int nthr, nthr_m, nthr_n, nthr_k;
- int MB, NB, KB;
-
- nthr = nthrs;
- nthr_m = (m + BM_NOCOPY_AVX - 1) / BM_NOCOPY_AVX;
- nthr_n = (n + BN_NOCOPY_AVX - 1) / BN_NOCOPY_AVX;
- nthr_k = 1;
-
- // Partition along K dimension
- // - if threading allows having barriers (e.g. OMP)
- // - if there is not enough parallelism along M or N
- if (mkldnn_thr_syncable()) {
- int nthr_other = nthr_k = 1;
- while ((nthr_m * nthr_n * nthr_other < nthr)
- && (k / (nthr_other + 1) > BK_NOCOPY_AVX)) {
- nthr_other++;
- if ((nthr / nthr_other) * nthr_other > 0.9 * nthr)
- nthr_k = nthr_other;
- }
- }
- nthr /= nthr_k;
-
- if (nthr_m == 1)
- nthr_n = nthr;
- if (nthr_n == 1)
- nthr_m = nthr;
-
- // Simple partition reduction
- while (nthr_m * nthr_n > nthr)
- if (nthr_m > nthr_n)
- nthr_m--;
- else
- nthr_n--;
- while (nthr_m * nthr_n < nthr)
- if (nthr_m < nthr_n)
- nthr_m++;
- else
- nthr_n++;
-
- if ((nthr_m * nthr_n > nthr) && (nthr_m > 1) && (nthr_n > 1)) {
-
- if (nthr_m <= nthr_n) {
- nthr_m = (int)sqrt((double)nthr);
- if (nthr_m > (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX)
- nthr_m = (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX;
- nthr_n = nthr / nthr_m;
-
- while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) {
- nthr_m--;
- nthr_n = nthr / nthr_m;
- }
- } else {
- nthr_n = (int)sqrt((double)nthr);
- if (nthr_n > (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX)
- nthr_n = (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX;
- nthr_m = nthr / nthr_n;
-
- while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) {
- nthr_n--;
- nthr_m = nthr / nthr_n;
- }
- }
- }
-
- MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX - 1;
- MB -= MB % BM_SMALL_NOCOPY_AVX;
- NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX - 1;
- NB -= NB % BN_SMALL_NOCOPY_AVX;
- KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX - 1;
- KB -= KB % BK_SMALL_NOCOPY_AVX;
-
- if (MB * nthr_m > m)
- nthr_m = (m + MB - 1) / MB;
- if (NB * nthr_n > n)
- nthr_n = (n + NB - 1) / NB;
- if (KB * nthr_k > k)
- nthr_k = (k + KB - 1) / KB;
-
- *nthrs_m = nthr_m;
- *nthrs_n = nthr_n;
- *nthrs_k = nthr_k;
-
- *BM = MB;
- *BN = NB;
- *BK = KB;
-}
-#undef BM_NOCOPY_AVX
-#undef BN_NOCOPY_AVX
-#undef BK_NOCOPY_AVX
-#undef BN_LARGE_NOCOPY_AVX
-#undef BM_SMALL_NOCOPY_AVX
-#undef BN_SMALL_NOCOPY_AVX
-#undef BK_SMALL_NOCOPY_AVX
-
-#define BM_NOCOPY_AVX512_COMMON 32
-#define BN_NOCOPY_AVX512_COMMON 64
-#define BK_NOCOPY_AVX512_COMMON 192
-#define BN_LARGE_NOCOPY_AVX512_COMMON 192
-#define BM_SMALL_NOCOPY_AVX512_COMMON 16
-#define BN_SMALL_NOCOPY_AVX512_COMMON 1
-#define BK_SMALL_NOCOPY_AVX512_COMMON 4
-// Determine number of threads for each dimension of a 3-D partitioning
-// algorithm based on input parameters
-// m/n/k - First/second/third parameter for GEMM
-// nthrs - total available number of threads
-// nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension
-// BM/BN/BK - blocking values
-void calc_nthr_nocopy_avx512_common(int m,
- int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k,
- int *BM, int *BN, int *BK)
-{
- int nthr, nthr_m, nthr_n, nthr_k = 1;
- int MB, NB, KB;
- nthr = nthrs;
-
- int counter = 0;
- float ratio_float = 1.;
- int ratio = 1;
- nthr = nthrs;
- int nthr_m_gt_n;
-
- // Partition along K dimension
- // - if threading allows having barriers (e.g. OMP)
- // - if there is not enough parallelism along M or N
- if (mkldnn_thr_syncable()) {
- if (n <= 2 * BN_NOCOPY_AVX512_COMMON &&
- m <= 2 * BM_NOCOPY_AVX512_COMMON * nthr) {
- nthr_k = k / BK_NOCOPY_AVX512_COMMON;
- if (nthr_k > nthr / 4)
- nthr_k = nthr / 4;
- if (nthr_k < 1)
- nthr_k = 1;
-
- while ((nthr_k > 1) && (nthr % nthr_k)) {
- nthr_k--;
- }
- nthr /= nthr_k;
- } else {
- nthr_k = 1;
- }
- }
- nthr_m = (m + BM_NOCOPY_AVX512_COMMON - 1) / BM_NOCOPY_AVX512_COMMON;
- nthr_n = (n + BN_NOCOPY_AVX512_COMMON - 1) / BN_NOCOPY_AVX512_COMMON;
-
- if (nthr_m < 1)
- nthr_m = 1;
- if (nthr_n < 1)
- nthr_n = 1;
-
- nthr_m_gt_n = nthr_m > nthr_n ? 1 : 0;
- ratio_float = (float)nthr_m / nthr_n;
-
- if (nthr_m_gt_n)
- ratio = (int)ratio_float;
- else
- ratio = (int)(1. / ratio_float);
-
- // scale down nthr_m and nthr_n if they are too large
- while (nthr_m * nthr_n > 4 * nthr) {
- nthr_m /= 2;
- nthr_n /= 2;
- }
-
- if (nthr_m < 1)
- nthr_m = 1;
- if (nthr_n < 1)
- nthr_n = 1;
-
- // Simple partition reduction
- counter = 0;
- while (nthr_m * nthr_n > nthr) {
- if (nthr_m > nthr_n) {
- if (counter < ratio)
- nthr_m--;
- else {
- nthr_n--;
- counter = -1;
- }
- } else {
- if (counter < ratio)
- nthr_n--;
- else {
- nthr_m--;
- counter = -1;
- }
- }
- counter++;
- }
-
- // Simple partition increment
- counter = 0;
- while (nthr_m * nthr_n < 0.95 * nthr) {
- if (nthr_m > nthr_n) {
- if (counter < ratio)
- nthr_m++;
- else {
- nthr_n++;
- counter = -1;
- }
- } else {
- if (counter < ratio)
- nthr_n++;
- else {
- nthr_m++;
- counter = -1;
- }
- }
- counter++;
- }
-
- // if nothing works out, then this should work
- if ((nthr_m * nthr_n > nthr)) {
-
- if (nthr_m <= nthr_n) {
- nthr_m = (int)sqrt((double)nthr);
- if (nthr_m > (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1)
- / BM_SMALL_NOCOPY_AVX512_COMMON)
- nthr_m = (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1)
- / BM_SMALL_NOCOPY_AVX512_COMMON;
- nthr_n = nthr / nthr_m;
-
- while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) {
- nthr_m--;
- nthr_n = nthr / nthr_m;
- }
- } else {
- nthr_n = (int)sqrt((double)nthr);
- if (nthr_n > (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1)
- / BN_SMALL_NOCOPY_AVX512_COMMON)
- nthr_n = (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1)
- / BN_SMALL_NOCOPY_AVX512_COMMON;
- nthr_m = nthr / nthr_n;
-
- while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) {
- nthr_n--;
- nthr_m = nthr / nthr_n;
- }
- }
- }
-
- MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX512_COMMON - 1;
- MB -= MB % BM_SMALL_NOCOPY_AVX512_COMMON;
- NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX512_COMMON - 1;
- NB -= NB % BN_SMALL_NOCOPY_AVX512_COMMON;
- KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX512_COMMON - 1;
- KB -= KB % BK_SMALL_NOCOPY_AVX512_COMMON;
-
- if (MB * nthr_m > m)
- nthr_m = (m + MB - 1) / MB;
- if (NB * nthr_n > n)
- nthr_n = (n + NB - 1) / NB;
- if (KB * nthr_k > k)
- nthr_k = (k + KB - 1) / KB;
-
- *nthrs_m = nthr_m;
- *nthrs_n = nthr_n;
- *nthrs_k = nthr_k;
-
- *BM = MB;
- *BN = NB;
- *BK = KB;
-}
-#undef BM_NOCOPY_AVX512_COMMON
-#undef BN_NOCOPY_AVX512_COMMON
-#undef BK_NOCOPY_AVX512_COMMON
-#undef BN_LARGE_NOCOPY_AVX512_COMMON
-#undef BM_SMALL_NOCOPY_AVX512_COMMON
-#undef BN_SMALL_NOCOPY_AVX512_COMMON
-#undef BK_SMALL_NOCOPY_AVX512_COMMON
-
-// Partition n values as equally as possible among nthr threads
-// and set the offset (t_offset) and number of values (t_block) for ithr
-// Assumption: 0 <= ithr < nthr
-void partition_unit_diff(
- int ithr, int nthr, int n, int *t_offset, int *t_block)
-{
- int band = n / nthr;
- if (band == 0)
- band = 1;
- int tail = n - band * nthr;
- if (tail < 0)
- tail = 0;
-
- if (ithr < tail) {
- band++;
- *t_offset = band * ithr;
- *t_block = band;
- } else {
- *t_offset = band * ithr + tail;
- *t_block = band;
- }
-
- if (*t_offset >= n) {
- *t_offset = 0;
- *t_block = 0;
- }
-
- if (*t_offset + *t_block > n) {
- *t_block = n - *t_offset;
- }
-}
-
-// Sum the m*n values from p_src into p_dst, assuming the two-dimensional
-// arrays have leading dimensions ld_src and ld_dst, respectively
-template<typename data_t>
-void sum_two_matrices(int m, int n,
- data_t * __restrict p_src, dim_t ld_src,
- data_t * __restrict p_dst, dim_t ld_dst)
-{
- int i, j;
- for (j = 0; j < n; j++) {
- for (i = 0; i < m; i++) {
- p_dst[i + j * ld_dst] += p_src[i + j * ld_src];
- }
- }
-}
-
-template
-void sum_two_matrices<float>(int m, int n,
- float * __restrict p_src, dim_t ld_src,
- float * __restrict p_dst, dim_t ld_dst);
-
-template
-void sum_two_matrices<double>(int m, int n,
- double * __restrict p_src, dim_t ld_src,
- double * __restrict p_dst, dim_t ld_dst);
-}
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp
deleted file mode 100644
index 3352298b4a..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp
+++ /dev/null
@@ -1,72 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef GEMM_UTILS_HPP
-#define GEMM_UTILS_HPP
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-namespace gemm_utils {
-// Alias for any dimension related variable.
-typedef ptrdiff_t dim_t;
-
-template <typename T, bool isTransA, bool isTransB>
-struct gemm_traits {};
-
-template <bool isTransA, bool isTransB>
-struct gemm_traits<double, isTransA, isTransB> {
- static constexpr int m = 8;
- static constexpr int n = 6;
- static constexpr int BM = 4032;
- static constexpr int BN = isTransA ? 96 : 192;
- static constexpr int BK = isTransB ? 96 : 512;
-};
-
-template <bool isTransA, bool isTransB>
-struct gemm_traits<float, isTransA, isTransB> {
- static constexpr int m = 16;
- static constexpr int n = 6;
- static constexpr int BM = 4032;
- static constexpr int BN = isTransA ? 96 : 48;
- static constexpr int BK = isTransB ? 96 : 256;
-};
-
-template <typename T>
-using unroll_factor = gemm_traits<T, false, false>;
-
-template <typename data_t>
-void sum_two_matrices(int m, int n,
- data_t * __restrict p_src, dim_t ld_src,
- data_t * __restrict p_dst, dim_t ld_dst);
-
-void calc_nthr_nocopy_avx512_common(int m,
- int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k,
- int *BM, int *BN, int *BK);
-
-void calc_nthr_nocopy_avx(int m, int n, int k,
- int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN,
- int *BK);
-
-void partition_unit_diff(
- int ithr, int nthr, int n, int *t_offset, int *t_block);
-};
-
-}
-}
-}
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp
deleted file mode 100644
index d7be43e392..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp
+++ /dev/null
@@ -1,2131 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <cmath>
-#include <mutex>
-
-#include "mkldnn_thread.hpp"
-#include "utils.hpp"
-
-#include "ref_gemm_f32.hpp"
-#include "gemm_utils_f32.hpp"
-#include "jit_avx512_common_gemm_f32.hpp"
-
-#include "jit_generator.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-#define CACHE_LINE_SIZE 64
-
-#define STACKSIZE get_size_of_abi_save_regs()
-#ifdef _WIN32
-#define STACK_K_CAPACITY 32
-#else
-#define STACK_K_CAPACITY 2048
-#endif
-#define SIZE 4
-#define OFFSET 128
-#define BASE_SHIFT 2
-#define SECOND_FETCH unroll_n
-#define UNROLL_M 48
-#define UNROLL_N 8
-
-namespace avx512_common_gemm_f32 {
-using namespace gemm_utils;
-
-struct xbyak_gemm : public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_gemm_f32_xbyak_gemm)
-
- xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false,
- void *code_ptr = nullptr,
- size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE)
- : jit_generator(code_ptr, code_size)
- {
- using namespace Xbyak;
-
- enum { ver_avx512_core, ver_avx512_mic } ver =
- mayiuse(avx512_core) ? ver_avx512_core : ver_avx512_mic;
-
- bool isBeta0 = (beta == 0.0);
- bool isBetaN = (!isBeta0 && beta != 1.0);
-
- // various definitions for convenience
- auto ARG_M = abi_param1;
- auto ARG_N = abi_param2;
- auto K = abi_param3;
- auto ARG_ALPHA = abi_param4;
-#ifdef _WIN32
- auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE];
- auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE +
- sizeof(float *) + STACKSIZE];
- const auto stackOffset = OFFSET_SHADOWSPACE +
- sizeof(float *) + STACKSIZE;
- auto A = rsi;
- auto LDA = rdi;
-#else
- auto ARG_A = r8;
- auto ARG_LDA = r9;
- const auto stackOffset = STACKSIZE;
- auto A = ARG_A;
- auto LDA = ARG_LDA;
-#endif
- auto ARG_B = ptr[rsp + 8 + stackOffset];
- auto ARG_LDB = ptr[rsp + 16 + stackOffset];
- auto ARG_BETA = ptr[rsp + 24 + stackOffset];
- auto ARG_C = ptr[rsp + 32 + stackOffset];
- auto ARG_LDC = ptr[rsp + 40 + stackOffset];
- auto ARG_BIAS = ptr[rsp + 48 + stackOffset];
- auto ARG_WS = ptr[rsp + 56 + stackOffset];
-
- auto B = r11;
- auto LDB = rbx;
- auto LDC = r13;
- auto LL = rax;
- auto AO1 = abi_param2;
- auto BO1 = abi_param4;
- auto BO2 = rbp;
- auto CO1 = r14;
- auto CO2 = r15;
- auto LDB3 = r10;
- auto LDA4 = abi_param1;
- auto AA = r12;
- auto BIAS1 = abi_param1;
-
- auto M = qword[rsp + 0];
- auto N = qword[rsp + 8];
- auto FLAG = qword[rsp + 16];
- auto I = qword[rsp + 24];
- auto C = qword[rsp + 32];
- auto BIAS = qword[rsp + 40];
- auto ALPHA = qword[rsp + 48];
- auto BETA = qword[rsp + 64];
- auto ORIG_A = qword[rsp + 80];
- auto ORIG_SP = qword[rsp + 120];
-
- auto ZSTRIDE = zmm4;
- auto VALPHA = zmm6;
- auto VBETA = zmm7;
- auto VBIAS1 = zmm1;
- auto VBIAS2 = zmm2;
- auto VBIAS3 = zmm3;
-
- auto PREFETCHSIZEA = ver == ver_avx512_core ? 48 : 80;
- auto PREFETCHSIZEB = 16;
-
- Zmm regs[] = { zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15,
- zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23, zmm24,
- zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31 };
-
- // Function for packing if needed
- auto do_pack = [&](int unroll_m) {
- Label pack2, pack3, pack4, pack10;
-
- mov(BO1, A);
- lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
- mov(LL, K);
- sar(LL, 2);
- jle(pack3, T_NEAR);
- align(16);
-
- L(pack2);
- if (!isTransA) {
- for (int i = 0; i < 4; i++) {
- vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]);
- if (unroll_m > 16)
- vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]);
- if (unroll_m > 32)
- vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]);
- add(BO1, LDA);
-
- vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE]
- | k1,
- zmm0);
- if (unroll_m > 16)
- vmovups(ptr[AO1
- + (unroll_m * i + 1 * 16 - OFFSET)
- * SIZE]
- | k2,
- zmm1);
- if (unroll_m > 32)
- vmovups(ptr[AO1
- + (unroll_m * i + 2 * 16 - OFFSET)
- * SIZE]
- | k3,
- zmm2);
- }
- } else {
- for (int i = 0; i < 4; i++) {
- kmovw(k4, k1);
- vgatherqps(ymm5 | k4,
- ptr[BO1 + ZSTRIDE + (i - OFFSET) * SIZE]);
- lea(BO2, ptr[BO1 + LDA * 8]);
- kshiftrw(k4, k1, 8);
- vgatherqps(ymm6 | k4,
- ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
- vshuff64x2(zmm0, zmm5, zmm6, 0x44);
-
- if (unroll_m > 16) {
- lea(BO2, ptr[BO2 + LDA * 8]);
- kmovw(k4, k2);
- vgatherqps(ymm5 | k4,
- ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
- lea(BO2, ptr[BO2 + LDA * 8]);
- kshiftrw(k4, k2, 8);
- vgatherqps(ymm6 | k4,
- ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
- vshuff64x2(zmm1, zmm5, zmm6, 0x44);
- }
-
- if (unroll_m > 32) {
- lea(BO2, ptr[BO2 + LDA * 8]);
- kmovw(k4, k3);
- vgatherqps(ymm5 | k4,
- ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
- lea(BO2, ptr[BO2 + LDA * 8]);
- kshiftrw(k4, k3, 8);
- vgatherqps(ymm6 | k4,
- ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
- lea(BO2, ptr[BO2 + LDA * 8]);
- vshuff64x2(zmm2, zmm5, zmm6, 0x44);
- }
-
- vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE],
- zmm0 | k1);
- if (unroll_m > 16)
- vmovups(ptr[AO1
- + (unroll_m * i + 1 * 16 - OFFSET)
- * SIZE],
- zmm1 | k2);
- if (unroll_m > 32)
- vmovups(ptr[AO1
- + (unroll_m * i + 2 * 16 - OFFSET)
- * SIZE],
- zmm2 | k3);
- }
- add(BO1, 4 * SIZE);
- }
- add(AO1, unroll_m * 4 * SIZE);
-
- sub(LL, 1);
- jg(pack2, T_NEAR);
- align(16);
-
- L(pack3);
- mov(LL, K);
- and_(LL, 3);
- jle(pack10, T_NEAR);
- align(16);
-
- L(pack4);
- if (!isTransA) {
- vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]);
- if (unroll_m > 16)
- vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]);
- if (unroll_m > 32)
- vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]);
- add(BO1, LDA);
- } else {
- kmovw(k4, k1);
- vgatherqps(ymm5 | k4, ptr[BO1 + ZSTRIDE + (0 - OFFSET) * SIZE]);
- lea(BO2, ptr[BO1 + LDA * 8]);
- kshiftrw(k4, k1, 8);
- vgatherqps(ymm6 | k4, ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
- vshuff64x2(zmm0, zmm5, zmm6, 0x44);
-
- if (unroll_m > 16) {
- lea(BO2, ptr[BO2 + LDA * 8]);
- kmovw(k4, k2);
- vgatherqps(ymm5 | k4,
- ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
- lea(BO2, ptr[BO2 + LDA * 8]);
- kshiftrw(k4, k2, 8);
- vgatherqps(ymm6 | k4,
- ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
- vshuff64x2(zmm1, zmm5, zmm6, 0x44);
- }
-
- if (unroll_m > 32) {
- lea(BO2, ptr[BO2 + LDA * 8]);
- kmovw(k4, k3);
- vgatherqps(ymm5 | k4,
- ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
- lea(BO2, ptr[BO2 + LDA * 8]);
- kshiftrw(k4, k3, 8);
- vgatherqps(ymm6 | k4,
- ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
- lea(BO2, ptr[BO2 + LDA * 8]);
- vshuff64x2(zmm2, zmm5, zmm6, 0x44);
- }
- add(BO1, SIZE);
- }
-
- vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
- zmm0 | k1);
- if (unroll_m > 16)
- vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 16 - OFFSET) * SIZE],
- zmm1 | k2);
- if (unroll_m > 32)
- vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 16 - OFFSET) * SIZE],
- zmm2 | k3);
-
- add(AO1, unroll_m * SIZE);
- sub(LL, 1);
- jg(pack4, T_NEAR);
- align(16);
-
- L(pack10);
- };
-
- // Function to update C, covering masking and other considerations
- auto update = [&](Zmm reg, bool useCO1, int offset, int mask,
- bool useScale = false) {
- vmulps(reg, reg, VALPHA);
- if (!isBeta0) {
- if (!useScale) {
- switch (mask) {
- case 0:
- if (useCO1)
- vmovups(zmm0, ptr[CO1 + offset * SIZE]);
- else
- vmovups(zmm0, ptr[CO2 + offset * SIZE]);
- break;
- case 1:
- if (useCO1)
- vmovups(zmm0 | k1 | T_z, ptr[CO1 + offset * SIZE]);
- else
- vmovups(zmm0 | k1 | T_z, ptr[CO2 + offset * SIZE]);
- break;
- case 2:
- if (useCO1)
- vmovups(zmm0 | k2 | T_z, ptr[CO1 + offset * SIZE]);
- else
- vmovups(zmm0 | k2 | T_z, ptr[CO2 + offset * SIZE]);
- break;
- case 3:
- if (useCO1)
- vmovups(zmm0 | k3 | T_z, ptr[CO1 + offset * SIZE]);
- else
- vmovups(zmm0 | k3 | T_z, ptr[CO2 + offset * SIZE]);
- break;
- }
- } else {
- switch (mask) {
- case 0:
- if (useCO1)
- vmovups(zmm0, ptr[CO1 + LDC + offset * SIZE]);
- else
- vmovups(zmm0, ptr[CO2 + LDC + offset * SIZE]);
- break;
- case 1:
- if (useCO1)
- vmovups(zmm0 | k1 | T_z,
- ptr[CO1 + LDC + offset * SIZE]);
- else
- vmovups(zmm0 | k1 | T_z,
- ptr[CO2 + LDC + offset * SIZE]);
- break;
- case 2:
- if (useCO1)
- vmovups(zmm0 | k2 | T_z,
- ptr[CO1 + LDC + offset * SIZE]);
- else
- vmovups(zmm0 | k2 | T_z,
- ptr[CO2 + LDC + offset * SIZE]);
- break;
- case 3:
- if (useCO1)
- vmovups(zmm0 | k3 | T_z,
- ptr[CO1 + LDC + offset * SIZE]);
- else
- vmovups(zmm0 | k3 | T_z,
- ptr[CO2 + LDC + offset * SIZE]);
- break;
- }
- }
- if (!isBetaN) {
- vaddps(zmm0, reg, zmm0);
- } else {
- vfmadd132ps(zmm0, reg, VBETA);
- }
- if (!useScale) {
- switch (mask) {
- case 0:
- if (useCO1)
- vmovups(ptr[CO1 + offset * SIZE], zmm0);
- else
- vmovups(ptr[CO2 + offset * SIZE], zmm0);
- break;
- case 1:
- if (useCO1)
- vmovups(ptr[CO1 + offset * SIZE], zmm0 | k1);
- else
- vmovups(ptr[CO2 + offset * SIZE], zmm0 | k1);
- break;
- case 2:
- if (useCO1)
- vmovups(ptr[CO1 + offset * SIZE], zmm0 | k2);
- else
- vmovups(ptr[CO2 + offset * SIZE], zmm0 | k2);
- break;
- case 3:
- if (useCO1)
- vmovups(ptr[CO1 + offset * SIZE], zmm0 | k3);
- else
- vmovups(ptr[CO2 + offset * SIZE], zmm0 | k3);
- break;
- }
- } else {
- switch (mask) {
- case 0:
- if (useCO1)
- vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0);
- else
- vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0);
- break;
- case 1:
- if (useCO1)
- vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k1);
- else
- vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k1);
- break;
- case 2:
- if (useCO1)
- vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k2);
- else
- vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k2);
- break;
- case 3:
- if (useCO1)
- vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k3);
- else
- vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k3);
- break;
- }
- }
- } else {
- if (!useScale) {
- switch (mask) {
- case 0:
- if (useCO1)
- vmovups(ptr[CO1 + offset * SIZE], reg);
- else
- vmovups(ptr[CO2 + offset * SIZE], reg);
- break;
- case 1:
- if (useCO1)
- vmovups(ptr[CO1 + offset * SIZE], reg | k1);
- else
- vmovups(ptr[CO2 + offset * SIZE], reg | k1);
- break;
- case 2:
- if (useCO1)
- vmovups(ptr[CO1 + offset * SIZE], reg | k2);
- else
- vmovups(ptr[CO2 + offset * SIZE], reg | k2);
- break;
- case 3:
- if (useCO1)
- vmovups(ptr[CO1 + offset * SIZE], reg | k3);
- else
- vmovups(ptr[CO2 + offset * SIZE], reg | k3);
- break;
- }
- } else {
- switch (mask) {
- case 0:
- if (useCO1)
- vmovups(ptr[CO1 + LDC + offset * SIZE], reg);
- else
- vmovups(ptr[CO2 + LDC + offset * SIZE], reg);
- break;
- case 1:
- if (useCO1)
- vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k1);
- else
- vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k1);
- break;
- case 2:
- if (useCO1)
- vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k2);
- else
- vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k2);
- break;
- case 3:
- if (useCO1)
- vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k3);
- else
- vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k3);
- break;
- }
- }
- }
- vpxorq(reg, reg, reg);
- };
-
- // Loop with unroll_n - 2 FMAs; called by innerkernel
- auto fmaloop = [&](int unroll_m, int unroll_n, int iteration) {
- for (int i = 2; i < unroll_n; i++) {
- if (ver == ver_avx512_core) {
- if (!isTransB) {
- switch (i) {
- case 2:
- vbroadcastss(
- zmm3,
- ptr[BO1 + LDB * 2
- + (iteration - OFFSET) * SIZE]);
- break;
- case 3:
- vbroadcastss(
- zmm3,
- ptr[BO1 + LDB3
- + (iteration - OFFSET) * SIZE]);
- break;
- case 4:
- vbroadcastss(zmm3,
- ptr[BO2 + (iteration - OFFSET) * SIZE]);
- break;
- case 5:
- vbroadcastss(
- zmm3,
- ptr[BO2 + LDB * 1
- + (iteration - OFFSET) * SIZE]);
- break;
- case 6:
- vbroadcastss(
- zmm3,
- ptr[BO2 + LDB * 2
- + (iteration - OFFSET) * SIZE]);
- break;
- case 7:
- vbroadcastss(
- zmm3,
- ptr[BO2 + LDB3
- + (iteration - OFFSET) * SIZE]);
- break;
- }
- } else {
- vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
- }
- vfmadd231ps(regs[i], zmm3, zmm0);
- if (unroll_m >= 32)
- vfmadd231ps(regs[i + 8], zmm3, zmm1);
- if (unroll_m >= 48)
- vfmadd231ps(regs[i + 16], zmm3, zmm2);
- } else {
- if (!isTransB) {
- switch (i) {
- case 2:
- vfmadd231ps(regs[i], zmm0,
- zword_b[BO1 + LDB * 2
- + (iteration - OFFSET) * SIZE]);
- if (unroll_m >= 32)
- vfmadd231ps(regs[i + 8], zmm1,
- zword_b[BO1 + LDB * 2
- + (iteration - OFFSET) * SIZE]);
- if (unroll_m >= 48)
- vfmadd231ps(regs[i + 16], zmm2,
- zword_b[BO1 + LDB * 2
- + (iteration - OFFSET) * SIZE]);
- break;
- case 3:
- vfmadd231ps(regs[i], zmm0,
- zword_b[BO1 + LDB3
- + (iteration - OFFSET) * SIZE]);
- if (unroll_m >= 32)
- vfmadd231ps(regs[i + 8], zmm1,
- zword_b[BO1 + LDB3
- + (iteration - OFFSET) * SIZE]);
- if (unroll_m >= 48)
- vfmadd231ps(regs[i + 16], zmm2,
- zword_b[BO1 + LDB3
- + (iteration - OFFSET) * SIZE]);
- break;
- case 4:
- vfmadd231ps(regs[i], zmm0,
- zword_b[BO2 + (iteration - OFFSET) * SIZE]);
- if (unroll_m >= 32)
- vfmadd231ps(regs[i + 8], zmm1,
- zword_b[BO2 + (iteration - OFFSET) * SIZE]);
- if (unroll_m >= 48)
- vfmadd231ps(regs[i + 16], zmm2,
- zword_b[BO2 + (iteration - OFFSET) * SIZE]);
- break;
- case 5:
- vfmadd231ps(regs[i], zmm0,
- zword_b[BO2 + LDB * 1
- + (iteration - OFFSET) * SIZE]);
- if (unroll_m >= 32)
- vfmadd231ps(regs[i + 8], zmm1,
- zword_b[BO2 + LDB * 1
- + (iteration - OFFSET) * SIZE]);
- if (unroll_m >= 48)
- vfmadd231ps(regs[i + 16], zmm2,
- zword_b[BO2 + LDB * 1
- + (iteration - OFFSET) * SIZE]);
- break;
- case 6:
- vfmadd231ps(regs[i], zmm0,
- zword_b[BO2 + LDB * 2
- + (iteration - OFFSET) * SIZE]);
- if (unroll_m >= 32)
- vfmadd231ps(regs[i + 8], zmm1,
- zword_b[BO2 + LDB * 2
- + (iteration - OFFSET) * SIZE]);
- if (unroll_m >= 48)
- vfmadd231ps(regs[i + 16], zmm2,
- zword_b[BO2 + LDB * 2
- + (iteration - OFFSET) * SIZE]);
- break;
- case 7:
- vfmadd231ps(regs[i], zmm0,
- zword_b[BO2 + LDB3
- + (iteration - OFFSET) * SIZE]);
- if (unroll_m >= 32)
- vfmadd231ps(regs[i + 8], zmm1,
- zword_b[BO2 + LDB3
- + (iteration - OFFSET) * SIZE]);
- if (unroll_m >= 48)
- vfmadd231ps(regs[i + 16], zmm2,
- zword_b[BO2 + LDB3
- + (iteration - OFFSET) * SIZE]);
- break;
- }
- } else {
- vfmadd231ps(
- regs[i], zmm0, zword_b[BO1 + (i - OFFSET) * SIZE]);
- if (unroll_m >= 32)
- vfmadd231ps(regs[i + 8], zmm1,
- zword_b[BO1 + (i - OFFSET) * SIZE]);
- if (unroll_m >= 48)
- vfmadd231ps(regs[i + 16], zmm2,
- zword_b[BO1 + (i - OFFSET) * SIZE]);
- }
- }
- }
- };
-
- // Innerkernel; called by kernel
- auto innerkernel = [&](int unroll_m, int unroll_n, bool isDirect,
- bool isCopy, bool doCPrefetch, bool isUnmasked = true) {
- for (int i = 0; i < 8; i++) {
- if (!isDirect) {
- prefetcht0(ptr[AO1
- + (PREFETCHSIZEA + i * unroll_m + 0 * 16 - OFFSET)
- * SIZE]);
- if (unroll_m >= 32)
- prefetcht0(ptr[AO1
- + (PREFETCHSIZEA + i * unroll_m + 1 * 16 - OFFSET)
- * SIZE]);
- if (unroll_m >= 48)
- prefetcht0(ptr[AO1
- + (PREFETCHSIZEA + i * unroll_m + 2 * 16 - OFFSET)
- * SIZE]);
- } else {
- prefetcht0(ptr[AO1 + LDA4 + (16 * 0 * SIZE)]);
- if (unroll_m >= 32)
- prefetcht0(ptr[AO1 + LDA4 + (16 * 1 * SIZE)]);
- if (unroll_m >= 48)
- prefetcht0(ptr[AO1 + LDA4 + (16 * 2 * SIZE)]);
- }
-
- if (!isDirect) {
- if (i != 0) {
- if (isUnmasked || unroll_m > 16) {
- vmovups(zmm0,
- ptr[AO1
- + (unroll_m * i + 0 * 16 - OFFSET)
- * SIZE]);
- } else {
- vmovups(zmm0 | k1 | T_z,
- ptr[AO1
- + (unroll_m * i + 0 * 16 - OFFSET)
- * SIZE]);
- }
- if (unroll_m >= 32) {
- if (isUnmasked || unroll_m > 32) {
- vmovups(zmm1, ptr[AO1
- + (unroll_m * i + 1 * 16
- - OFFSET)
- * SIZE]);
- } else {
- vmovups(zmm1 | k2 | T_z,
- ptr[AO1
- + (unroll_m * i + 1 * 16
- - OFFSET)
- * SIZE]);
- }
- }
- if (unroll_m >= 48) {
- if (isUnmasked) {
- vmovups(zmm2, ptr[AO1
- + (unroll_m * i + 2 * 16
- - OFFSET)
- * SIZE]);
- } else {
- vmovups(zmm2 | k3 | T_z,
- ptr[AO1
- + (unroll_m * i + 2 * 16
- - OFFSET)
- * SIZE]);
- }
- }
- }
- } else {
- if (isUnmasked || unroll_m > 16) {
- vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
- } else {
- vmovups(zmm0 | k1 | T_z,
- ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
- }
- if (unroll_m >= 32) {
- if (isUnmasked || unroll_m > 32) {
- vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
- } else {
- vmovups(zmm1 | k2 | T_z,
- ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
- }
- }
- if (unroll_m >= 48) {
- if (isUnmasked) {
- vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
- } else {
- vmovups(zmm2 | k3 | T_z,
- ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
- }
- }
- add(AO1, LDA);
- }
-
- if (ver == ver_avx512_core) {
- if (!isTransB) {
- vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
- } else {
- vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]);
- }
- vfmadd231ps(regs[0], zmm3, zmm0);
- if (unroll_m >= 32)
- vfmadd231ps(regs[0 + 8], zmm3, zmm1);
- if (unroll_m >= 48)
- vfmadd231ps(regs[0 + 16], zmm3, zmm2);
- } else {
- if (!isTransB) {
- vfmadd231ps(regs[0], zmm0,
- zword_b[BO1 + (i - OFFSET) * SIZE]);
- if (unroll_m >= 32)
- vfmadd231ps(regs[0 + 8], zmm1,
- zword_b[BO1 + (i - OFFSET) * SIZE]);
- if (unroll_m >= 48)
- vfmadd231ps(regs[0 + 16], zmm2,
- zword_b[BO1 + (i - OFFSET) * SIZE]);
- } else {
- vfmadd231ps(regs[0], zmm0,
- zword_b[BO1 + (0 - OFFSET) * SIZE]);
- if (unroll_m >= 32)
- vfmadd231ps(regs[0 + 8], zmm1,
- zword_b[BO1 + (0 - OFFSET) * SIZE]);
- if (unroll_m >= 48)
- vfmadd231ps(regs[0 + 16], zmm2,
- zword_b[BO1 + (0 - OFFSET) * SIZE]);
- }
- }
-
- if (unroll_n >= i + 1) {
- if (!isTransB) {
- switch (i) {
- case 0:
- prefetcht0(
- ptr[BO1 + (PREFETCHSIZEB - OFFSET) * SIZE]);
- break;
- case 1:
- prefetcht0(ptr[BO1 + LDB
- + (PREFETCHSIZEB - OFFSET) * SIZE]);
- break;
- case 2:
- prefetcht0(ptr[BO1 + LDB * 2
- + (PREFETCHSIZEB - OFFSET) * SIZE]);
- break;
- case 3:
- prefetcht0(ptr[BO1 + LDB3
- + (PREFETCHSIZEB - OFFSET) * SIZE]);
- break;
- case 4:
- prefetcht0(
- ptr[BO2 + (PREFETCHSIZEB - OFFSET) * SIZE]);
- break;
- case 5:
- prefetcht0(ptr[BO2 + LDB
- + (PREFETCHSIZEB - OFFSET) * SIZE]);
- break;
- case 6:
- prefetcht0(ptr[BO2 + LDB * 2
- + (PREFETCHSIZEB - OFFSET) * SIZE]);
- break;
- case 7:
- prefetcht0(ptr[BO2 + LDB3
- + (PREFETCHSIZEB - OFFSET) * SIZE]);
- break;
- }
- }
- }
-
- if (unroll_n >= 2) {
- if (ver == ver_avx512_core) {
- if (!isTransB) {
- vbroadcastss(zmm3,
- ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
- } else {
- vbroadcastss(zmm3, ptr[BO1 + (1 - OFFSET) * SIZE]);
- }
- vfmadd231ps(regs[1], zmm3, zmm0);
- if (unroll_m >= 32)
- vfmadd231ps(regs[1 + 8], zmm3, zmm1);
- if (unroll_m >= 48)
- vfmadd231ps(regs[1 + 16], zmm3, zmm2);
- } else {
- if (!isTransB) {
- vfmadd231ps(regs[1], zmm0,
- zword_b[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
- if (unroll_m >= 32)
- vfmadd231ps(regs[1 + 8], zmm1,
- zword_b[BO1 + LDB * 1
- + (i - OFFSET) * SIZE]);
- if (unroll_m >= 48)
- vfmadd231ps(regs[1 + 16], zmm2,
- zword_b[BO1 + LDB * 1
- + (i - OFFSET) * SIZE]);
- } else {
- vfmadd231ps(regs[1], zmm0,
- zword_b[BO1 + (1 - OFFSET) * SIZE]);
- if (unroll_m >= 32)
- vfmadd231ps(regs[1 + 8], zmm1,
- zword_b[BO1 + (1 - OFFSET) * SIZE]);
- if (unroll_m >= 48)
- vfmadd231ps(regs[1 + 16], zmm2,
- zword_b[BO1 + (1 - OFFSET) * SIZE]);
- }
- }
- }
-
- if (isCopy) {
- if (isUnmasked || unroll_m > 16) {
- vmovups(ptr[LDA4
- + (unroll_m * i + 0 * 16 - OFFSET)
- * SIZE],
- zmm0);
- } else {
- vmovups(ptr[LDA4
- + (unroll_m * i + 0 * 16 - OFFSET)
- * SIZE],
- zmm0 | k1);
- }
- if (unroll_m >= 32) {
- if (isUnmasked || unroll_m > 32) {
- vmovups(ptr[LDA4
- + (unroll_m * i + 1 * 16 - OFFSET)
- * SIZE],
- zmm1);
- } else {
- vmovups(ptr[LDA4
- + (unroll_m * i + 1 * 16 - OFFSET)
- * SIZE],
- zmm1 | k2);
- }
- }
- if (unroll_m >= 48) {
- if (isUnmasked) {
- vmovups(ptr[LDA4
- + (unroll_m * i + 2 * 16 - OFFSET)
- * SIZE],
- zmm2);
- } else {
- vmovups(ptr[LDA4
- + (unroll_m * i + 2 * 16 - OFFSET)
- * SIZE],
- zmm2 | k3);
- }
- }
- if (i == 7)
- sub(LDA4, -unroll_m * 8 * SIZE);
- }
- fmaloop(unroll_m, unroll_n, i);
-
- if (i == 1) {
- if (doCPrefetch) {
- if (ver == ver_avx512_core)
- prefetchw(ptr[CO2 + 0 * 16 * SIZE]);
- else
- prefetcht0(ptr[CO2 + 0 * 16 * SIZE]);
- }
- }
- if (i == 3) {
- if (doCPrefetch && unroll_m >= 32) {
- if (ver == ver_avx512_core)
- prefetchw(ptr[CO2 + 1 * 16 * SIZE]);
- else
- prefetcht0(ptr[CO2 + 1 * 16 * SIZE]);
- }
- if (!isTransA) {
- if (ver == ver_avx512_core)
- prefetcht0(ptr[AA + 16 * 0 * SIZE]);
- else
- prefetcht2(ptr[AA + 16 * 0 * SIZE]);
- }
- }
- if (i == 5) {
- if (doCPrefetch) {
- if (unroll_m >= 48) {
- if (ver == ver_avx512_core)
- prefetchw(ptr[CO2 + 2 * 16 * SIZE]);
- else
- prefetcht0(ptr[CO2 + 2 * 16 * SIZE]);
- }
- add(CO2, LDC);
- }
- if (!isTransA) {
- if (unroll_m >= 32) {
- if (ver == ver_avx512_core)
- prefetcht0(ptr[AA + 16 * 1 * SIZE]);
- else
- prefetcht2(ptr[AA + 16 * 1 * SIZE]);
- }
- }
- }
-
- if (isTransB) {
- prefetcht0(ptr[BO1 + BO2]);
- add(BO1, LDB);
- }
- } // end of for loop
-
- if (!isTransB) {
- sub(BO1, -8 * SIZE);
- if (unroll_n >= 4)
- sub(BO2, -8 * SIZE);
- }
- if (!isTransA) {
- if (unroll_m >= 48) {
- if (ver == ver_avx512_core)
- prefetcht0(ptr[AA + 16 * 2 * SIZE]);
- else
- prefetcht2(ptr[AA + 16 * 2 * SIZE]);
- }
- lea(AA, ptr[AA + LDA]);
- }
-
- if (!isDirect) {
- if (isUnmasked || unroll_m > 16) {
- vmovups(zmm0,
- ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]);
- } else {
- vmovups(zmm0 | k1 | T_z,
- ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]);
- }
- if (unroll_m >= 32) {
- if (isUnmasked || unroll_m > 32) {
- vmovups(zmm1, ptr[AO1
- + (unroll_m * 8 + 1 * 16 - OFFSET)
- * SIZE]);
- } else {
- vmovups(zmm1 | k2 | T_z,
- ptr[AO1
- + (unroll_m * 8 + 1 * 16 - OFFSET)
- * SIZE]);
- }
- }
- if (unroll_m >= 48) {
- if (isUnmasked) {
- vmovups(zmm2, ptr[AO1
- + (unroll_m * 8 + 2 * 16 - OFFSET)
- * SIZE]);
- } else {
- vmovups(zmm2 | k3 | T_z,
- ptr[AO1
- + (unroll_m * 8 + 2 * 16 - OFFSET)
- * SIZE]);
- }
- }
- sub(AO1, -unroll_m * 8 * SIZE);
- }
-
- sub(LL, 1);
- };
-
- // Main kernel; does prefetching and calls innerkernel
- // After calculating results in registers, writes back to C matrix by
- // calling update
- auto kernel = [&](int unroll_m, int unroll_n, bool isDirect,
- bool isCopy, bool isUnmasked = true) {
- if (!isDirect) {
- lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
- } else {
- mov(AO1, A);
- }
-
- if (isCopy) {
- lea(LDA4, ptr[rsp + 128 + OFFSET * SIZE]);
- } else {
- auto step = ver == ver_avx512_core ? 2 : 4;
- lea(LDA4, ptr[LDA * step + (16 - 1 - OFFSET) * SIZE]);
- }
-
- if (isTransB) {
- lea(BO2, ptr[LDB * 4 + (16 / 2 - 1 - OFFSET) * SIZE]);
- }
-
- if (!isDirect) {
- if (isUnmasked || unroll_m > 16) {
- vmovups(zmm0,
- ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]);
- } else {
- vmovups(zmm0 | k1 | T_z,
- ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]);
- }
- if (unroll_m >= 32) {
- if (isUnmasked || unroll_m > 32) {
- vmovups(zmm1, ptr[AO1
- + (unroll_m * 0 + 1 * 16 - OFFSET)
- * SIZE]);
- } else {
- vmovups(zmm1 | k2 | T_z,
- ptr[AO1
- + (unroll_m * 0 + 1 * 16 - OFFSET)
- * SIZE]);
- }
- }
- if (unroll_m >= 48) {
- if (isUnmasked) {
- vmovups(zmm2, ptr[AO1
- + (unroll_m * 0 + 2 * 16 - OFFSET)
- * SIZE]);
- } else {
- vmovups(zmm2 | k3 | T_z,
- ptr[AO1
- + (unroll_m * 0 + 2 * 16 - OFFSET)
- * SIZE]);
- }
- }
- }
-
- Label kernel12, kernel13, kernel14, kernel15, kernel16, kernel18;
-
- mov(LL, K);
- sar(LL, 3);
- sub(LL, SECOND_FETCH);
- jle(kernel13, T_NEAR);
- align(16);
-
- L(kernel12);
- innerkernel(
- unroll_m, unroll_n, isDirect, isCopy, false, isUnmasked);
- jg(kernel12, T_NEAR);
- align(16);
-
- L(kernel13);
- lea(CO2, ptr[CO1 + (16 - 1) * SIZE]);
- add(LL, unroll_n);
- jle(kernel15, T_NEAR);
- align(16);
-
- L(kernel14);
- innerkernel(unroll_m, unroll_n, isDirect, isCopy, true, isUnmasked);
- jg(kernel14, T_NEAR);
- align(16);
-
- L(kernel15);
- mov(LL, K);
- and_(LL, 7);
- jle(kernel18, T_NEAR);
- align(16);
-
- L(kernel16);
- if (isDirect) {
- if (isUnmasked || unroll_m > 16) {
- vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
- } else {
- vmovups(zmm0 | k1 | T_z,
- ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
- }
- if (unroll_m >= 32) {
- if (isUnmasked || unroll_m > 32) {
- vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
- } else {
- vmovups(zmm1 | k2 | T_z,
- ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
- }
- }
- if (unroll_m >= 48) {
- if (isUnmasked) {
- vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
- } else {
- vmovups(zmm2 | k3 | T_z,
- ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
- }
- }
- add(AO1, LDA);
- }
-
- for (int i = 0; i < unroll_n; i++) {
- if (!isTransB) {
- switch (i) {
- case 0:
- vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]);
- break;
- case 1:
- vbroadcastss(
- zmm3, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
- break;
- case 2:
- vbroadcastss(
- zmm3, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
- break;
- case 3:
- vbroadcastss(
- zmm3, ptr[BO1 + LDB3 + (0 - OFFSET) * SIZE]);
- break;
- case 4:
- vbroadcastss(zmm3, ptr[BO2 + (0 - OFFSET) * SIZE]);
- break;
- case 5:
- vbroadcastss(
- zmm3, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
- break;
- case 6:
- vbroadcastss(
- zmm3, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
- break;
- case 7:
- vbroadcastss(
- zmm3, ptr[BO2 + LDB3 + (0 - OFFSET) * SIZE]);
- break;
- }
- } else {
- vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
- }
- vfmadd231ps(regs[i], zmm3, zmm0);
- if (unroll_m >= 32) {
- vfmadd231ps(regs[i + 8], zmm3, zmm1);
- }
- if (unroll_m >= 48) {
- vfmadd231ps(regs[i + 16], zmm3, zmm2);
- }
- }
-
- if (isCopy) {
- if (isUnmasked || unroll_m > 16) {
- vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
- zmm0);
- } else {
- vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
- zmm0 | k1);
- }
- if (unroll_m >= 32) {
- if (isUnmasked || unroll_m > 32) {
- vmovups(ptr[LDA4
- + (unroll_m * 0 + 1 * 16 - OFFSET)
- * SIZE],
- zmm1);
- } else {
- vmovups(ptr[LDA4
- + (unroll_m * 0 + 1 * 16 - OFFSET)
- * SIZE],
- zmm1 | k2);
- }
- }
- if (unroll_m >= 48) {
- if (isUnmasked) {
- vmovups(ptr[LDA4
- + (unroll_m * 0 + 2 * 16 - OFFSET)
- * SIZE],
- zmm2);
- } else {
- vmovups(ptr[LDA4
- + (unroll_m * 0 + 2 * 16 - OFFSET)
- * SIZE],
- zmm2 | k3);
- }
- }
- sub(LDA4, -unroll_m * SIZE);
- }
-
- if (!isDirect) {
- if (isUnmasked || unroll_m > 16) {
- vmovups(zmm0,
- ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]);
- } else {
- vmovups(zmm0 | k1 | T_z,
- ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]);
- }
- if (unroll_m >= 32) {
- if (isUnmasked || unroll_m > 32) {
- vmovups(zmm1, ptr[AO1
- + (unroll_m * 1 + 1 * 16 - OFFSET)
- * SIZE]);
- } else {
- vmovups(zmm1 | k2 | T_z,
- ptr[AO1
- + (unroll_m * 1 + 1 * 16 - OFFSET)
- * SIZE]);
- }
- }
- if (unroll_m >= 48) {
- if (isUnmasked) {
- vmovups(zmm2, ptr[AO1
- + (unroll_m * 1 + 2 * 16 - OFFSET)
- * SIZE]);
- } else {
- vmovups(zmm2 | k3 | T_z,
- ptr[AO1
- + (unroll_m * 1 + 2 * 16 - OFFSET)
- * SIZE]);
- }
- }
- sub(AO1, -unroll_m * SIZE);
- }
-
- if (!isTransB) {
- sub(BO1, -SIZE);
- if (unroll_n >= 4) {
- sub(BO2, -SIZE);
- }
- } else {
- add(BO1, LDB);
- }
-
- sub(LL, 1);
- jg(kernel16, T_NEAR);
- align(16);
-
- L(kernel18);
- vbroadcastss(VALPHA, ALPHA);
-
- if (isBetaN) {
- vbroadcastss(VBETA, BETA);
- }
-
- // Write back the results; all beta cases need to be handled
- if (hasBias) {
- mov(BIAS1, BIAS);
- if (isUnmasked || unroll_m > 16)
- vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]);
- else
- vmovups(VBIAS1 | k1 | T_z, ptr[BIAS1 + 0 * SIZE]);
- if (unroll_m >= 32) {
- if (isUnmasked || unroll_m > 32)
- vmovups(VBIAS2, ptr[BIAS1 + 16 * SIZE]);
- else
- vmovups(VBIAS2 | k2 | T_z, ptr[BIAS1 + 16 * SIZE]);
- }
- if (unroll_m >= 48) {
- if (isUnmasked)
- vmovups(VBIAS3, ptr[BIAS1 + 32 * SIZE]);
- else
- vmovups(VBIAS3 | k3 | T_z, ptr[BIAS1 + 32 * SIZE]);
- }
- }
-
- for (int i = 0; i < unroll_n; i++) {
- bool useScale = i % 2 != 0;
- bool useCO1 = i < 2;
- if (i == 2)
- lea(CO2, ptr[CO1 + LDC * 2]);
- if (i == 4 || i == 6)
- lea(CO2, ptr[CO2 + LDC * 2]);
- if (hasBias)
- vaddps(regs[i], VBIAS1, regs[i]);
- if (isUnmasked || unroll_m > 16) {
- update(regs[i], useCO1, 0, 0, useScale);
- } else {
- update(regs[i], useCO1, 0, 1, useScale);
- }
- if (unroll_m >= 32) {
- if (hasBias)
- vaddps(regs[i + 8], VBIAS2, regs[i + 8]);
- if (isUnmasked || unroll_m > 32) {
- update(regs[i + 8], useCO1, 16, 0, useScale);
- } else {
- update(regs[i + 8], useCO1, 16, 2, useScale);
- }
- }
- if (unroll_m >= 48) {
- if (hasBias)
- vaddps(regs[i + 16], VBIAS3, regs[i + 16]);
- if (isUnmasked) {
- update(regs[i + 16], useCO1, 32, 0, useScale);
- } else {
- update(regs[i + 16], useCO1, 32, 3, useScale);
- }
- }
- }
-
- switch (unroll_n) {
- case 1: add(CO1, LDC); break;
- case 2: lea(CO1, ptr[CO1 + LDC * 2]); break;
- case 3: lea(CO1, ptr[CO2 + LDC * 1]); break;
- case 4: lea(CO1, ptr[CO2 + LDC * 2]); break;
- case 5: lea(CO1, ptr[CO2 + LDC * 1]); break;
- case 6: lea(CO1, ptr[CO2 + LDC * 2]); break;
- case 7: lea(CO1, ptr[CO2 + LDC * 1]); break;
- case 8: lea(CO1, ptr[CO2 + LDC * 2]); break;
- }
-
- // Compute next address of B
- if (!isTransB) {
- lea(rax, ptr[K * SIZE]);
- switch (unroll_n) {
- case 1:
- add(BO1, LDB);
- add(BO2, LDB);
- break;
- case 2:
- lea(BO1, ptr[BO1 + LDB * 2]);
- lea(BO2, ptr[BO2 + LDB * 2]);
- break;
- case 3:
- lea(BO1, ptr[BO1 + LDB3]);
- lea(BO2, ptr[BO2 + LDB3]);
- break;
- case 4:
- lea(BO1, ptr[BO1 + LDB * 4]);
- lea(BO2, ptr[BO2 + LDB * 4]);
- break;
- case 5:
- lea(BO1, ptr[BO1 + LDB * 4]);
- add(BO1, LDB);
- lea(BO2, ptr[BO2 + LDB * 4]);
- add(BO2, LDB);
- break;
- case 6:
- lea(BO1, ptr[BO1 + LDB3 * 2]);
- lea(BO2, ptr[BO2 + LDB3 * 2]);
- break;
- case 7:
- lea(BO1, ptr[BO1 + LDB * 8]);
- sub(BO1, LDB);
- lea(BO2, ptr[BO2 + LDB * 8]);
- sub(BO2, LDB);
- break;
- case 8:
- lea(BO1, ptr[BO1 + LDB * 8]);
- lea(BO2, ptr[BO2 + LDB * 8]);
- break;
- }
- sub(BO1, rax);
- sub(BO2, rax);
- } else {
- mov(rax, LDB);
- imul(rax, K);
- sub(BO1, rax);
- add(BO1, unroll_n * SIZE);
- }
- };
-
- // High-level subroutine; does packing if needed, then splits C matrix.
- // Operates on chunks of 48 rows, 8 columns at a time (handling tail
- // cases appropriately by doing 32 or 16 rows, and/or with masking,
- // and/or fewer columns).
- auto subloop = [&](int unroll_m) {
- Label l_subloop_20x[8], l_subloop_mask_20x[8];
- Label l_subloop_30x[8], l_subloop_mask_30x[8];
-
- Label subloop11, subloop11mask;
- Label subloop30, subloop30mask;
- Label subloop31, subloop31mask;
- Label subloop96;
- Label subloop98, subloop98mask;
- Label subloop99;
-
- // Create mask
- mov(BO1, rcx);
- mov(rcx, M);
- sub(rcx, unroll_m - 16);
- mov(CO1, 16);
- cmp(rcx, 16);
-
- cmovg(rcx, CO1);
- mov(rax, 1);
- sal(rax, cl);
- sub(rax, 1);
- mov(rcx, 0xffff);
-
- if (unroll_m == 16) {
- kmovw(k1, eax);
- } else if (unroll_m == 32) {
- kmovw(k1, ecx);
- kmovw(k2, eax);
- } else {
- kmovw(k1, ecx);
- kmovw(k2, ecx);
- kmovw(k3, eax);
- }
- mov(rcx, BO1);
-
- and_(rax, 0xffff);
- cmp(rax, 0xffff);
- jne(subloop96, T_NEAR);
-
- if (isTransA) {
- do_pack(unroll_m);
- }
-
- mov(CO1, C);
- add(C, unroll_m * SIZE);
-
- mov(BO1, B);
- if (!isTransB) {
- lea(BO2, ptr[B + LDB * 4]);
- }
-
- if (!isTransA) {
- lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]);
- cmp(M, UNROLL_M);
- jg(subloop98, T_NEAR);
-
- mov(AA, ORIG_A);
- lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]);
- L(subloop98);
- }
-
- mov(LL, N);
- mov(I, LL);
- if (!isTransA) {
- // If N is too small, skip copy operation
- cmp(LL, UNROLL_N * 3);
- jle(subloop30, T_NEAR);
-
- // If A is not aligned to cache line
- cmp(FLAG, 0);
- je(subloop30, T_NEAR);
- } else {
- cmp(LL, UNROLL_N);
- jl(l_subloop_20x[1], T_NEAR);
- }
- align(16);
-
- if (!isTransA) {
- kernel(unroll_m, UNROLL_N, true, true);
- } else {
- kernel(unroll_m, UNROLL_N, false, false);
- }
-
- sub(I, UNROLL_N);
- cmp(I, UNROLL_N);
- jl(l_subloop_20x[1], T_NEAR);
- align(16);
-
- L(subloop11);
- kernel(unroll_m, UNROLL_N, false, false);
- sub(I, UNROLL_N);
- cmp(I, UNROLL_N);
- jge(subloop11, T_NEAR);
- align(16);
-
- for (int i = 1; i <= 7; i++) {
- L(l_subloop_20x[i]);
- cmp(I, i);
- if (i < 7) {
- jne(l_subloop_20x[i + 1], T_NEAR);
- } else {
- jne(subloop99, T_NEAR);
- }
- kernel(unroll_m, i, false, false);
- jmp(subloop99, T_NEAR);
- align(16);
- }
-
- if (!isTransA) {
- L(subloop30);
- cmp(I, UNROLL_N);
- jl(l_subloop_30x[1], T_NEAR);
- align(16);
-
- L(subloop31);
- kernel(unroll_m, UNROLL_N, true, false);
- sub(I, UNROLL_N);
- cmp(I, UNROLL_N);
- jge(subloop31, T_NEAR);
- align(16);
-
- for (int i = 1; i <= 7; i++) {
- L(l_subloop_30x[i]);
- cmp(I, i);
- if (i < 7) {
- jne(l_subloop_30x[i + 1], T_NEAR);
- } else {
- jne(subloop99, T_NEAR);
- }
- kernel(unroll_m, i, true, false);
- if (i < 7)
- jmp(subloop99, T_NEAR);
- align(16);
- }
- }
- jmp(subloop99, T_NEAR);
- align(16);
-
- L(subloop96);
- if (isTransA) {
- do_pack(unroll_m);
- }
-
- mov(CO1, C);
- add(C, unroll_m * SIZE);
- mov(BO1, B);
- if (!isTransB) {
- lea(BO2, ptr[B + LDB * 4]);
- }
-
- if (!isTransA) {
- lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]);
- cmp(M, UNROLL_M);
- jg(subloop98mask, T_NEAR);
- mov(AA, ORIG_A);
- lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]);
- L(subloop98mask);
- }
-
- mov(LL, N);
- mov(I, LL);
- if (!isTransA) {
- // If N is too small, skip copy operation
- cmp(LL, UNROLL_N * 3);
- jle(subloop30mask, T_NEAR);
-
- // If A is not aligned to cache line
- cmp(FLAG, 0);
- je(subloop30mask, T_NEAR);
- } else {
- cmp(LL, UNROLL_N);
- jl(l_subloop_mask_20x[1], T_NEAR);
- }
- align(16);
-
- if (!isTransA) {
- kernel(unroll_m, UNROLL_N, true, true, false);
- } else {
- kernel(unroll_m, UNROLL_N, false, false, false);
- }
-
- sub(I, UNROLL_N);
- cmp(I, UNROLL_N);
- jl(l_subloop_mask_20x[1], T_NEAR);
- align(16);
-
- L(subloop11mask);
- kernel(unroll_m, UNROLL_N, false, false, false);
- sub(I, UNROLL_N);
- cmp(I, UNROLL_N);
- jge(subloop11mask, T_NEAR);
- align(16);
-
- for (int i = 1; i <= 7; i++) {
- L(l_subloop_mask_20x[i]);
- cmp(I, i);
- if (i < 7) {
- jne(l_subloop_mask_20x[i + 1], T_NEAR);
- } else {
- jne(subloop99, T_NEAR);
- }
- kernel(unroll_m, i, false, false, false);
- jmp(subloop99, T_NEAR);
- align(16);
- }
-
- if (!isTransA) {
- L(subloop30mask);
- cmp(I, UNROLL_N);
- jl(l_subloop_mask_30x[1], T_NEAR);
- align(16);
-
- L(subloop31mask);
- kernel(unroll_m, UNROLL_N, true, false, false);
- sub(I, UNROLL_N);
- cmp(I, UNROLL_N);
- jge(subloop31mask, T_NEAR);
- align(16);
-
- for (int i = 1; i <= 7; i++) {
- L(l_subloop_mask_30x[i]);
- cmp(I, i);
- if (i < 7) {
- jne(l_subloop_mask_30x[i + 1], T_NEAR);
- } else {
- jne(subloop99, T_NEAR);
- }
- kernel(unroll_m, i, true, false, false);
- if (i < 7)
- jmp(subloop99, T_NEAR);
- align(16);
- }
- }
-
- L(subloop99);
- // Compute address for A
- if (!isTransA) {
- add(A, unroll_m * SIZE);
- } else {
- mov(rax, LDA);
- imul(rax, rax, unroll_m);
- add(A, rax);
- }
-
- // Compute next address of BIAS
- if (hasBias) {
- add(BIAS, unroll_m * SIZE);
- }
- };
-
- preamble();
-
- Label buffer_in_ws, buffer_allocated;
-
- // Get the registers
- mov(B, ARG_B);
- mov(LDB, ARG_LDB);
- mov(r15, ARG_BETA);
- mov(r12, ARG_C);
- if (hasBias)
- mov(r10, ARG_BIAS);
- mov(LDC, ARG_LDC);
- mov(rbp, rsp);
-
- vmovss(xmm0, ptr[ARG_ALPHA]);
- vmovss(xmm1, ptr[r15]);
-
-#if _WIN32
- mov(A, ARG_A);
- mov(LDA, ARG_LDA);
-#endif
-
- cmp(K, STACK_K_CAPACITY);
- jg(buffer_in_ws, T_NEAR);
-
- // Create buffer and align to 4kB page
- lea(rax, ptr[K * SIZE]);
- imul(rax, rax, 0x30);
- add(rax, 256);
- sub(rsp, rax);
- and_(rsp, -PAGE_4K);
- jmp(buffer_allocated, T_NEAR);
-
- L(buffer_in_ws);
- mov(rsp, ARG_WS);
-
- L(buffer_allocated);
-
- mov(ORIG_SP, rbp);
- mov(M, ARG_M);
- mov(N, ARG_N);
- mov(C, r12);
- if (hasBias)
- mov(BIAS, r10);
- vmovss(ALPHA, xmm0);
- vmovss(BETA, xmm1);
- sub(A, -OFFSET * SIZE);
- sub(B, -OFFSET * SIZE);
- mov(ORIG_A, A);
- sal(LDA, BASE_SHIFT);
- sal(LDB, BASE_SHIFT);
- sal(LDC, BASE_SHIFT);
- lea(LDB3, ptr[LDB + LDB * 2]);
-
- if (isTransA) {
- vpbroadcastq(zmm2, LDA);
- vpxorq(ZSTRIDE, ZSTRIDE, ZSTRIDE);
- mov(rax, -2);
- kmovw(k4, eax);
-
- for (int i = 0; i < 6; i++) {
- vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2);
- kshiftlw(k4, k4, 1);
- }
- vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2);
- }
-
- // Check A alignment and leading dimension; take copy-based path as
- // needed
- mov(rax, LDA);
- or_(rax, A);
- and_(rax, ver == ver_avx512_core ? 0x07 : 0x3f);
- mov(FLAG, rax);
-
- for (int i = 8; i < 16; i++) {
- for (int j = 0; j < 3; j++) {
- vpxorq(Zmm(i + 8 * j), Zmm(i + 8 * j), Zmm(i + 8 * j));
- }
- }
-
- Label main0, main1, main2, main999;
-
- cmp(M, 32);
- jle(main0, T_NEAR);
- align(16);
-
- L(main1);
- subloop(48);
- sub(M, UNROLL_M);
- cmp(M, 32);
- jg(main1, T_NEAR);
- align(16);
-
- L(main0);
- cmp(M, 16);
- jle(main2, T_NEAR);
-
- subloop(32);
- jmp(main999, T_NEAR);
- align(16);
-
- L(main2);
- cmp(M, 0);
- jle(main999, T_NEAR);
- subloop(16);
- align(16);
-
- L(main999);
- // Restore original stack
- mov(rsp, ORIG_SP);
-
- vzeroupper();
- postamble();
-
- ker_ = this->getCode<ker_t>();
- }
-
- typedef void (*ker_t)(dim_t m, dim_t n, dim_t k,
- const float *alpha, const float *a, dim_t lda,
- const float *b, dim_t ldb, const float *beta, float *c,
- dim_t ldc, const float *bias, float *ws);
-
- void operator()(dim_t m, dim_t n, dim_t k,
- const float *alpha, const float *a, dim_t lda,
- const float *b, dim_t ldb, const float *beta, float *c,
- dim_t ldc, const float *bias, float *ws) const
- {
- ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
- }
-
-private:
- ker_t ker_;
-};
-
-const xbyak_gemm *get_xbyak_gemm(
- bool isTransA, bool isTransB, float beta, bool hasBias) {
- auto beta_idx = [](float beta) {
- return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2);
- };
-
- // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)]
- static xbyak_gemm *kernel_table[2][2][2][3];
- static std::once_flag initialized;
- std::call_once(initialized, [=]{
- for (bool isTransA: {false, true})
- for (bool isTransB: {false, true})
- for (bool hasBias: {false, true})
- for (float beta: {0.0f, 1.0f, 2.0f}) {
- // nocopy sgemm with bias for beta != 0.0 is not supported
- if (hasBias && beta != 0.0)
- continue;
- kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] =
- new xbyak_gemm(isTransA, isTransB, beta, hasBias);
- }
- });
-
- return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)];
-}
-
-void sgemm_nocopy_driver(const char *transa,
- const char *transb, int m, int n, int k, const float *alpha,
- const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta,
- float *c, dim_t ldc, const float *bias, float *ws)
-{
- bool isTransA = (*transa == 'T' || *transa == 't');
- bool isTransB = (*transb == 'T' || *transb == 't');
-
- int Bm, sizeM, Bn, sizeN, Bk, sizeK;
-
- int i, j;
-
- if ((m <= 0) || (n <= 0))
- return;
-
- if ((k <= 0) || (alpha[0] == 0.)) {
-
- if (beta[0] == 0.) {
- for (j = 0; j < n; j++)
- for (i = 0; i < m; i++)
- c[i + j * ldc] = 0.0;
- } else if (beta[0] != 1.) {
- for (j = 0; j < n; j++)
- for (i = 0; i < m; i++)
- c[i + j * ldc] *= beta[0];
- }
-
- return;
- }
-
- assert(IMPLICATION(bias != nullptr, *beta == 0.0));
-
- // XXX: this happens on every thread...
- bool hasBias = (bias != nullptr);
- auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias);
- auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false);
- auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false);
- assert(ker_bn && ker_b1 && ker_b0);
-
- int BM = 4032, BN, BK;
- if (mayiuse(avx512_core)) {
- BN = isTransA ? 384 : 64;
- BK = 384;
- } else {
- BN = isTransA ? 96 : 64;
- BK = isTransB ? 96 : 192;
- if (!isTransA && !isTransB)
- BK = 128;
- }
- const float *curA, *curB, *curBias = nullptr;
- float *curC;
-
- for (Bk = 0; Bk < k; Bk += sizeK) {
- sizeK = k - Bk;
- if (sizeK >= BK * 2)
- sizeK = BK;
- else {
- if (sizeK > BK)
- sizeK = (sizeK + 1) / 2;
- }
-
- for (Bm = 0; Bm < m; Bm += sizeM) {
- sizeM = m - Bm;
- if (sizeM >= BM * 2)
- sizeM = BM;
- else {
- if (sizeM > BM + BM / 2)
- sizeM = (sizeM + 1) / 2;
- }
-
- for (Bn = 0; Bn < n; Bn += sizeN) {
- sizeN = n - Bn;
- if (sizeN >= BN * 2)
- sizeN = BN;
- else {
- if (sizeN > BN + BN / 2)
- sizeN = (sizeN + 1) / 2;
- }
-
- if (!isTransA) {
- curA = a + Bm + Bk * lda;
- } else {
- curA = a + Bk + Bm * lda;
- }
- if (!isTransB) {
- curB = b + Bk + Bn * ldb;
- } else {
- curB = b + Bn + Bk * ldb;
- }
- curC = c + Bm + (size_t)Bn * ldc;
- if (bias != nullptr) {
- if (Bk == 0) {
- curBias = bias + Bm;
- } else {
- curBias = nullptr;
- }
- }
- if (Bk == 0) {
- if (*beta == 0.0 && bias == nullptr)
- (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
- alpha, curA, lda, curB, ldb, beta, curC, ldc,
- curBias, ws);
- else
- (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
- alpha, curA, lda, curB, ldb, beta, curC, ldc,
- curBias, ws);
- } else {
- (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
- alpha, curA, lda, curB, ldb, beta, curC, ldc,
- curBias, ws);
- }
- }
- }
- }
-}
-
-}
-
-mkldnn_status_t jit_avx512_common_gemm_f32(
- const char *transa, const char *transb,
- const int *p_m, const int *p_n, const int *p_k, const float *p_alpha,
- const float *A, const int *p_lda, const float *B, const int *p_ldb,
- const float *p_beta, float *C, const int *p_ldc, const float *bias)
-{
- using namespace mkldnn::impl::utils;
- using namespace avx512_common_gemm_f32;
- using namespace gemm_utils;
-
- if (*p_beta != 0 && bias)
- return ref_gemm(transa, transb, p_m, p_n, p_k,
- p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias);
-
- int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
-
- int m = *p_m;
- int n = *p_n;
- int k = *p_k;
- dim_t lda = *p_lda;
- dim_t ldb = *p_ldb;
- dim_t ldc = *p_ldc;
- float beta = *p_beta;
- int MB, NB, KB;
-
- int nthr_m, nthr_n, nthr_k, nthr_mn;
-
- // Determine threading partitioning
- calc_nthr_nocopy_avx512_common(
- m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
- assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
-
- // May not happen, but just in case
- if (nthr < nthr_m * nthr_n * nthr_k)
- nthr = nthr_m * nthr_n * nthr_k;
-
- nthr_mn = nthr_m * nthr_n;
-
- unsigned char * ompstatus_ = nullptr;
- unsigned char volatile *ompstatus = nullptr;
-
- float *c_buffers = nullptr;
- float *ws_buffers = nullptr;
-
- if (nthr_k > 1) {
- ompstatus_ = (unsigned char *) malloc(
- nthr * CACHE_LINE_SIZE,
- CACHE_LINE_SIZE);
- ompstatus = (unsigned char volatile *) ompstatus_;
- assert(ompstatus);
-
- for (int i = 0; i < nthr; i++)
- ompstatus[i * CACHE_LINE_SIZE] = 0;
-
- c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
- * sizeof(float), PAGE_4K);
- }
-
- const size_t ws_elems_per_thr = (size_t)k * 48 + 64;
- const size_t ws_size_per_thr
- = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
- if (k > STACK_K_CAPACITY) {
- ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
- }
-
- parallel_nd(nthr, [&](const int ithr) {
- int ithr_m, ithr_n, ithr_k, ithr_mn;
- int m_from, m_to, myM;
- int n_from, n_to, myN;
- int k_from, k_to, myK;
- int cbase, ibase;
- const float *myA, *myB, *myBias = nullptr;
- float *myC = C, myBeta;
- float *ws = ws_buffers ?
- ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
- dim_t ld = ldc;
-
- int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k);
-
- if (ithr < nthr_m * nthr_n * nthr_k) {
-
- ithr_mn = ithr % nthr_mn;
- ithr_m = ithr_mn % nthr_m;
- ithr_n = ithr_mn / nthr_m;
- ithr_k = ithr / nthr_mn;
-
- /* swap ithr_k for performance improvement */
- if (ithr_k == 0)
- ithr_k = nthr_k - 1;
- else if (ithr_k == nthr_k - 1)
- ithr_k = 0;
-
- m_from = MB * (ithr_m);
- m_to = MB * (ithr_m + 1);
- if (m_to > m)
- m_to = m;
- myM = m_to - m_from;
-
- n_from = NB * (ithr_n);
- n_to = NB * (ithr_n + 1);
- if (n_to > n)
- n_to = n;
- myN = n_to - n_from;
-
- k_from = KB * (ithr_k);
- k_to = KB * (ithr_k + 1);
- if (k_to > k)
- k_to = k;
- myK = k_to - k_from;
-
- cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
- ibase = (ithr_m + nthr_m * ithr_n) * nthr_k;
-
- if ((myM > 0) && (myN > 0)) {
-
- if (*transa == 'N' || *transa == 'n') {
- myA = &(A[m_from + k_from * lda]);
- } else {
- myA = &(A[k_from + m_from * lda]);
- }
- if (*transb == 'N' || *transb == 'n') {
- myB = &(B[k_from + n_from * ldb]);
- } else {
- myB = &(B[n_from + k_from * ldb]);
- }
- if (ithr_k == 0) {
- myC = &(C[m_from + n_from * ldc]);
- myBeta = beta;
- ld = ldc;
- if (bias)
- myBias = &(bias[m_from]);
- } else {
- myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1);
- myBeta = 0.0;
- ld = MB;
- myBias = nullptr;
- }
-
- sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
- lda, myB, ldb, &myBeta, myC, ld, myBias, ws);
-
- if (nthr_k > 1 && !sum_later)
- ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1;
- }
-
- if (nthr_k > 1 && !sum_later) {
-
- // sum matrices partitioned along K dimension
- int n1, n2;
-
- partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
-
- if (ithr_k > 0) {
-
- myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
- + (dim_t)n1 * MB;
- /* need to wait until main thread finishes */
- while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) {
- };
-
- /* my cache is hot */
- sum_two_matrices(myM, n2, myC, MB,
- &C[m_from + (n_from + n1) * ldc], ldc);
- }
-
- for (int ik = 1; ik < nthr_k; ++ik) {
- if (ik != ithr_k) {
-
- myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
- + (dim_t)n1 * MB;
-
- while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) {
- };
-
- sum_two_matrices(myM, n2, myC, MB,
- &C[m_from + (n_from + n1) * ldc], ldc);
- }
- }
- }
- }
- });
-
-
- // handle C summation later
- if (nthr_k > 1 && ompstatus[0] == 0) {
-
- parallel_nd(nthr, [&](const int ithr) {
- int ithr_m, ithr_n, ithr_k, ithr_mn;
- int m_from, m_to, myM;
- int n_from, n_to, myN;
- int cbase;
- float *myC = C;
-
- if (ithr < nthr_m * nthr_n * nthr_k) {
-
- ithr_mn = ithr % nthr_mn;
- ithr_m = ithr_mn % nthr_m;
- ithr_n = ithr_mn / nthr_m;
- ithr_k = ithr / nthr_mn;
-
- /* swap ithr_k for performance improvement */
- if (ithr_k == 0)
- ithr_k = nthr_k - 1;
- else if (ithr_k == nthr_k - 1)
- ithr_k = 0;
-
- m_from = MB * (ithr_m);
- m_to = MB * (ithr_m + 1);
- if (m_to > m)
- m_to = m;
- myM = m_to - m_from;
-
- n_from = NB * (ithr_n);
- n_to = NB * (ithr_n + 1);
- if (n_to > n)
- n_to = n;
- myN = n_to - n_from;
-
- cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
-
- if (nthr_k > 1) {
- // sum matrices partitioned along K dimension
- int n1, n2;
-
- partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
-
- if (ithr_k > 0) {
-
- myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
- + (dim_t)n1 * MB;
-
- /* my cache is hot */
- sum_two_matrices(myM, n2, myC, MB,
- &C[m_from + (n_from + n1) * ldc], ldc);
- }
-
- for (int ik = 1; ik < nthr_k; ++ik) {
- if (ik != ithr_k) {
-
- myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
- + (dim_t)n1 * MB;
-
- sum_two_matrices(myM, n2, myC, MB,
- &C[m_from + (n_from + n1) * ldc], ldc);
- }
- }
- }
- }
- });
- }
-
- free(c_buffers);
- free(ompstatus_);
- free(ws_buffers);
-
- return mkldnn_success;
-}
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp
deleted file mode 100644
index d581b7fd71..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp
+++ /dev/null
@@ -1,36 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef JIT_AVX512_COMMON_GEMM_F32_HPP
-#define JIT_AVX512_COMMON_GEMM_F32_HPP
-
-#include "mkldnn_types.h"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-mkldnn_status_t jit_avx512_common_gemm_f32(
- const char *transa, const char *transb, const int *M,
- const int *N, const int *K, const float *alpha, const float *A,
- const int *lda, const float *B, const int *ldb, const float *beta,
- float *C, const int *ldc, const float *bias = nullptr);
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp
deleted file mode 100644
index 60d4220837..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp
+++ /dev/null
@@ -1,2705 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <cmath>
-#include <mutex>
-
-#include "mkldnn_thread.hpp"
-#include "utils.hpp"
-
-#include "ref_gemm_f32.hpp"
-#include "gemm_utils_f32.hpp"
-#include "jit_avx_gemm_f32.hpp"
-
-#include "jit_generator.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-#define CACHE_LINE_SIZE 64
-
-#define STACKSIZE get_size_of_abi_save_regs()
-#if _WIN32
-#define STACK_K_CAPACITY 128
-#else
-#define STACK_K_CAPACITY 8192
-#endif
-#define SIZE 4
-#define OFFSET 32
-#define BASE_SHIFT 2
-#define SECOND_FETCH 14
-
-namespace avx_gemm_f32 {
-using namespace gemm_utils;
-
-struct xbyak_gemm : public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_gemm_f32_xbyak_gemm)
-
- xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false,
- void *code_ptr = nullptr,
- size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE)
- : jit_generator(code_ptr, code_size)
- {
- using namespace Xbyak;
-
- const bool is_avx2 = mayiuse(avx2);
- assert(IMPLICATION(!is_avx2, mayiuse(avx)));
-
- const int UNROLL_M = is_avx2 ? 16 : 8;
- const int UNROLL_N = 6;
-
- bool isBeta0 = (beta == 0.0);
- bool isBetaN = (!isBeta0 && beta != 1.0);
-
- // various definitions for convenience
- auto ARG_M = abi_param1;
- auto ARG_N = abi_param2;
- auto K = abi_param3;
- auto ARG_ALPHA = abi_param4;
-#ifdef _WIN32
- auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE];
- auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE +
- sizeof(float *) + STACKSIZE];
- const auto stackOffset = OFFSET_SHADOWSPACE +
- sizeof(float *) + STACKSIZE;
- auto A = rsi;
- auto LDA = rdi;
-#else
- auto ARG_A = r8;
- auto ARG_LDA = r9;
- const auto stackOffset = STACKSIZE;
- auto A = ARG_A;
- auto LDA = ARG_LDA;
-#endif
- auto ARG_B = ptr[rsp + 8 + stackOffset];
- auto ARG_LDB = ptr[rsp + 16 + stackOffset];
- auto ARG_BETA = ptr[rsp + 24 + stackOffset];
- auto ARG_C = ptr[rsp + 32 + stackOffset];
- auto ARG_LDC = ptr[rsp + 40 + stackOffset];
- auto ARG_BIAS = ptr[rsp + 48 + stackOffset];
- auto ARG_WS = ptr[rsp + 56 + stackOffset];
-
- auto B = r11;
- auto LDB = rbx;
- auto LDC = r13;
- auto LL = rax;
- auto AO1 = abi_param2;
- auto BO1 = abi_param4;
- auto BO2 = rbp;
- auto CO1 = r14;
- auto CO2 = r15;
- auto LDB3 = r10;
- auto LDA4 = abi_param1;
- auto AA = r12;
- auto BIAS1 = abi_param1;
-
- auto M = qword[rsp + 0];
- auto N = qword[rsp + 8];
- auto FLAG = qword[rsp + 16];
- auto I = qword[rsp + 24];
- auto C = qword[rsp + 32];
- auto BIAS = qword[rsp + 40];
- auto ALPHA = qword[rsp + 48];
- auto BETA = qword[rsp + 64];
- auto ORIG_A = qword[rsp + 80];
- auto MASK = dword[rsp + 88];
- auto STRIDE = qword[rsp + 120];
- auto ORIG_SP = qword[rsp + 152];
-
- auto VALPHA = ymm1;
- auto VBETA = ymm2;
- auto VMASK = ymm3;
- auto VBIAS1 = ymm2;
- auto VBIAS2 = ymm4;
-
- auto PREFETCHSIZEA = 128;
- auto PREFETCHSIZEB = (!isTransB) ? -16 : 0;
-
- // Function for packing if needed
- auto do_pack = [&](
- int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
- Label pack2, pack3, pack4, pack10;
-
- int regIdx;
- Reg64 reg;
-
- mov(BO1, A);
- lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
-
- if (isTransA) {
- lea(BO2, ptr[BO1 + LDA * 4]);
- lea(CO1, ptr[LDA + LDA * 2]);
- vmovupd(ymm7, STRIDE);
- }
-
- mov(LL, K);
- sar(LL, 2);
- jle(pack3, T_NEAR);
- align(16);
-
- L(pack2);
- if (!isTransA) {
- for (int i = 0; i < 4; i++) {
- regIdx = (i % 2 == 0) ? 4 : 6;
- if (isLoad1Unmasked) {
- vmovups(Ymm(regIdx),
- ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
- } else {
- vmaskmovps(Ymm(regIdx), VMASK,
- ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
- }
- if (unroll_m > 8) {
- if (isLoad2Unmasked) {
- vmovups(Ymm(regIdx + 1),
- ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
- } else {
- vmaskmovps(Ymm(regIdx + 1), VMASK,
- ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
- }
- }
- add(BO1, LDA);
-
- vmovups(ptr[AO1 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
- Ymm(regIdx));
- if (unroll_m > 8) {
- vmovups(ptr[AO1
- + (unroll_m * i + 1 * 8 - OFFSET)
- * SIZE],
- Ymm(regIdx + 1));
- }
- }
-
- } else {
- if (isLoad1Unmasked) {
- for (int i = 0; i < 2; i++) {
- reg = (i % 2 == 0) ? BO1 : BO2;
- vmovups(xmm0, ptr[reg + (0 * 8 - OFFSET) * SIZE]);
- vmovups(xmm1,
- ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
- lea(BO2, ptr[reg + LDA * 2]);
- vunpcklps(xmm4, xmm0, xmm1);
- vunpckhps(xmm5, xmm0, xmm1);
- vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
- vmovups(xmm1,
- ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
- lea(BO2, ptr[BO2 + LDA * 2]);
- vunpcklps(xmm6, xmm0, xmm1);
- vunpckhps(xmm2, xmm0, xmm1);
-
- vunpcklpd(xmm0, xmm4, xmm6);
- vunpckhpd(xmm1, xmm4, xmm6);
- vmovups(ptr[AO1
- + (unroll_m * 0 + i * 4 - OFFSET)
- * SIZE],
- xmm0);
- vmovups(ptr[AO1
- + (unroll_m * 1 + i * 4 - OFFSET)
- * SIZE],
- xmm1);
- vunpcklpd(xmm0, xmm5, xmm2);
- vunpckhpd(xmm1, xmm5, xmm2);
- vmovups(ptr[AO1
- + (unroll_m * 2 + i * 4 - OFFSET)
- * SIZE],
- xmm0);
- vmovups(ptr[AO1
- + (unroll_m * 3 + i * 4 - OFFSET)
- * SIZE],
- xmm1);
- }
- } else if (is_avx2) {
- for (int i = 0; i < 2; i++) {
- vmovaps(xmm4, xmm3);
- vgatherqps(xmm0,
- ptr[BO1 + ymm7 + ((2 * i) - OFFSET) * SIZE],
- xmm4);
- vmovaps(xmm4, xmm3);
- vgatherqps(xmm1,
- ptr[BO1 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE],
- xmm4);
-
- vmovups(ptr[AO1
- + (unroll_m * (2 * i) + 0 * 4 - OFFSET)
- * SIZE],
- xmm0);
- vmovups(ptr[AO1
- + (unroll_m * (2 * i + 1) + 0 * 4
- - OFFSET)
- * SIZE],
- xmm1);
- }
-
- lea(BO2, ptr[BO1 + LDA * 4]);
-
- for (int i = 0; i < 2; i++) {
- vextractf128(xmm4, ymm3, 1);
- vgatherqps(xmm0,
- ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
- xmm4);
- vextractf128(xmm4, ymm3, 1);
- vgatherqps(xmm1,
- ptr[BO2 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE],
- xmm4);
-
- vmovups(ptr[AO1
- + (unroll_m * (2 * i) + 1 * 4 - OFFSET)
- * SIZE],
- xmm0);
- vmovups(ptr[AO1
- + (unroll_m * (2 * i + 1) + 1 * 4
- - OFFSET)
- * SIZE],
- xmm1);
- }
-
- lea(BO2, ptr[BO2 + LDA * 4]);
- } else {
- vxorps(xmm4, xmm4, xmm4);
- lea(BO2, ptr[BO1 + LDA * 4]);
-
- auto el_cp = [&](int section, int ld_step) {
- RegExp src_addr = section == 0 ? BO1 : BO2;
- if (ld_step == 1 || ld_step == 2)
- src_addr = src_addr + LDA * ld_step;
- else if (ld_step == 3)
- src_addr = src_addr + CO1;
- src_addr = src_addr - OFFSET * SIZE;
-
- vmovups(Xmm(ld_step % 2), ptr[src_addr]);
- RegExp dst_addr = AO1
- + (ld_step + section * 4 - OFFSET) * SIZE;
- for (int off = 0; off < 4; ++off)
- pextrd(ptr[dst_addr + unroll_m * off * SIZE],
- Xmm(ld_step % 2), off);
- };
-
- Label l_end;
- el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR);
- el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR);
- el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR);
- el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR);
- el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR);
- el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR);
- el_cp(1, 2);
- L(l_end);
-
- lea(BO2, ptr[BO2 + LDA * 4]);
- }
-
- if (unroll_m >= 16) {
- assert(is_avx2);
- if (isLoad2Unmasked) {
- for (int i = 0; i < 2; i++) {
- vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
- vmovups(xmm1, ptr[BO2 + LDA * 1
- + (0 * 8 - OFFSET) * SIZE]);
- lea(BO2, ptr[BO2 + LDA * 2]);
- vunpcklps(xmm4, xmm0, xmm1);
- vunpckhps(xmm5, xmm0, xmm1);
- vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
- vmovups(xmm1, ptr[BO2 + LDA * 1
- + (0 * 8 - OFFSET) * SIZE]);
- if (i == 0)
- lea(BO2, ptr[BO2 + LDA * 2]);
- vunpcklps(xmm6, xmm0, xmm1);
- vunpckhps(xmm2, xmm0, xmm1);
-
- vunpcklpd(xmm0, xmm4, xmm6);
- vunpckhpd(xmm1, xmm4, xmm6);
- vmovups(ptr[AO1
- + (unroll_m * 0 + (i + 2) * 4
- - OFFSET)
- * SIZE],
- xmm0);
- vmovups(ptr[AO1
- + (unroll_m * 1 + (i + 2) * 4
- - OFFSET)
- * SIZE],
- xmm1);
- vunpcklpd(xmm0, xmm5, xmm2);
- vunpckhpd(xmm1, xmm5, xmm2);
- vmovups(ptr[AO1
- + (unroll_m * 2 + (i + 2) * 4
- - OFFSET)
- * SIZE],
- xmm0);
- vmovups(ptr[AO1
- + (unroll_m * 3 + (i + 2) * 4
- - OFFSET)
- * SIZE],
- xmm1);
- }
- } else {
- for (int i = 0; i < 2; i++) {
- vmovaps(xmm4, xmm3);
- vgatherqps(xmm0,
- ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
- xmm4);
- vmovaps(xmm4, xmm3);
- vgatherqps(xmm1,
- ptr[BO2 + ymm7
- + ((2 * i + 1) - OFFSET) * SIZE],
- xmm4);
-
- vmovups(ptr[AO1
- + (unroll_m * (2 * i) + 2 * 4
- - OFFSET)
- * SIZE],
- xmm0);
- vmovups(ptr[AO1
- + (unroll_m * (2 * i + 1) + 2 * 4
- - OFFSET)
- * SIZE],
- xmm1);
- }
-
- lea(BO2, ptr[BO2 + LDA * 4]);
-
- for (int i = 0; i < 2; i++) {
- vextractf128(xmm4, ymm3, 1);
- vgatherqps(xmm0,
- ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
- xmm4);
- vextractf128(xmm4, ymm3, 1);
- vgatherqps(xmm1,
- ptr[BO2 + ymm7
- + ((2 * i + 1) - OFFSET) * SIZE],
- xmm4);
-
- vmovups(ptr[AO1
- + (unroll_m * (2 * i) + 3 * 4
- - OFFSET)
- * SIZE],
- xmm0);
- vmovups(ptr[AO1
- + (unroll_m * (2 * i + 1) + 3 * 4
- - OFFSET)
- * SIZE],
- xmm1);
- }
-
- lea(BO2, ptr[BO2 + LDA * 4]);
- }
- }
- add(BO1, (4 * SIZE));
- }
-
- add(AO1, unroll_m * 4 * SIZE);
- sub(LL, 1);
- jg(pack2, T_NEAR);
- align(16);
-
- L(pack3);
- mov(LL, K);
- and_(LL, 3);
- jle(pack10, T_NEAR);
- align(16);
-
- L(pack4);
- if (!isTransA) {
- if (isLoad1Unmasked) {
- vmovups(ymm4, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
- } else {
- vmaskmovps(ymm4, VMASK, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
- }
- if (unroll_m > 8) {
- if (isLoad2Unmasked) {
- vmovups(ymm5, ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
- } else {
- vmaskmovps(ymm5, VMASK,
- ptr[BO1 + (1 + 8 - OFFSET) * SIZE]);
- }
- }
- add(BO1, LDA);
- vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
- ymm4);
- if (unroll_m > 8) {
- vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE],
- ymm5);
- }
- } else {
- if (isLoad1Unmasked) {
- for (int i = 0; i < 2; i++) {
- reg = (i % 2 == 0) ? BO1 : BO2;
- vmovss(Xmm(i + 1), ptr[reg + (0 * 8 - OFFSET) * SIZE]);
- vmovss(xmm0,
- ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
- lea(BO2, ptr[reg + LDA * 2]);
- vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
- }
- vunpcklpd(xmm1, xmm1, xmm2);
- vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE],
- xmm1);
-
- for (int i = 0; i < 2; i++) {
- vmovss(Xmm(i + 1), ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
- vmovss(xmm0,
- ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
- lea(BO2, ptr[BO2 + LDA * 2]);
- vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
- }
- vunpcklpd(xmm1, xmm1, xmm2);
- vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE],
- xmm1);
- } else if (is_avx2) {
- vmovaps(xmm4, xmm3);
- vgatherqps(xmm1, ptr[BO1 + ymm7 + (0 * 8 - OFFSET) * SIZE],
- xmm4);
- lea(BO2, ptr[BO1 + LDA * 4]);
- vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE],
- xmm1);
-
- vextractf128(xmm4, ymm3, 1);
- vgatherqps(xmm1, ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
- xmm4);
- lea(BO2, ptr[BO2 + LDA * 4]);
- vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE],
- xmm1);
- } else {
- vxorps(xmm4, xmm4, xmm4);
- lea(BO2, ptr[BO1 + LDA * 4]);
-
- auto el_cp = [&](int section, int ld_step) {
- RegExp src_addr = section == 0 ? BO1 : BO2;
- if (ld_step == 1 || ld_step == 2)
- src_addr = src_addr + LDA * ld_step;
- else if (ld_step == 3)
- src_addr = src_addr + CO1;
- src_addr = src_addr - OFFSET * SIZE;
-
- vmovss(xmm1, ptr[src_addr]);
- RegExp dst_addr = AO1
- + (ld_step + section * 4 - OFFSET) * SIZE;
- movss(ptr[dst_addr], xmm1);
- };
-
- Label l_end;
- el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR);
- el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR);
- el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR);
- el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR);
- el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR);
- el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR);
- el_cp(1, 2);
- L(l_end);
-
- lea(BO2, ptr[BO2 + LDA * 4]);
- }
-
- if (unroll_m >= 16) {
- assert(is_avx2);
- if (isLoad2Unmasked) {
- for (int i = 0; i < 2; i++) {
- vmovss(Xmm(i + 1),
- ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
- vmovss(xmm0, ptr[BO2 + LDA * 1
- + (0 * 8 - OFFSET) * SIZE]);
- lea(BO2, ptr[BO2 + LDA * 2]);
- vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
- }
- vunpcklpd(xmm1, xmm1, xmm2);
- } else {
- vmovaps(xmm4, xmm3);
- vgatherqps(xmm1,
- ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
- xmm4);
- lea(BO2, ptr[BO2 + LDA * 4]);
- }
- vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 4 - OFFSET) * SIZE],
- xmm1);
-
- if (isLoad2Unmasked) {
- for (int i = 0; i < 2; i++) {
- vmovss(Xmm(i + 1),
- ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
- vmovss(xmm0, ptr[BO2 + LDA * 1
- + (0 * 8 - OFFSET) * SIZE]);
- lea(BO2, ptr[BO2 + LDA * 2]);
- vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
- }
- vunpcklpd(xmm1, xmm1, xmm2);
- } else {
- vextractf128(xmm4, ymm3, 1);
- vgatherqps(xmm1,
- ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
- xmm4);
- }
- vmovups(ptr[AO1 + (unroll_m * 0 + 3 * 4 - OFFSET) * SIZE],
- xmm1);
- }
- add(BO1, SIZE);
- }
-
- add(AO1, unroll_m * SIZE);
- sub(LL, 1);
- jg(pack4, T_NEAR);
- align(16);
-
- L(pack10);
- };
-
- // Fused multiply add; may become one or two instructions
- auto fma = [&](bool useFma, Ymm reg0, Ymm reg1, Ymm reg2,
- bool overWrite = false) {
- if (useFma) {
- if (is_avx2) {
- vfmadd231ps(reg2, reg1, reg0);
- } else {
- assert(UNROLL_M == 8);
- auto tent_vreg = overWrite ? reg1 : ymm1;
- vmulps(tent_vreg, reg1, reg0);
- vaddps(reg2, reg2, tent_vreg);
- }
- } else {
- if (!overWrite) {
- vmulps(ymm15, reg1, reg0);
- vaddps(reg2, reg2, ymm15);
- } else {
- vmulps(reg1, reg1, reg0);
- vaddps(reg2, reg2, reg1);
- }
- }
- };
-
- // Inner kernel with k=8
- auto innerkernel8 = [&](int unroll_m, int unroll_n,
- bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
- bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
- Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
- Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
- Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
- Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
- Ymm reg23) {
-
- Ymm fmareg;
-
- if (!isDirect) {
- prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]);
- } else {
- prefetcht0(ptr[AO1 + LDA4]);
- }
-
- for (int i = 0; i < 8; i++) {
- if (isDirect) {
- if (isLoad1Unmasked) {
- vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
- } else {
- vmaskmovps(ymm0, VMASK,
- ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
- }
- if (unroll_m >= 16) {
- if (isLoad2Unmasked) {
- vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
- } else {
- vmaskmovps(ymm1, VMASK,
- ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
- }
- }
- add(AO1, LDA);
- }
-
- if (!isTransB) {
- vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
- }
- fmareg = (i % 2 == 0) ? reg00 : reg12;
- fma(useFma, ymm0, ymm2, fmareg);
- if (unroll_m >= 16) {
- fmareg = (i % 2 == 0) ? reg06 : reg18;
- fma(useFma, ymm1, ymm2, fmareg);
- }
- if (i == 0) {
- if (!isTransB) {
- prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]);
- }
- }
- if (unroll_n >= 2) {
- if (!isTransB) {
- if (i == 1) {
- prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]);
- }
- vbroadcastss(
- ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
- }
- fmareg = (i % 2 == 0) ? reg01 : reg13;
- fma(useFma, ymm0, ymm2, fmareg);
- if (unroll_m >= 16) {
- fmareg = (i % 2 == 0) ? reg07 : reg19;
- fma(useFma, ymm1, ymm2, fmareg);
- }
- }
-
- if (isCopy) {
- vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
- ymm0);
- if (unroll_m >= 16) {
- vmovups(ptr[LDA4
- + (unroll_m * i + 1 * 8 - OFFSET)
- * SIZE],
- ymm1);
- }
- if (i == 7) {
- sub(LDA4, -unroll_m * 8 * SIZE);
- }
- }
-
- if (unroll_n >= 3) {
- if (!isTransB) {
- if (i == 2) {
- prefetcht0(
- ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
- }
- vbroadcastss(
- ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
- }
- fmareg = (i % 2 == 0) ? reg02 : reg14;
- fma(useFma, ymm0, ymm2, fmareg);
- if (unroll_m >= 16) {
- fmareg = (i % 2 == 0) ? reg08 : reg20;
- fma(useFma, ymm1, ymm2, fmareg);
- }
- }
-
- if (i == 7) {
- if (!isTransB) {
- sub(BO1, -8 * SIZE);
- }
- }
-
- if (unroll_n >= 4) {
- if (!isTransB) {
- if (i == 3) {
- prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]);
- }
- vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
- }
- fmareg = (i % 2 == 0) ? reg03 : reg15;
- fma(useFma, ymm0, ymm2, fmareg);
- if (unroll_m >= 16) {
- fmareg = (i % 2 == 0) ? reg09 : reg21;
- fma(useFma, ymm1, ymm2, fmareg);
- }
- }
-
- if (unroll_n >= 5) {
- if (!isTransB) {
- if (i == 4) {
- prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]);
- }
- vbroadcastss(
- ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
- }
- fmareg = (i % 2 == 0) ? reg04 : reg16;
- fma(useFma, ymm0, ymm2, fmareg);
- if (unroll_m >= 16) {
- fmareg = (i % 2 == 0) ? reg10 : reg22;
- fma(useFma, ymm1, ymm2, fmareg);
- }
- }
-
- if (unroll_n >= 6) {
- if (!isTransB) {
- if (i == 5) {
- prefetcht0(
- ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]);
- }
- vbroadcastss(
- ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
- }
- fmareg = (i % 2 == 0) ? reg05 : reg17;
- fma(useFma, ymm0, ymm2, fmareg);
- if (unroll_m >= 16) {
- fmareg = (i % 2 == 0) ? reg11 : reg23;
- fma(useFma, ymm1, ymm2, fmareg);
- }
- }
- if (isTransB) {
- prefetcht0(ptr[BO1 + BO2]);
- add(BO1, LDB);
- }
-
- if (i == 0) {
- if (unroll_m >= 4) {
- if (!isDirect) {
- prefetcht0(
- ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]);
- } else {
- prefetcht0(ptr[AO1 + LDA4]);
- }
- }
- }
- if (i == 1 || i == 2) {
- if (unroll_m >= 8) {
- if (!isDirect) {
- prefetcht0(ptr[AO1
- + (PREFETCHSIZEA + (2 + 2 * i) * 8)
- * SIZE]);
- } else {
- prefetcht0(ptr[AO1 + LDA4]);
- }
- }
- }
- if (i == 3 || i == 4 || i == 5 || i == 6) {
- if (unroll_m >= 16) {
- if (!isDirect) {
- prefetcht0(ptr[AO1
- + (PREFETCHSIZEA + (2 + 2 * i) * 8)
- * SIZE]);
- } else {
- prefetcht0(ptr[AO1 + LDA4]);
- }
- }
- }
- if (i == 7) {
- if (!isTransB) {
- if (unroll_n >= 4) {
- sub(BO2, -8 * SIZE);
- }
- }
- if (!isTransA) {
- prefetcht2(ptr[AA]);
- lea(AA, ptr[AA + LDA]);
- }
- }
-
- if (!isDirect) {
- if (isLoad1Unmasked) {
- vmovups(ymm0,
- ptr[AO1
- + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
- * SIZE]);
- } else {
- vmaskmovps(
- ymm0, VMASK,
- ptr[AO1
- + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
- * SIZE]);
- }
- if (unroll_m >= 16) {
- if (isLoad2Unmasked) {
- vmovups(ymm1, ptr[AO1
- + (unroll_m * (i + 1) + 1 * 8
- - OFFSET)
- * SIZE]);
- } else {
- vmaskmovps(ymm1, VMASK,
- ptr[AO1
- + (unroll_m * (i + 1) + 1 * 8
- - OFFSET)
- * SIZE]);
- }
- }
- }
- }
-
- if (!isDirect) {
- sub(AO1, -unroll_m * 8 * SIZE);
- }
- sub(LL, 1);
-
- };
-
- // Inner kernel with k=4
- auto innerkernel4 = [&](int unroll_m, int unroll_n,
- bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
- bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
- Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
- Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
- Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
- Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
- Ymm reg23) {
-
- Ymm fmareg;
-
- if (!isDirect) {
- prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]);
- } else {
- prefetcht0(ptr[AO1 + LDA4]);
- }
-
- for (int i = 0; i < 4; i++) {
- if (isDirect) {
- if (isLoad1Unmasked) {
- vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
- } else {
- vmaskmovps(ymm0, VMASK,
- ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
- }
- if (unroll_m >= 16) {
- if (isLoad2Unmasked) {
- vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
- } else {
- vmaskmovps(ymm1, VMASK,
- ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
- }
- }
- add(AO1, LDA);
- }
-
- if (!isTransB) {
- vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
- }
- fmareg = (i % 2 == 0) ? reg00 : reg12;
- fma(useFma, ymm0, ymm2, fmareg);
- if (unroll_m >= 16) {
- fmareg = (i % 2 == 0) ? reg06 : reg18;
- fma(useFma, ymm1, ymm2, fmareg);
- }
- if (i == 0) {
- if (!isTransB) {
- prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]);
- }
- }
- if (unroll_n >= 2) {
- if (!isTransB) {
- if (i == 1) {
- prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]);
- }
- vbroadcastss(
- ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
- }
- fmareg = (i % 2 == 0) ? reg01 : reg13;
- fma(useFma, ymm0, ymm2, fmareg);
- if (unroll_m >= 16) {
- fmareg = (i % 2 == 0) ? reg07 : reg19;
- fma(useFma, ymm1, ymm2, fmareg);
- }
- }
-
- if (isCopy) {
- vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
- ymm0);
- if (unroll_m >= 16) {
- vmovups(ptr[LDA4
- + (unroll_m * i + 1 * 8 - OFFSET)
- * SIZE],
- ymm1);
- }
- if (i == 3) {
- sub(LDA4, -unroll_m * 4 * SIZE);
- }
- }
-
- if (unroll_n >= 3) {
- if (!isTransB) {
- if (i == 2) {
- prefetcht0(
- ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
- }
- vbroadcastss(
- ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
- }
- fmareg = (i % 2 == 0) ? reg02 : reg14;
- fma(useFma, ymm0, ymm2, fmareg);
- if (unroll_m >= 16) {
- fmareg = (i % 2 == 0) ? reg08 : reg20;
- fma(useFma, ymm1, ymm2, fmareg);
- }
- }
-
- if (i == 7) {
- if (!isTransB) {
- sub(BO1, -8 * SIZE);
- }
- }
-
- if (unroll_n >= 4) {
- if (!isTransB) {
- if (i == 3) {
- prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]);
- }
- vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
- }
- fmareg = (i % 2 == 0) ? reg03 : reg15;
- fma(useFma, ymm0, ymm2, fmareg);
- if (unroll_m >= 16) {
- fmareg = (i % 2 == 0) ? reg09 : reg21;
- fma(useFma, ymm1, ymm2, fmareg);
- }
- }
-
- if (unroll_n >= 5) {
- if (!isTransB) {
- if (i == 4) {
- prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]);
- }
- vbroadcastss(
- ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
- }
- fmareg = (i % 2 == 0) ? reg04 : reg16;
- fma(useFma, ymm0, ymm2, fmareg);
- if (unroll_m >= 16) {
- fmareg = (i % 2 == 0) ? reg10 : reg22;
- fma(useFma, ymm1, ymm2, fmareg);
- }
- }
-
- if (unroll_n >= 6) {
- if (!isTransB) {
- if (i == 5) {
- prefetcht0(
- ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]);
- }
- vbroadcastss(
- ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
- }
- fmareg = (i % 2 == 0) ? reg05 : reg17;
- fma(useFma, ymm0, ymm2, fmareg);
- if (unroll_m >= 16) {
- fmareg = (i % 2 == 0) ? reg11 : reg23;
- fma(useFma, ymm1, ymm2, fmareg);
- }
- }
- if (isTransB) {
- prefetcht0(ptr[BO1 + BO2]);
- add(BO1, LDB);
- }
-
- if (i == 0) {
- if (unroll_m >= 4) {
- if (!isDirect) {
- prefetcht0(
- ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]);
- } else {
- prefetcht0(ptr[AO1 + LDA4]);
- }
- }
- }
- if (i == 1 || i == 2) {
- if (unroll_m >= 8) {
- if (!isDirect) {
- prefetcht0(ptr[AO1
- + (PREFETCHSIZEA + (2 + 2 * i) * 8)
- * SIZE]);
- } else {
- prefetcht0(ptr[AO1 + LDA4]);
- }
- }
- }
- if (i == 3) {
- if (!isTransB) {
- sub(BO1, -4 * SIZE);
- if (unroll_n >= 4) {
- sub(BO2, -4 * SIZE);
- }
- }
- }
-
- if (!isDirect) {
- if (isLoad1Unmasked) {
- vmovups(ymm0,
- ptr[AO1
- + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
- * SIZE]);
- } else {
- vmaskmovps(
- ymm0, VMASK,
- ptr[AO1
- + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
- * SIZE]);
- }
- if (unroll_m >= 16) {
- if (isLoad2Unmasked) {
- vmovups(ymm1, ptr[AO1
- + (unroll_m * (i + 1) + 1 * 8
- - OFFSET)
- * SIZE]);
- } else {
- vmaskmovps(ymm1, VMASK,
- ptr[AO1
- + (unroll_m * (i + 1) + 1 * 8
- - OFFSET)
- * SIZE]);
- }
- }
- }
- }
-
- if (!isDirect) {
- sub(AO1, -unroll_m * 4 * SIZE);
- }
-
- };
-
- // Inner kernel with k=2
- auto innerkernel2 = [&](int unroll_m, int unroll_n,
- bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
- bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
- Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
- Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
- Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
- Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
- Ymm reg23) {
-
- Ymm fmareg;
-
- for (int i = 0; i < 2; i++) {
- if (isDirect) {
- if (isLoad1Unmasked) {
- vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
- } else {
- vmaskmovps(ymm0, VMASK,
- ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
- }
- if (unroll_m >= 16) {
- if (isLoad2Unmasked) {
- vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
- } else {
- vmaskmovps(ymm1, VMASK,
- ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
- }
- }
- add(AO1, LDA);
- }
-
- if (!isTransB) {
- vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
- }
- fmareg = (i % 2 == 0) ? reg00 : reg12;
- fma(useFma, ymm0, ymm2, fmareg);
- if (unroll_m >= 16) {
- fmareg = (i % 2 == 0) ? reg06 : reg18;
- fma(useFma, ymm1, ymm2, fmareg);
- }
- if (unroll_n >= 2) {
- if (!isTransB) {
- vbroadcastss(
- ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
- }
- fmareg = (i % 2 == 0) ? reg01 : reg13;
- fma(useFma, ymm0, ymm2, fmareg);
- if (unroll_m >= 16) {
- fmareg = (i % 2 == 0) ? reg07 : reg19;
- fma(useFma, ymm1, ymm2, fmareg);
- }
- }
-
- if (unroll_n >= 3) {
- if (!isTransB) {
- if (i == 2) {
- prefetcht0(
- ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
- }
- vbroadcastss(
- ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
- }
- fmareg = (i % 2 == 0) ? reg02 : reg14;
- fma(useFma, ymm0, ymm2, fmareg);
- if (unroll_m >= 16) {
- fmareg = (i % 2 == 0) ? reg08 : reg20;
- fma(useFma, ymm1, ymm2, fmareg);
- }
- }
-
- if (unroll_n >= 4) {
- if (!isTransB) {
- vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
- }
- fmareg = (i % 2 == 0) ? reg03 : reg15;
- fma(useFma, ymm0, ymm2, fmareg);
- if (unroll_m >= 16) {
- fmareg = (i % 2 == 0) ? reg09 : reg21;
- fma(useFma, ymm1, ymm2, fmareg);
- }
- }
-
- if (unroll_n >= 5) {
- if (!isTransB) {
- vbroadcastss(
- ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
- }
- fmareg = (i % 2 == 0) ? reg04 : reg16;
- fma(useFma, ymm0, ymm2, fmareg);
- if (unroll_m >= 16) {
- fmareg = (i % 2 == 0) ? reg10 : reg22;
- fma(useFma, ymm1, ymm2, fmareg);
- }
- }
-
- if (unroll_n >= 6) {
- if (!isTransB) {
- vbroadcastss(
- ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
- }
- fmareg = (i % 2 == 0) ? reg05 : reg17;
- fma(useFma, ymm0, ymm2, fmareg);
- if (unroll_m >= 16) {
- fmareg = (i % 2 == 0) ? reg11 : reg23;
- fma(useFma, ymm1, ymm2, fmareg);
- }
- }
-
- if (isCopy) {
- vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
- ymm0);
- if (unroll_m >= 16) {
- vmovups(ptr[LDA4
- + (unroll_m * 0 + 1 * 8 - OFFSET)
- * SIZE],
- ymm1);
- }
- sub(LDA4, -unroll_m * SIZE);
- }
-
- if (!isDirect) {
- if (isLoad1Unmasked) {
- vmovups(ymm0, ptr[AO1
- + (unroll_m * 1 + 0 * 8 - OFFSET)
- * SIZE]);
- } else {
- vmaskmovps(ymm0, VMASK,
- ptr[AO1
- + (unroll_m * 1 + 0 * 8 - OFFSET)
- * SIZE]);
- }
- if (unroll_m >= 16) {
- if (isLoad2Unmasked) {
- vmovups(ymm1,
- ptr[AO1
- + (unroll_m * 1 + 1 * 8 - OFFSET)
- * SIZE]);
- } else {
- vmaskmovps(ymm1, VMASK,
- ptr[AO1
- + (unroll_m * 1 + 1 * 8 - OFFSET)
- * SIZE]);
- }
- }
- sub(AO1, -unroll_m * SIZE);
- }
-
- if (!isTransB) {
- sub(BO1, -SIZE);
- if (unroll_n >= 4) {
- sub(BO2, -SIZE);
- }
- } else {
- add(BO1, LDB);
- }
- }
-
- };
-
- // Inner kernel with k=1
- auto innerkernel1 = [&](int unroll_m, int unroll_n,
- bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
- bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
- Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
- Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11) {
-
- if (isDirect) {
- if (isLoad1Unmasked) {
- vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
- } else {
- vmaskmovps(ymm0, VMASK, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
- }
- if (unroll_m >= 16) {
- if (isLoad2Unmasked) {
- vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
- } else {
- vmaskmovps(ymm1, VMASK,
- ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
- }
- }
- add(AO1, LDA);
- }
-
- if (!isTransB) {
- vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
- }
- fma(useFma, ymm0, ymm2, reg00);
- if (unroll_m >= 16) {
- fma(useFma, ymm1, ymm2, reg06);
- }
-
- if (unroll_n >= 2) {
- if (!isTransB) {
- vbroadcastss(
- ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
- }
- fma(useFma, ymm0, ymm2, reg01);
- if (unroll_m >= 16) {
- fma(useFma, ymm1, ymm2, reg07);
- }
- }
-
- if (unroll_n >= 3) {
- if (!isTransB) {
- vbroadcastss(
- ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
- }
- fma(useFma, ymm0, ymm2, reg02);
- if (unroll_m >= 16) {
- fma(useFma, ymm1, ymm2, reg08);
- }
- }
-
- if (unroll_n >= 4) {
- if (!isTransB) {
- vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
- }
- fma(useFma, ymm0, ymm2, reg03);
- if (unroll_m >= 16) {
- fma(useFma, ymm1, ymm2, reg09);
- }
- }
-
- if (unroll_n >= 5) {
- if (!isTransB) {
- vbroadcastss(
- ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
- }
- fma(useFma, ymm0, ymm2, reg04);
- if (unroll_m >= 16) {
- fma(useFma, ymm1, ymm2, reg10);
- }
- }
-
- if (unroll_n >= 6) {
- if (!isTransB) {
- vbroadcastss(
- ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
- } else {
- vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
- }
- fma(useFma, ymm0, ymm2, reg05);
- if (unroll_m >= 16) {
- fma(useFma, ymm1, ymm2, reg11);
- }
- }
-
- if (isCopy) {
- vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
- ymm0);
- if (unroll_m >= 16) {
- vmovups(ptr[LDA4 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE],
- ymm1);
- }
- sub(LDA4, -unroll_m * SIZE);
- }
-
- if (!isDirect) {
- if (isLoad1Unmasked) {
- vmovups(ymm0,
- ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]);
- } else {
- vmaskmovps(ymm0, VMASK,
- ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]);
- }
- if (unroll_m >= 16) {
- if (isLoad2Unmasked) {
- vmovups(ymm1, ptr[AO1
- + (unroll_m * 1 + 1 * 8 - OFFSET)
- * SIZE]);
- } else {
- vmaskmovps(ymm1, VMASK,
- ptr[AO1
- + (unroll_m * 1 + 1 * 8 - OFFSET)
- * SIZE]);
- }
- }
- sub(AO1, -unroll_m * SIZE);
- }
-
- if (!isTransB) {
- sub(BO1, -SIZE);
- if (unroll_n >= 4) {
- sub(BO2, -SIZE);
- }
- } else {
- add(BO1, LDB);
- }
-
- };
-
- // Main kernel; does prefetching and calls innerkernel{1,2,4,8} as
- // appropriate
- // After calculating results in registers, writes back to C matrix
- auto kernel = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
- bool isLoad2Unmasked, bool isDirect, bool isCopy, bool useFma,
- Ymm reg00 = Ymm(4), Ymm reg01 = Ymm(5), Ymm reg02 = Ymm(6),
- Ymm reg03 = Ymm(7), Ymm reg04 = Ymm(8), Ymm reg05 = Ymm(9),
- Ymm reg06 = Ymm(10), Ymm reg07 = Ymm(11), Ymm reg08 = Ymm(12),
- Ymm reg09 = Ymm(13), Ymm reg10 = Ymm(14), Ymm reg11 = Ymm(15),
- Ymm reg12 = Ymm(4), Ymm reg13 = Ymm(5), Ymm reg14 = Ymm(6),
- Ymm reg15 = Ymm(7), Ymm reg16 = Ymm(8), Ymm reg17 = Ymm(9),
- Ymm reg18 = Ymm(10), Ymm reg19 = Ymm(11), Ymm reg20 = Ymm(12),
- Ymm reg21 = Ymm(13), Ymm reg22 = Ymm(14), Ymm reg23 = Ymm(15)) {
- if (!isDirect) {
- lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
- } else {
- mov(AO1, A);
- }
-
- if (isCopy) {
- lea(LDA4, ptr[rsp + 256 + OFFSET * SIZE]);
- } else {
- lea(LDA4, ptr[LDA * 8 + (8 - 1 - OFFSET) * SIZE]);
- }
-
- if (isTransB) {
- lea(BO2, ptr[LDB * 4 + (8 - 1 - OFFSET) * SIZE]);
- lea(BO2, ptr[BO2 + LDB * 2]);
- }
-
- if (!isDirect) {
- if (isLoad1Unmasked) {
- vmovups(ymm0,
- ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]);
- } else {
- vmaskmovps(ymm0, VMASK,
- ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]);
- }
- if (unroll_m >= 16) {
- if (isLoad2Unmasked) {
- vmovups(ymm1, ptr[AO1
- + (unroll_m * 0 + 1 * 8 - OFFSET)
- * SIZE]);
- } else {
- vmaskmovps(ymm1, VMASK,
- ptr[AO1
- + (unroll_m * 0 + 1 * 8 - OFFSET)
- * SIZE]);
- }
- }
- }
-
- for (int i = 4; i < 10; i++) {
- vxorps(Ymm(i), Ymm(i), Ymm(i));
- vxorps(Ymm(i + 6), Ymm(i + 6), Ymm(i + 6));
- }
-
- mov(LL, K);
- sar(LL, 3);
-
- Label kernel12, kernel13, kernel14, kernel15;
- Label kernel16, kernel17, kernel18;
-
- sub(LL, SECOND_FETCH);
- jle(kernel13, T_NEAR);
- align(16);
-
- L(kernel12);
- innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
- isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
- reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
- reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
- reg21, reg22, reg23);
- jg(kernel12, T_NEAR);
- align(16);
-
- L(kernel13);
- prefetcht0(ptr[CO1 + (unroll_m - 1) * SIZE]);
- if (unroll_n >= 2)
- prefetcht0(ptr[CO1 + LDC + (unroll_m - 1) * SIZE]);
- if (unroll_n >= 3)
- prefetcht0(ptr[CO1 + LDC * 2 + (unroll_m - 1) * SIZE]);
- if (unroll_n >= 4)
- prefetcht0(ptr[CO2 + (unroll_m - 1) * SIZE]);
- if (unroll_n >= 5)
- prefetcht0(ptr[CO2 + LDC + (unroll_m - 1) * SIZE]);
- if (unroll_n >= 6)
- prefetcht0(ptr[CO2 + LDC * 2 + (unroll_m - 1) * SIZE]);
-
- add(LL, SECOND_FETCH);
- jle(kernel15, T_NEAR);
- align(16);
-
- L(kernel14);
- innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
- isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
- reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
- reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
- reg21, reg22, reg23);
- jg(kernel14, T_NEAR);
- align(16);
-
- L(kernel15);
- test(K, 4);
- jle(kernel16, T_NEAR);
- innerkernel4(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
- isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
- reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
- reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
- reg21, reg22, reg23);
-
- L(kernel16);
- test(K, 2);
- jle(kernel17, T_NEAR);
- innerkernel2(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
- isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
- reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
- reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
- reg21, reg22, reg23);
- align(16);
-
- L(kernel17);
- if (unroll_m == 16) {
- if (unroll_n <= 3) {
- vaddps(reg00, reg00, reg12);
- vaddps(reg01, reg01, reg13);
- vaddps(reg02, reg02, reg14);
- vaddps(reg06, reg06, reg18);
- vaddps(reg07, reg07, reg19);
- vaddps(reg08, reg08, reg20);
- }
- }
-
- if (unroll_m <= 8) {
- vaddps(reg00, reg00, reg12);
- vaddps(reg01, reg01, reg13);
- vaddps(reg02, reg02, reg14);
- vaddps(reg03, reg03, reg15);
- vaddps(reg04, reg04, reg16);
- vaddps(reg05, reg05, reg17);
- }
-
- test(K, 1);
- jle(kernel18, T_NEAR);
- innerkernel1(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
- isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
- reg05, reg06, reg07, reg08, reg09, reg10, reg11);
- align(16);
-
- L(kernel18);
- vbroadcastss(VALPHA, ALPHA);
-
- if (isBetaN) {
- vbroadcastss(VBETA, BETA);
- }
-
- // Write back the results; all beta and bias cases need to be
- // handled
- switch (unroll_n) {
- case 1: mov(rax, LDC); break;
- case 2: lea(rax, ptr[LDC * 2]); break;
- case 3: lea(rax, ptr[LDC + LDC * 2]); break;
- case 4: lea(rax, ptr[LDC + LDC * 4]); break;
- case 5:
- lea(rax, ptr[LDC * 4]);
- add(rax, LDC);
- break;
- case 6:
- lea(rax, ptr[LDC + LDC * 2]);
- add(rax, rax);
- break;
- }
-
- if (hasBias) {
- mov(BIAS1, BIAS);
- if (isLoad1Unmasked) {
- vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]);
- } else {
- vmaskmovps(VBIAS1, VMASK, ptr[BIAS1 + 0 * SIZE]);
- }
- }
-
- for (int i = 0; i < unroll_n; i++) {
- vmulps(Ymm(i + 4), Ymm(i + 4), VALPHA);
- if (!isBeta0) {
- if (isLoad1Unmasked) {
- switch (i) {
- case 0: vmovups(ymm0, ptr[CO1 + 0 * SIZE]); break;
- case 1: vmovups(ymm0, ptr[CO1 + LDC + 0 * SIZE]); break;
- case 2:
- vmovups(ymm0, ptr[CO1 + LDC * 2 + 0 * SIZE]);
- break;
- case 3: vmovups(ymm0, ptr[CO2 + 0 * SIZE]); break;
- case 4: vmovups(ymm0, ptr[CO2 + LDC + 0 * SIZE]); break;
- case 5:
- vmovups(ymm0, ptr[CO2 + LDC * 2 + 0 * SIZE]);
- break;
- }
- } else {
- switch (i) {
- case 0:
- vmaskmovps(ymm0, VMASK, ptr[CO1 + 0 * SIZE]);
- break;
- case 1:
- vmaskmovps(ymm0, VMASK, ptr[CO1 + LDC + 0 * SIZE]);
- break;
- case 2:
- vmaskmovps(
- ymm0, VMASK, ptr[CO1 + LDC * 2 + 0 * SIZE]);
- break;
- case 3:
- vmaskmovps(ymm0, VMASK, ptr[CO2 + 0 * SIZE]);
- break;
- case 4:
- vmaskmovps(ymm0, VMASK, ptr[CO2 + LDC + 0 * SIZE]);
- break;
- case 5:
- vmaskmovps(
- ymm0, VMASK, ptr[CO2 + LDC * 2 + 0 * SIZE]);
- break;
- }
- }
-
- if (!isBetaN) {
- vaddps(Ymm(i + 4), ymm0, Ymm(i + 4));
- } else {
- fma(useFma, VBETA, ymm0, Ymm(i + 4), true);
- }
- }
- if (hasBias) {
- vaddps(Ymm(i + 4), VBIAS1, Ymm(i + 4));
- }
- if (isLoad1Unmasked) {
- switch (i) {
- case 0: vmovups(ptr[CO1 + 0 * SIZE], Ymm(i + 4)); break;
- case 1:
- vmovups(ptr[CO1 + LDC + 0 * SIZE], Ymm(i + 4));
- break;
- case 2:
- vmovups(ptr[CO1 + LDC * 2 + 0 * SIZE], Ymm(i + 4));
- break;
- case 3: vmovups(ptr[CO2 + 0 * SIZE], Ymm(i + 4)); break;
- case 4:
- vmovups(ptr[CO2 + LDC + 0 * SIZE], Ymm(i + 4));
- break;
- case 5:
- vmovups(ptr[CO2 + LDC * 2 + 0 * SIZE], Ymm(i + 4));
- break;
- }
- } else {
- switch (i) {
- case 0:
- vmaskmovps(ptr[CO1 + 0 * SIZE], VMASK, Ymm(i + 4));
- break;
- case 1:
- vmaskmovps(
- ptr[CO1 + LDC + 0 * SIZE], VMASK, Ymm(i + 4));
- break;
- case 2:
- vmaskmovps(ptr[CO1 + LDC * 2 + 0 * SIZE], VMASK,
- Ymm(i + 4));
- break;
- case 3:
- vmaskmovps(ptr[CO2 + 0 * SIZE], VMASK, Ymm(i + 4));
- break;
- case 4:
- vmaskmovps(
- ptr[CO2 + LDC + 0 * SIZE], VMASK, Ymm(i + 4));
- break;
- case 5:
- vmaskmovps(ptr[CO2 + LDC * 2 + 0 * SIZE], VMASK,
- Ymm(i + 4));
- break;
- }
- }
-
- if (unroll_m >= 16) {
- // Re-use ymm4 (VBIAS2)
- if (i == 0) {
- if (hasBias) {
- if (isLoad1Unmasked) {
- vmovups(VBIAS2, ptr[BIAS1 + 8 * SIZE]);
- } else {
- vmaskmovps(
- VBIAS2, VMASK, ptr[BIAS1 + 8 * SIZE]);
- }
- }
- }
- vmulps(Ymm(i + 10), Ymm(i + 10), VALPHA);
- if (!isBeta0) {
- if (isLoad2Unmasked) {
- switch (i) {
- case 0: vmovups(ymm0, ptr[CO1 + 8 * SIZE]); break;
- case 1:
- vmovups(ymm0, ptr[CO1 + LDC + 8 * SIZE]);
- break;
- case 2:
- vmovups(ymm0, ptr[CO1 + LDC * 2 + 8 * SIZE]);
- break;
- case 3: vmovups(ymm0, ptr[CO2 + 8 * SIZE]); break;
- case 4:
- vmovups(ymm0, ptr[CO2 + LDC + 8 * SIZE]);
- break;
- case 5:
- vmovups(ymm0, ptr[CO2 + LDC * 2 + 8 * SIZE]);
- break;
- }
- } else {
- switch (i) {
- case 0:
- vmaskmovps(ymm0, VMASK, ptr[CO1 + 8 * SIZE]);
- break;
- case 1:
- vmaskmovps(
- ymm0, VMASK, ptr[CO1 + LDC + 8 * SIZE]);
- break;
- case 2:
- vmaskmovps(ymm0, VMASK,
- ptr[CO1 + LDC * 2 + 8 * SIZE]);
- break;
- case 3:
- vmaskmovps(ymm0, VMASK, ptr[CO2 + 8 * SIZE]);
- break;
- case 4:
- vmaskmovps(
- ymm0, VMASK, ptr[CO2 + LDC + 8 * SIZE]);
- break;
- case 5:
- vmaskmovps(ymm0, VMASK,
- ptr[CO2 + LDC * 2 + 8 * SIZE]);
- break;
- }
- }
- if (!isBetaN) {
- vaddps(Ymm(i + 10), ymm0, Ymm(i + 10));
- } else {
- fma(useFma, VBETA, ymm0, Ymm(i + 10), true);
- }
- }
- if (hasBias) {
- vaddps(Ymm(i + 10), VBIAS2, Ymm(i + 10));
- }
- if (isLoad2Unmasked) {
- switch (i) {
- case 0:
- vmovups(ptr[CO1 + 8 * SIZE], Ymm(i + 10));
- break;
- case 1:
- vmovups(ptr[CO1 + LDC + 8 * SIZE], Ymm(i + 10));
- break;
- case 2:
- vmovups(ptr[CO1 + LDC * 2 + 8 * SIZE], Ymm(i + 10));
- break;
- case 3:
- vmovups(ptr[CO2 + 8 * SIZE], Ymm(i + 10));
- break;
- case 4:
- vmovups(ptr[CO2 + LDC + 8 * SIZE], Ymm(i + 10));
- break;
- case 5:
- vmovups(ptr[CO2 + LDC * 2 + 8 * SIZE], Ymm(i + 10));
- break;
- }
- } else {
- switch (i) {
- case 0:
- vmaskmovps(ptr[CO1 + 8 * SIZE], VMASK, Ymm(i + 10));
- break;
- case 1:
- vmaskmovps(ptr[CO1 + LDC + 8 * SIZE], VMASK,
- Ymm(i + 10));
- break;
- case 2:
- vmaskmovps(ptr[CO1 + LDC * 2 + 8 * SIZE], VMASK,
- Ymm(i + 10));
- break;
- case 3:
- vmaskmovps(ptr[CO2 + 8 * SIZE], VMASK, Ymm(i + 10));
- break;
- case 4:
- vmaskmovps(ptr[CO2 + LDC + 8 * SIZE], VMASK,
- Ymm(i + 10));
- break;
- case 5:
- vmaskmovps(ptr[CO2 + LDC * 2 + 8 * SIZE], VMASK,
- Ymm(i + 10));
- break;
- }
- }
- }
- if (i == 2)
- add(CO1, rax);
- }
- if (unroll_n >= 4) {
- add(CO2, rax);
- }
-
- // Compute next address of B
- if (!isTransB) {
- lea(rax, ptr[K * SIZE]);
- switch (unroll_n) {
- case 1:
- add(BO1, LDB);
- add(BO2, LDB);
- break;
- case 2:
- lea(BO1, ptr[BO1 + LDB * 2]);
- lea(BO2, ptr[BO2 + LDB * 2]);
- break;
- case 3:
- lea(BO1, ptr[BO1 + LDB3]);
- lea(BO2, ptr[BO2 + LDB3]);
- break;
- case 4:
- lea(BO1, ptr[BO1 + LDB * 4]);
- lea(BO2, ptr[BO2 + LDB * 4]);
- break;
- case 5:
- lea(BO1, ptr[BO1 + LDB * 4]);
- add(BO1, LDB);
- lea(BO2, ptr[BO2 + LDB * 4]);
- add(BO2, LDB);
- break;
- case 6:
- lea(BO1, ptr[BO1 + LDB3 * 2]);
- lea(BO2, ptr[BO2 + LDB3 * 2]);
- break;
- }
- sub(BO1, rax);
- sub(BO2, rax);
- } else {
- mov(rax, LDB);
- imul(rax, K);
- sub(BO1, rax);
- add(BO1, unroll_n * SIZE);
- }
- };
-
- auto kernel_16x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
- bool isLoad2Unmasked, bool isDirect, bool isCopy) {
- kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
- isDirect, isCopy, true);
- };
-
- auto kernel_16x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
- bool isLoad2Unmasked, bool isDirect, bool isCopy) {
- kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
- isDirect, isCopy, true);
- };
-
- auto kernel_16x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
- bool isLoad2Unmasked, bool isDirect, bool isCopy) {
- kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
- isDirect, isCopy, true);
- };
-
- auto kernel_16x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
- bool isLoad2Unmasked, bool isDirect, bool isCopy,
- bool useFma = true) {
- kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
- isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
- Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
- Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9),
- Ymm(13), Ymm(14), Ymm(15));
- };
-
- auto kernel_16x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
- bool isLoad2Unmasked, bool isDirect, bool isCopy) {
- kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
- isDirect, isCopy, false);
- };
-
- auto kernel_16x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
- bool isLoad2Unmasked, bool isDirect, bool isCopy) {
- kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
- isDirect, isCopy, false);
- };
-
- auto kernel_8x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
- bool isLoad2Unmasked, bool isDirect, bool isCopy,
- bool useFma = true) {
- kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
- isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
- Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
- Ymm(15), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
- Ymm(15));
- };
-
- auto kernel_8x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
- bool isLoad2Unmasked, bool isDirect, bool isCopy) {
- kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
- isDirect, isCopy);
- };
-
- auto kernel_8x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
- bool isLoad2Unmasked, bool isDirect, bool isCopy) {
- kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
- isDirect, isCopy);
- };
-
- auto kernel_8x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
- bool isLoad2Unmasked, bool isDirect, bool isCopy,
- bool useFma = true) {
- kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
- isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
- Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
- Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9),
- Ymm(13), Ymm(14), Ymm(15));
- };
-
- auto kernel_8x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
- bool isLoad2Unmasked, bool isDirect, bool isCopy) {
- kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
- isDirect, isCopy, false);
- };
-
- auto kernel_8x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
- bool isLoad2Unmasked, bool isDirect, bool isCopy) {
- kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
- isDirect, isCopy, false);
- };
-
- // High-level subroutine; does packing if needed, then splits C matrix.
- // Operates on chunks of 16 rows, 6 columns at a time (handling tail
- // cases appropriately).
- // Masking is used for tail cases where M is not divisible by 8.
- auto subloop = [&](
- int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
- if (isTransA) {
- do_pack(unroll_m, isLoad1Unmasked, isLoad2Unmasked);
- }
-
- Label subloop11, subloop11mask;
- Label subloop20, subloop21, subloop22, subloop23;
- Label subloop24, subloop25;
- Label subloop30, subloop31, subloop32, subloop33;
- Label subloop34, subloop35;
- Label subloop98, subloop98mask;
- Label subloop99, subloop99mask;
-
- mov(CO1, C);
- lea(CO2, ptr[CO1 + LDC * 2]);
- add(CO2, LDC);
- add(C, unroll_m * SIZE);
- mov(BO1, B);
- if (!isTransB) {
- lea(BO2, qword[B + LDB3]);
- }
-
- if (!isTransA) {
- lea(AA, ptr[A + (unroll_m * 2 - 1 - OFFSET) * SIZE]);
- cmp(M, UNROLL_M);
- jg(subloop98, T_NEAR);
-
- mov(AA, ORIG_A);
- lea(AA, ptr[AA + (unroll_m - 1 - OFFSET) * SIZE]);
- L(subloop98);
- }
-
- mov(LL, N);
- mov(I, LL);
- if (!isTransA) {
- // If N is too small, skip copy operation
- cmp(LL, UNROLL_N * 3);
- jle(subloop30, T_NEAR);
-
- // If A is not aligned to cache line
- cmp(FLAG, 0);
- je(subloop30, T_NEAR);
- } else {
- cmp(LL, UNROLL_N);
- jl(subloop20, T_NEAR);
- }
- align(16);
-
- if (!isTransA) {
- if (unroll_m == 16) {
- kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
- isLoad2Unmasked, true, true);
- } else {
- kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
- isLoad2Unmasked, true, true);
- }
- } else {
- if (unroll_m == 16) {
- kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
- isLoad2Unmasked, false, false);
- } else {
- kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
- isLoad2Unmasked, false, false);
- }
- }
-
- sub(I, UNROLL_N);
- cmp(I, UNROLL_N);
- jl(subloop20, T_NEAR);
- align(16);
-
- L(subloop11);
- if (unroll_m == 16) {
- kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
- isLoad2Unmasked, false, false);
- } else {
- kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, isLoad2Unmasked,
- false, false);
- }
- sub(I, UNROLL_N);
- cmp(I, UNROLL_N);
- jge(subloop11, T_NEAR);
- align(16);
-
- L(subloop20);
- cmp(I, 1);
- jne(subloop21, T_NEAR);
- if (unroll_m == 16) {
- kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
- false, false);
- } else {
- kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, false,
- false);
- }
- jmp(subloop99, T_NEAR);
- align(16);
-
- L(subloop21);
- cmp(I, 2);
- jne(subloop22, T_NEAR);
- if (unroll_m == 16) {
- kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
- false, false);
- } else {
- kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, false,
- false);
- }
- jmp(subloop99, T_NEAR);
- align(16);
-
- L(subloop22);
- cmp(I, 3);
- jne(subloop23, T_NEAR);
- if (unroll_m == 16) {
- kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
- false, false);
- } else {
- kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, false,
- false);
- }
- jmp(subloop99, T_NEAR);
- align(16);
-
- L(subloop23);
- cmp(I, 4);
- jne(subloop24, T_NEAR);
- if (unroll_m == 16) {
- kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
- false, false);
- } else {
- kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, false,
- false);
- }
- jmp(subloop99, T_NEAR);
- align(16);
-
- L(subloop24);
- cmp(I, 5);
- jne(subloop99, T_NEAR);
- if (unroll_m == 16) {
- kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
- false, false);
- } else {
- kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, false,
- false);
- }
- jmp(subloop99, T_NEAR);
- align(16);
-
- if (!isTransA) {
- L(subloop30);
- cmp(I, UNROLL_N);
- jl(subloop25, T_NEAR);
- align(16);
-
- L(subloop31);
- if (unroll_m == 16) {
- kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
- isLoad2Unmasked, true, false);
- } else {
- kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
- isLoad2Unmasked, true, false);
- }
- sub(I, UNROLL_N);
- cmp(I, UNROLL_N);
- jge(subloop31, T_NEAR);
- align(16);
-
- L(subloop25);
- cmp(I, 1);
- jne(subloop32, T_NEAR);
- if (unroll_m == 16) {
- kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
- true, false);
- } else {
- kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
- true, false);
- }
- jmp(subloop99, T_NEAR);
- align(16);
-
- L(subloop32);
- cmp(I, 2);
- jne(subloop33, T_NEAR);
- if (unroll_m == 16) {
- kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
- true, false);
- } else {
- kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
- true, false);
- }
- jmp(subloop99, T_NEAR);
- align(16);
-
- L(subloop33);
- cmp(I, 3);
- jne(subloop34, T_NEAR);
- if (unroll_m == 16) {
- kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
- true, false);
- } else {
- kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
- true, false);
- }
- jmp(subloop99, T_NEAR);
- align(16);
-
- L(subloop34);
- cmp(I, 4);
- jne(subloop35, T_NEAR);
- if (unroll_m == 16) {
- kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
- true, false);
- } else {
- kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
- true, false);
- }
- jmp(subloop99, T_NEAR);
- align(16);
-
- L(subloop35);
- cmp(I, 5);
- jne(subloop99, T_NEAR);
- if (unroll_m == 16) {
- kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
- true, false);
- } else {
- kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
- true, false);
- }
- align(16);
- }
-
- L(subloop99);
- // Compute address for A
- if (!isTransA) {
- add(A, unroll_m * SIZE);
- } else {
- mov(rax, LDA);
- imul(rax, rax, unroll_m);
- add(A, rax);
- }
-
- // Compute next address of BIAS
- if (hasBias) {
- add(BIAS, unroll_m * SIZE);
- }
- };
-
- preamble();
-
- Label buffer_in_ws, buffer_allocated;
-
- // Get the registers
- mov(B, ARG_B);
- mov(LDB, ARG_LDB);
- mov(r15, ARG_BETA);
- mov(r12, ARG_C);
- if (hasBias)
- mov(r10, ARG_BIAS);
- mov(LDC, ARG_LDC);
- mov(rbp, rsp);
-
- vmovss(xmm0, ptr[ARG_ALPHA]);
- vmovss(xmm1, ptr[r15]);
-
-#if _WIN32
- mov(A, ARG_A);
- mov(LDA, ARG_LDA);
-#endif
-
- cmp(K, STACK_K_CAPACITY);
- jg(buffer_in_ws, T_NEAR);
-
- // Create buffer and align to 4kB page
- lea(rax, ptr[K * SIZE]);
- sal(rax, 4);
- add(rax, 256);
- sub(rsp, rax);
- and_(rsp, -PAGE_4K);
- jmp(buffer_allocated, T_NEAR);
-
- L(buffer_in_ws);
- mov(rsp, ARG_WS);
-
- L(buffer_allocated);
-
- mov(ORIG_SP, rbp);
- mov(M, ARG_M);
- mov(N, ARG_N);
- mov(C, r12);
- if (hasBias)
- mov(BIAS, r10);
- vmovss(ALPHA, xmm0);
- vmovss(BETA, xmm1);
- sub(A, -OFFSET * SIZE);
- sub(B, -OFFSET * SIZE);
- mov(ORIG_A, A);
- sal(LDA, BASE_SHIFT);
- sal(LDB, BASE_SHIFT);
- sal(LDC, BASE_SHIFT);
- lea(LDB3, ptr[LDB + LDB * 2]);
-
- for (int i = 0; i < 8; i++) {
- mov(dword[rsp + 88 + i * 4], i);
- }
-
- if (isTransA && is_avx2) {
- movq(xmm0, LDA);
- vpbroadcastq(ymm1, xmm0);
- vinsertf128(ymm0, ymm0, xmm0, 1);
- vpermilpd(ymm0, ymm0, 5);
- vpaddq(ymm1, ymm1, ymm1);
- vperm2f128(ymm1, ymm1, ymm1, 8);
- vpaddq(ymm0, ymm0, ymm1);
- vmovups(STRIDE, ymm0);
- }
-
- // Check A alignment and leading dimension; take copy-based path as
- // needed
- mov(rax, LDA);
- or_(rax, A);
- and_(rax, 0x1f);
- mov(FLAG, rax);
-
- Label main0, main1, main2, main3, main999;
-
- cmp(M, UNROLL_M);
- jl(main0, T_NEAR);
- align(16);
-
- L(main1);
- subloop(UNROLL_M, true, true);
- sub(M, UNROLL_M);
- cmp(M, UNROLL_M);
- jge(main1, T_NEAR);
- align(16);
-
- L(main0);
- cmp(M, 0);
- jle(main999, T_NEAR);
-
- if (UNROLL_M > 8) {
- cmp(M, 8);
- jle(main2, T_NEAR);
-
- sub(M, 8);
- vbroadcastss(VMASK, M);
- vpcmpgtd(VMASK, VMASK, MASK);
-
- subloop(16, true, false);
- jmp(main999, T_NEAR);
- align(16);
-
- L(main2);
- cmp(M, 8);
- jne(main3, T_NEAR);
- subloop(8, true, true);
- jmp(main999, T_NEAR);
- }
-
- align(16);
-
- L(main3);
- vbroadcastss(VMASK, M);
- if (is_avx2) {
- vpcmpgtd(VMASK, VMASK, MASK);
- } else {
- auto xmask = Xmm(VMASK.getIdx());
- auto xmm_tmp = xmm4;
-
- vextractf128(xmm_tmp, VMASK, 1);
- vpcmpgtd(xmask, xmask, MASK);
- vpcmpgtd(xmm_tmp, xmm_tmp, dword[rsp + 88 + 4 * 4]); // MASK + 4
- vinsertf128(VMASK, VMASK, xmm_tmp, 1);
- }
- subloop(8, false, false);
- align(16);
-
- L(main999);
- // Restore original stack
- mov(rsp, ORIG_SP);
-
- vzeroupper();
- postamble();
-
- ker_ = this->getCode<ker_t>();
- }
-
- typedef void (*ker_t)(dim_t m, dim_t n, dim_t k,
- const float *alpha, const float *a, dim_t lda,
- const float *b, dim_t ldb, const float *beta, float *c,
- dim_t ldc, const float *bias, float *ws);
-
- void operator()(dim_t m, dim_t n, dim_t k,
- const float *alpha, const float *a, dim_t lda,
- const float *b, dim_t ldb, const float *beta, float *c,
- dim_t ldc, const float *bias, float *ws) const
- {
- ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
- }
-
-private:
- ker_t ker_;
-};
-
-const xbyak_gemm *get_xbyak_gemm(
- bool isTransA, bool isTransB, float beta, bool hasBias) {
- auto beta_idx = [](float beta) {
- return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2);
- };
-
- // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)]
- static xbyak_gemm *kernel_table[2][2][2][3];
- static std::once_flag initialized;
- std::call_once(initialized, [=]{
- for (bool isTransA: {false, true})
- for (bool isTransB: {false, true})
- for (bool hasBias: {false, true})
- for (float beta: {0.0f, 1.0f, 2.0f}) {
- // nocopy sgemm with bias for beta != 0.0 is not supported
- if (hasBias && beta != 0.0)
- continue;
- kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] =
- new xbyak_gemm(isTransA, isTransB, beta, hasBias);
- }
- });
-
- return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)];
-}
-
-void sgemm_nocopy_driver(const char *transa,
- const char *transb, int m, int n, int k, const float *alpha,
- const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta,
- float *c, dim_t ldc, const float *bias, float *ws)
-{
- bool isTransA = (*transa == 'T' || *transa == 't');
- bool isTransB = (*transb == 'T' || *transb == 't');
-
- int Bm, sizeM, Bn, sizeN, Bk, sizeK;
-
- int i, j;
-
- if ((m <= 0) || (n <= 0))
- return;
-
- if ((k <= 0) || (alpha[0] == 0.)) {
-
- if (beta[0] == 0.) {
- for (j = 0; j < n; j++)
- for (i = 0; i < m; i++)
- c[i + j * ldc] = 0.0;
- } else if (beta[0] != 1.) {
- for (j = 0; j < n; j++)
- for (i = 0; i < m; i++)
- c[i + j * ldc] *= beta[0];
- }
-
- return;
- }
-
- assert(IMPLICATION(bias != nullptr, *beta == 0.0));
-
- // XXX: this happens on every thread...
- bool hasBias = (bias != nullptr);
- auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias);
- auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false);
- auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false);
- assert(ker_bn && ker_b1 && ker_b0);
-
- int BM = 4032;
- int BN = isTransA ? 96 : 48;
- int BK = isTransB ? 96 : 256;
- const float *curA, *curB, *curBias = nullptr;
- float *curC;
-
- for (Bk = 0; Bk < k; Bk += sizeK) {
- sizeK = k - Bk;
- if (sizeK >= BK * 2)
- sizeK = BK;
- else {
- if (sizeK > BK)
- sizeK = (sizeK + 1) / 2;
- }
-
- for (Bm = 0; Bm < m; Bm += sizeM) {
- sizeM = m - Bm;
- if (sizeM >= BM * 2)
- sizeM = BM;
- else {
- if (sizeM > BM + BM / 2)
- sizeM = (sizeM + 1) / 2;
- }
-
- for (Bn = 0; Bn < n; Bn += sizeN) {
- sizeN = n - Bn;
- if (sizeN >= BN * 2)
- sizeN = BN;
- else {
- if (sizeN > BN + BN / 2)
- sizeN = (sizeN + 1) / 2;
- }
-
- if (!isTransA) {
- curA = a + Bm + Bk * lda;
- } else {
- curA = a + Bk + Bm * lda;
- }
- if (!isTransB) {
- curB = b + Bk + Bn * ldb;
- } else {
- curB = b + Bn + Bk * ldb;
- }
- curC = c + Bm + (size_t)Bn * ldc;
- if (bias != nullptr) {
- if (Bk == 0) {
- curBias = bias + Bm;
- } else {
- curBias = nullptr;
- }
- }
- if (Bk == 0) {
- if (*beta == 0.0 && bias == nullptr)
- (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
- alpha, curA, lda, curB, ldb, beta, curC, ldc,
- curBias, ws);
- else
- (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
- alpha, curA, lda, curB, ldb, beta, curC, ldc,
- curBias, ws);
- } else {
- (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
- alpha, curA, lda, curB, ldb, beta, curC, ldc,
- curBias, ws);
- }
- }
- }
- }
-}
-
-}
-
-mkldnn_status_t jit_avx_gemm_f32(
- const char *transa, const char *transb,
- const int *p_m, const int *p_n, const int *p_k, const float *p_alpha,
- const float *A, const int *p_lda, const float *B, const int *p_ldb,
- const float *p_beta, float *C, const int *p_ldc, const float *bias)
-{
- using namespace mkldnn::impl::utils;
- using namespace avx_gemm_f32;
- using namespace gemm_utils;
-
- if (*p_beta != 0 && bias)
- return ref_gemm(transa, transb, p_m, p_n, p_k,
- p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias);
-
- int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
-
- int m = *p_m;
- int n = *p_n;
- int k = *p_k;
- dim_t lda = *p_lda;
- dim_t ldb = *p_ldb;
- dim_t ldc = *p_ldc;
- float beta = *p_beta;
- int MB, NB, KB;
-
- int nthr_m, nthr_n, nthr_k, nthr_mn;
-
- // Determine threading partitioning
- calc_nthr_nocopy_avx(
- m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
- assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
-
- // May not happen, but just in case
- if (nthr < nthr_m * nthr_n * nthr_k)
- nthr = nthr_m * nthr_n * nthr_k;
-
- nthr_mn = nthr_m * nthr_n;
-
- unsigned char * ompstatus_ = nullptr;
- unsigned char volatile *ompstatus = nullptr;
-
- float *c_buffers = nullptr;
- float *ws_buffers = nullptr;
-
- if (nthr_k > 1) {
- ompstatus_ = (unsigned char *) malloc(
- nthr * CACHE_LINE_SIZE,
- CACHE_LINE_SIZE);
- ompstatus = (unsigned char volatile *) ompstatus_;
- assert(ompstatus);
-
- for (int i = 0; i < nthr; i++)
- ompstatus[i * CACHE_LINE_SIZE] = 0;
-
- c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
- * sizeof(float), PAGE_4K);
- }
-
- const size_t ws_elems_per_thr = (size_t)k * 16 + 64;
- const size_t ws_size_per_thr
- = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
- if (k > STACK_K_CAPACITY) {
- ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
- }
-
- parallel_nd(nthr, [&](const int ithr) {
- int ithr_m, ithr_n, ithr_k, ithr_mn;
- int m_from, m_to, myM;
- int n_from, n_to, myN;
- int k_from, k_to, myK;
- int cbase, ibase;
- const float *myA, *myB, *myBias = nullptr;
- float *myC = C, myBeta;
- float *ws = ws_buffers ?
- ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
- dim_t ld = ldc;
-
- int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k);
-
- if (ithr < nthr_m * nthr_n * nthr_k) {
-
- ithr_mn = ithr % nthr_mn;
- ithr_m = ithr_mn % nthr_m;
- ithr_n = ithr_mn / nthr_m;
- ithr_k = ithr / nthr_mn;
-
- /* swap ithr_k for performance improvement */
- if (ithr_k == 0)
- ithr_k = nthr_k - 1;
- else if (ithr_k == nthr_k - 1)
- ithr_k = 0;
-
- m_from = MB * (ithr_m);
- m_to = MB * (ithr_m + 1);
- if (m_to > m)
- m_to = m;
- myM = m_to - m_from;
-
- n_from = NB * (ithr_n);
- n_to = NB * (ithr_n + 1);
- if (n_to > n)
- n_to = n;
- myN = n_to - n_from;
-
- k_from = KB * (ithr_k);
- k_to = KB * (ithr_k + 1);
- if (k_to > k)
- k_to = k;
- myK = k_to - k_from;
-
- cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
- ibase = (ithr_m + nthr_m * ithr_n) * nthr_k;
-
- if ((myM > 0) && (myN > 0)) {
-
- if (*transa == 'N' || *transa == 'n') {
- myA = &(A[m_from + k_from * lda]);
- } else {
- myA = &(A[k_from + m_from * lda]);
- }
- if (*transb == 'N' || *transb == 'n') {
- myB = &(B[k_from + n_from * ldb]);
- } else {
- myB = &(B[n_from + k_from * ldb]);
- }
- if (ithr_k == 0) {
- myC = &(C[m_from + n_from * ldc]);
- myBeta = beta;
- ld = ldc;
- if (bias)
- myBias = &(bias[m_from]);
- } else {
- myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1);
- myBeta = 0.0;
- ld = MB;
- myBias = nullptr;
- }
-
- sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
- lda, myB, ldb, &myBeta, myC, ld, myBias, ws);
-
- if (nthr_k > 1 && !sum_later)
- ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1;
- }
-
- if (nthr_k > 1 && !sum_later) {
-
- // sum matrices partitioned along K dimension
- int n1, n2;
-
- partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
-
- if (ithr_k > 0) {
-
- myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
- + (dim_t)n1 * MB;
- /* need to wait until main thread finishes */
- while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) {
- };
-
- /* my cache is hot */
- sum_two_matrices(myM, n2, myC, MB,
- &C[m_from + (n_from + n1) * ldc], ldc);
- }
-
- for (int ik = 1; ik < nthr_k; ++ik) {
- if (ik != ithr_k) {
-
- myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
- + (dim_t)n1 * MB;
-
- while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) {
- };
-
- sum_two_matrices(myM, n2, myC, MB,
- &C[m_from + (n_from + n1) * ldc], ldc);
- }
- }
- }
- }
- });
-
- // handle C summation later
- if (nthr_k > 1 && ompstatus[0] == 0) {
-
- parallel_nd(nthr, [&](const int ithr) {
- int ithr_m, ithr_n, ithr_k, ithr_mn;
- int m_from, m_to, myM;
- int n_from, n_to, myN;
- int cbase;
- float *myC = C;
-
- if (ithr < nthr_m * nthr_n * nthr_k) {
-
- ithr_mn = ithr % nthr_mn;
- ithr_m = ithr_mn % nthr_m;
- ithr_n = ithr_mn / nthr_m;
- ithr_k = ithr / nthr_mn;
-
- /* swap ithr_k for performance improvement */
- if (ithr_k == 0)
- ithr_k = nthr_k - 1;
- else if (ithr_k == nthr_k - 1)
- ithr_k = 0;
-
- m_from = MB * (ithr_m);
- m_to = MB * (ithr_m + 1);
- if (m_to > m)
- m_to = m;
- myM = m_to - m_from;
-
- n_from = NB * (ithr_n);
- n_to = NB * (ithr_n + 1);
- if (n_to > n)
- n_to = n;
- myN = n_to - n_from;
-
- cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
-
- if (nthr_k > 1) {
- // sum matrices partitioned along K dimension
- int n1, n2;
-
- partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
-
- if (ithr_k > 0) {
-
- myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
- + (dim_t)n1 * MB;
-
- /* my cache is hot */
- sum_two_matrices(myM, n2, myC, MB,
- &C[m_from + (n_from + n1) * ldc], ldc);
- }
-
- for (int ik = 1; ik < nthr_k; ++ik) {
- if (ik != ithr_k) {
-
- myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
- + (dim_t)n1 * MB;
-
- sum_two_matrices(myM, n2, myC, MB,
- &C[m_from + (n_from + n1) * ldc], ldc);
- }
- }
- }
- }
- });
- }
-
-
- free(c_buffers);
- free(ompstatus_);
- free(ws_buffers);
-
- return mkldnn_success;
-}
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp
deleted file mode 100644
index aabf520a3c..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp
+++ /dev/null
@@ -1,37 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef JIT_AVX_GEMM_F32_HPP
-#define JIT_AVX_GEMM_F32_HPP
-
-#include "mkldnn_types.h"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-mkldnn_status_t jit_avx_gemm_f32(
- const char *transa, const char *transb, const int *M,
- const int *N, const int *K, const float *alpha, const float *A,
- const int *lda, const float *B, const int *ldb, const float *beta,
- float *C, const int *ldc, const float *bias = nullptr);
-
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp
deleted file mode 100644
index 5147885a89..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp
+++ /dev/null
@@ -1,346 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "mkldnn_types.h"
-
-#include "mkldnn_thread.hpp"
-#include "nstl.hpp"
-#include "utils.hpp"
-
-#include "jit_generator.hpp"
-
-#include "gemm_utils_f32.hpp"
-#include "ref_gemm_f32.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::utils;
-using namespace gemm_utils;
-
-namespace {
-
-template <typename data_t>
-void copy_A(
- bool isTransA, int K, const data_t *A, const dim_t lda, data_t *ws) {
- for (int k = 0; k < K; k++) {
- PRAGMA_OMP_SIMD()
- for (int i = 0; i < unroll_factor<data_t>::m; i++) {
- ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda];
- }
- ws += unroll_factor<data_t>::m;
- }
-}
-
-template <typename data_t, bool isTransA, bool isTransB>
-void kernel_mxn(int K, const data_t *A, const dim_t lda,
- const data_t *B, const dim_t ldb, data_t *C, const dim_t ldc,
- const data_t alpha, const data_t beta) {
- data_t c[unroll_factor<data_t>::m * unroll_factor<data_t>::n] =
- { static_cast<data_t>(0.) };
- for (int k = 0; k < K; k++) {
- for (int j = 0; j < unroll_factor<data_t>::n; j++) {
- data_t b = isTransB ? B[j + k * ldb] : B[k + j * ldb];
- PRAGMA_OMP_SIMD()
- for (int i = 0; i < unroll_factor<data_t>::m; i++) {
- data_t a = isTransA ? A[i * lda + k] : A[i + lda * k];
- c[i + unroll_factor<data_t>::m * j] += a * b;
- }
- }
- }
- for (int j = 0; j < unroll_factor<data_t>::n; j++) {
- PRAGMA_OMP_SIMD()
- for (int i = 0; i < unroll_factor<data_t>::m; i++) {
- C[i + j * ldc] = (beta == static_cast<data_t>(0.))
- ? alpha * c[i + unroll_factor<data_t>::m * j]
- : alpha * c[i + unroll_factor<data_t>::m * j]
- + beta * C[i + j * ldc];
- }
- }
-}
-
-template <typename data_t, bool isTransA, bool isTransB>
-void block_ker(const int M, const int N, const int K,
- const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb,
- data_t *C, const dim_t ldc, const data_t alpha, const data_t beta,
- data_t *ws, bool do_copy) {
- int Nu = rnd_dn(N, unroll_factor<data_t>::n);
- int Mu = rnd_dn(M, unroll_factor<data_t>::m);
- for (int i = 0; i < Mu; i += unroll_factor<data_t>::m) {
- for (int j = 0; j < Nu; j += unroll_factor<data_t>::n) {
- const data_t *b = isTransB ? &B[j] : &B[j * ldb];
- const data_t *a = isTransA ? &A[i * lda] : &A[i];
- if (do_copy) {
- if (j == 0) {
- copy_A<data_t>(isTransA, K, a, lda, ws);
- }
- kernel_mxn<data_t, false, isTransB>(
- K, ws, unroll_factor<data_t>::m, b, ldb,
- &C[i + j * ldc], ldc, alpha, beta);
- } else {
- kernel_mxn<data_t, isTransA, isTransB>(
- K, a, lda, b, ldb, &C[i + j * ldc], ldc, alpha, beta);
- }
- }
- }
- // tail processing
- for (int i = 0; i < M; i++) {
- for (int j = Nu; j < N; j++) {
- data_t c = beta == static_cast<data_t>(0.)
- ? static_cast<data_t>(0.)
- : beta * C[i + j * ldc];
- for (int p = 0; p < K; p++) {
- data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb];
- data_t a = isTransA ? A[p + i * lda] : A[i + p * lda];
- c += alpha * a * b;
- }
- C[i + j * ldc] = c;
- }
- }
- for (int i = Mu; i < M; i++) {
- for (int j = 0; j < Nu; j++) {
- data_t c = beta == static_cast<data_t>(0.)
- ? static_cast<data_t>(0.)
- : beta * C[i + j * ldc];
- for (int p = 0; p < K; p++) {
- data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb];
- data_t a = isTransA ? A[p + i * lda] : A[i + p * lda];
- c += alpha * a * b;
- }
- C[i + j * ldc] = c;
- }
- }
-}
-
-template <typename data_t, bool isTransA, bool isTransB>
-void gemm_ithr(const int M, const int N, const int K, const data_t alpha,
- const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb,
- const data_t beta, data_t *C, const dim_t ldc, bool do_copy,
- data_t *ws) {
- constexpr int BM = gemm_traits<data_t, isTransA, isTransB>::BM;
- constexpr int BN = gemm_traits<data_t, isTransA, isTransB>::BN;
- constexpr int BK = gemm_traits<data_t, isTransA, isTransB>::BK;
-
- const data_t *curA;
- const data_t *curB;
- data_t *curC;
-
- if ((M <= 0) || (N <= 0))
- return;
-
- if ((K <= 0) || (alpha == static_cast<data_t>(0))) {
- dim_t MN = N * M;
- if (beta == static_cast<data_t>(0.)) {
- for (dim_t j = 0; j < MN; j++)
- C[j] = static_cast<data_t>(0.);
- } else if (beta != static_cast<data_t>(1.)) {
- for (dim_t j = 0; j < MN; j++)
- C[j] *= beta;
- }
- return;
- }
-
- for (int Bk = 0; Bk < K; Bk += BK) {
- int kb = nstl::min(K - Bk, BK);
- for (int Bm = 0; Bm < M; Bm += BM) {
- int mb = nstl::min(M - Bm, BM);
- for (int Bn = 0; Bn < N; Bn += BN) {
- int nb = nstl::min(N - Bn, BN);
- curA = isTransA ? A + Bk + Bm * lda : A + Bm + Bk * lda;
- curB = isTransB ? B + Bn + Bk * ldb : B + Bk + Bn * ldb;
- curC = C + Bm + Bn * ldc;
- if (Bk == 0) {
- block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda,
- curB, ldb, curC, ldc, alpha, beta, ws, do_copy);
- } else {
- block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda,
- curB, ldb, curC, ldc, alpha, static_cast<data_t>(1.0),
- ws, do_copy);
- }
- }
- }
- }
-}
-
-}
-
-template <typename data_t>
-mkldnn_status_t ref_gemm(
- const char *transa_, const char *transb_, const int *M_,
- const int *N_, const int *K_, const data_t *alpha_, const data_t *A,
- const int *lda_, const data_t *B, const int *ldb_, const data_t *beta_,
- data_t *C, const int *ldc_, const data_t *bias) {
-
- bool isTransA = (*transa_ == 'T' || *transa_ == 't');
- bool isTransB = (*transb_ == 'T' || *transb_ == 't');
- const int M = *M_, N = *N_, K = *K_;
- const dim_t lda = *lda_, ldb = *ldb_, ldc = *ldc_;
- const data_t alpha = *alpha_, beta = *beta_;
-
- int max_nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads();
- int nthr_m, nthr_n, nthr_k;
- int MB, NB, KB;
- // thread balancing over M, N, K & size of blocking dimensions
- calc_nthr_nocopy_avx(
- M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
- assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
-
- data_t *c_buffers = nullptr;
- data_t *ws_buffers = nullptr;
- if (nthr_k > 1) {
- c_buffers = (data_t *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
- * sizeof(data_t), PAGE_4K);
- if (!c_buffers) {
- nthr_k = 1;
- KB = K;
- }
- }
-
- bool do_copy = (NB / unroll_factor<data_t>::n > 3);
- const int nthr_mn = nthr_m * nthr_n;
- const int nthr = nthr_mn * nthr_k;
- const size_t ws_elems_per_thr = K * unroll_factor<data_t>::m;
- const size_t ws_size_per_thr
- = rnd_up(ws_elems_per_thr * sizeof(data_t), PAGE_4K);
- if (do_copy) {
- ws_buffers = (data_t*)malloc(nthr * ws_size_per_thr, PAGE_4K);
- if (!ws_buffers)
- do_copy = false;
- }
-
- auto get_thr_block = [&](int &from, int &to, int &myN, int NB, int N,
- int ithr) {
- from = NB * (ithr);
- to = NB * (ithr + 1);
- if (to > N)
- to = N;
- myN = to - from;
- };
-
- parallel_nd(nthr, [&](const int ithr) {
- int ithr_mn = ithr % nthr_mn;
- int ithr_m = ithr_mn % nthr_m;
- int ithr_n = ithr_mn / nthr_m;
- int ithr_k = ithr / nthr_mn;
-
- int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
-
- data_t *ws = do_copy
- ? ws_buffers + ithr * ws_size_per_thr / sizeof(data_t)
- : nullptr;
-
- int m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0,
- k_from = 0, k_to = 0, myK = 0;
-
- get_thr_block(m_from, m_to, myM, MB, M, ithr_m);
- get_thr_block(n_from, n_to, myN, NB, N, ithr_n);
- get_thr_block(k_from, k_to, myK, KB, K, ithr_k);
-
- if (myM > 0 && myN > 0) {
- data_t myBeta, *myC;
- dim_t ld;
- if (ithr_k == 0) {
- myC = &(C[m_from + n_from * ldc]);
- myBeta = beta;
- ld = ldc;
- } else {
- myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1);
- myBeta = 0.0f;
- ld = MB;
- }
- const data_t *myA = isTransA
- ? &(A[k_from + m_from * lda])
- : &(A[m_from + k_from * lda]);
- const data_t *myB = isTransB
- ? &(B[n_from + k_from * ldb])
- : &(B[k_from + n_from * ldb]);
-
- if (!isTransA) {
- if (!isTransB) {
- gemm_ithr<data_t, false, false>(myM, myN, myK, alpha, myA,
- lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
- } else {
- gemm_ithr<data_t, false, true>(myM, myN, myK, alpha, myA,
- lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
- }
- } else {
- if (!isTransB) {
- gemm_ithr<data_t, true, false>(myM, myN, myK, alpha, myA,
- lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
- } else {
- gemm_ithr<data_t, true, true>(myM, myN, myK, alpha, myA,
- lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
- }
- }
- }
- });
-
- if (nthr_k > 1) {
- parallel_nd(nthr, [&](const int ithr) {
- int ithr_mn = ithr % nthr_mn;
- int ithr_m = ithr_mn % nthr_m;
- int ithr_k = ithr / nthr_mn;
- int ithr_n = ithr_mn / nthr_m;
-
- int n_from = 0, n_to = 0, myN = 0;
- int m_from = 0, m_to = 0, myM = 0;
-
- int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
-
- get_thr_block(n_from, n_to, myN, NB, N, ithr_n);
- get_thr_block(m_from, m_to, myM, MB, M, ithr_m);
-
- // sum matrices partitioned along K dimension
- int offset = 0, block = 0;
- gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &offset,
- &block);
- for (int ik = 1; ik < nthr_k; ++ik) {
- data_t *myC = c_buffers
- + MB * ((dim_t)NB * (cbase + ik - 1) + offset);
-
- gemm_utils::sum_two_matrices(myM, block, myC, MB,
- &C[m_from + (n_from + offset) * ldc], ldc);
- }
- });
- }
-
- if (bias) {
- parallel_nd(N, M, [&](int i, int j) {
- C[i*ldc + j] += bias[j];
- });
- }
-
- free(ws_buffers);
- free(c_buffers);
-
- return mkldnn_success;
-}
-
-template mkldnn_status_t ref_gemm<float>(
- const char *transa_, const char *transb_,
- const int *M_, const int *N_, const int *K_, const float *alpha_,
- const float *A, const int *lda_, const float *B, const int *ldb_,
- const float *beta_, float *C, const int *ldc_, const float *bias);
-
-template mkldnn_status_t ref_gemm<double>(
- const char *transa_, const char *transb_,
- const int *M_, const int *N_, const int *K_, const double *alpha_,
- const double *A, const int *lda_, const double *B, const int *ldb_,
- const double *beta_, double *C, const int *ldc_, const double *bias);
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp
deleted file mode 100644
index 7c90ba6277..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp
+++ /dev/null
@@ -1,36 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef REF_GEMM_F32_HPP
-#define REF_GEMM_F32_HPP
-
-#include "mkldnn_types.h"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <typename data_t>
-mkldnn_status_t ref_gemm(const char *transa, const char *transb, const int *M,
- const int *N, const int *K, const data_t *alpha, const data_t *A,
- const int *lda, const data_t *B, const int *ldb, const data_t *beta,
- data_t *C, const int *ldc, const data_t *bias);
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp
deleted file mode 100644
index 3dbe07d743..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp
+++ /dev/null
@@ -1,280 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "mkldnn.h"
-
-#include "mkldnn_traits.hpp"
-#include "nstl.hpp"
-
-#include "jit_generator.hpp"
-
-#include "gemm.hpp"
-
-#include "f32/jit_avx512_common_gemm_f32.hpp"
-#include "f32/jit_avx_gemm_f32.hpp"
-#include "f32/ref_gemm_f32.hpp"
-
-#include "s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp"
-#include "s8x8s32/simple_gemm_s8s8s32.hpp"
-#include "s8x8s32/ref_gemm_s8x8s32.hpp"
-
-#include "os_blas.hpp"
-
-/* USE_MKL USE_CBLAS effect
- * ------- --------- ------
- * yes yes use Intel(R) MKL CBLAS
- * yes no use jit
- * no yes system-dependent CBLAS
- * no no use jit
- */
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-mkldnn_status_t check_gemm_input(const char *transa, const char *transb,
- const int *M, const int *N, const int *K, const int *lda,
- const int *ldb, const int *ldc, const float *alpha, const float *beta,
- const bool with_bias) {
- if (utils::any_null(transa, transb, M, N, K, lda, ldb, ldc, alpha, beta))
- return mkldnn_invalid_arguments;
- if (with_bias && *beta != 0)
- return mkldnn_unimplemented;
- bool consistency = true
- && utils::one_of(*transa, 'T', 't', 'N', 'n')
- && utils::one_of(*transb, 'T', 't', 'N', 'n')
- && *M >= 0
- && *N >= 0
- && *K >= 0;
-
- if (!consistency)
- return mkldnn_invalid_arguments;
- bool isTransA = utils::one_of(*transa, 'T', 't');
- bool isTransB = utils::one_of(*transb, 'T', 't');
- int nrowA = isTransA ? *K : *M;
- int nrowB = isTransB ? *N : *K;
- consistency = true
- && *lda >= nstl::max(1, nrowA)
- && *ldb >= nstl::max(1, nrowB)
- && *ldc >= nstl::max(1, *M);
- if (!consistency)
- return mkldnn_invalid_arguments;
-
- return mkldnn_success;
-}
-
-mkldnn_status_t check_gemm_x8x8x32_input(const char *offsetc,
- const char *transa, const char *transb, const int *M, const int *N,
- const int *K, const int *lda, const int *ldb, const int *ldc,
- const float *alpha, const float *beta, const bool with_bias) {
- if (offsetc == nullptr)
- return mkldnn_invalid_arguments;
- if (!utils::one_of(*offsetc, 'F', 'f', 'C', 'c', 'R', 'r'))
- return mkldnn_invalid_arguments;
-
- return check_gemm_input(transa, transb, M, N, K, lda, ldb, ldc, alpha,
- beta, with_bias);
-}
-
-mkldnn_status_t extended_sgemm(const char *transa, const char *transb,
- const int *M, const int *N, const int *K, const float *alpha,
- const float *A, const int *lda, const float *B, const int *ldb,
- const float *beta, float *C, const int *ldc,
- const float *bias, const bool force_jit_gemm) {
- mkldnn_status_t status = check_gemm_input(transa, transb, M, N, K,
- lda, ldb, ldc, alpha, beta, bias != nullptr);
- if (status != mkldnn_success)
- return status;
-
-#ifdef USE_CBLAS
- if (!force_jit_gemm) {
- bool trA = *transa == 't' || *transa == 'T';
- bool trB = *transb == 't' || *transb == 'T';
- CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans;
- CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans;
- cblas_sgemm(CblasColMajor, Cblas_trA, Cblas_trB,
- *M, *N, *K, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
-
- if (bias) {
- // Add bias if necessary (bias is applied to columns of C)
- cblas_int incx = 1, incy = 1;
- parallel_nd(*N, [&](int n) {
- ptrdiff_t offset = (ptrdiff_t)n * (*ldc);
- cblas_saxpy(*M, 1.0, bias, incx, C + offset, incy);
- });
- }
- return mkldnn_success;
- }
-#endif
-
- if (mayiuse(avx512_common))
- return jit_avx512_common_gemm_f32(transa, transb,
- M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
- else if (mayiuse(avx))
- return jit_avx_gemm_f32(transa, transb,
- M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
- else
- return ref_gemm<float>(transa, transb,
- M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
-}
-
-template <typename b_dt>
-mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb,
- const char *offsetc, const int *M, const int *N, const int *K,
- const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
- const b_dt *B, const int *LDB, const int8_t *bo, const float *beta,
- int32_t *C, const int *LDC, const int32_t *co) {
- mkldnn_status_t status = check_gemm_x8x8x32_input(offsetc, transa, transb,
- M, N, K, LDA, LDB, LDC, alpha, beta, false);
- if (status != mkldnn_success)
- return status;
-
- if (*M == 0 || *N == 0 || *K == 0)
- return mkldnn_success;
-
-#if USE_MKL_IGEMM
- bool OCisR = (*offsetc == 'R' || *offsetc == 'r');
- bool OCisC = (*offsetc == 'C' || *offsetc == 'c');
- bool AisN = (*transa == 'N' || *transa == 'n');
- bool BisN = (*transb == 'N' || *transb == 'n');
-
- if (data_traits<b_dt>::data_type == data_type::u8) {
- CBLAS_TRANSPOSE Cblas_trA = AisN ? CblasNoTrans : CblasTrans;
- CBLAS_TRANSPOSE Cblas_trB = BisN ? CblasNoTrans : CblasTrans;
- CBLAS_OFFSET Cblas_offsetc =
- OCisR
- ? CblasRowOffset
- : OCisC
- ? CblasColOffset
- : CblasFixOffset;
- cblas_gemm_s8u8s32(CblasColMajor, Cblas_trA, Cblas_trB, Cblas_offsetc,
- *M, *N, *K, *alpha, A, *LDA, *ao, (uint8_t *)B, *LDB, *bo,
- *beta, C, *LDC, co);
- return mkldnn_success;
- } else {
- assert(data_traits<b_dt>::data_type == data_type::s8);
- // TODO CBLAS implementation of gemm_s8s8s32 goes here.
- // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo
- if (utils::everyone_is(0, *ao, *bo)) {
- return simple_gemm_s8s8s32(transa, transb, offsetc, M,
- N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta,
- C, LDC, co);
- } else {
- return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K,
- alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co);
- }
- }
-#else
- cpu_isa_t isa = isa_any;
- if (mayiuse(avx512_core_vnni)) {
- isa = avx512_core_vnni;
- } else if (mayiuse(avx512_core)) {
- isa = avx512_core;
- }
-
- if (data_traits<b_dt>::data_type == data_type::u8) {
- switch (isa) {
- case avx512_core:
- case avx512_core_vnni:
- return jit_avx512_core_gemm_s8u8s32(transa, transb, offsetc, M,
- N, K, alpha, A, LDA, ao, (uint8_t *)B, LDB, bo, beta,
- C, LDC, co);
- default:
- return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K,
- alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co);
- }
- } else {
- assert(data_traits<b_dt>::data_type == data_type::s8);
- // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo
- if ((mayiuse(avx512_core) || mayiuse(avx512_core_vnni))
- && *ao == 0 && *bo == 0) {
- return simple_gemm_s8s8s32(transa, transb, offsetc, M,
- N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta,
- C, LDC, co);
- } else {
- return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K,
- alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co);
- }
- }
-#endif
-}
-
-template
-mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb,
- const char *offsetc, const int *M, const int *N, const int *K,
- const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
- const int8_t *B, const int *LDB, const int8_t *bo, const float *beta,
- int32_t *C, const int *LDC, const int32_t *co);
-
-template
-mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb,
- const char *offsetc, const int *M, const int *N, const int *K,
- const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
- const uint8_t *B, const int *LDB, const int8_t *bo, const float *beta,
- int32_t *C, const int *LDC, const int32_t *co);
-
-}
-}
-}
-
-using namespace mkldnn::impl;
-using namespace mkldnn::impl::cpu;
-
-mkldnn_status_t mkldnn_sgemm(const char *transa, const char *transb,
- const int64_t *M, const int64_t *N, const int64_t *K, const float *alpha,
- const float *A, const int64_t *lda, const float *B, const int64_t *ldb,
- const float *beta, float *C, const int64_t *ldc) {
- int M_s32 = (int)*M;
- int N_s32 = (int)*N;
- int K_s32 = (int)*K;
- int lda_s32 = (int)*lda;
- int ldb_s32 = (int)*ldb;
- int ldc_s32 = (int)*ldc;
-
- return extended_sgemm(transa, transb, &M_s32, &N_s32, &K_s32,
- alpha, A, &lda_s32, B, &ldb_s32, beta, C, &ldc_s32);
-}
-
-mkldnn_status_t mkldnn_gemm_s8u8s32(const char *transa, const char *transb,
- const char *offsetc, const int64_t *M, const int64_t *N, const int64_t *K,
- const float *alpha, const int8_t *A, const int64_t *lda, const int8_t *ao,
- const uint8_t *B, const int64_t *ldb, const int8_t *bo, const float *beta,
- int32_t *C, const int64_t *ldc, const int32_t *co) {
- int M_s32 = (int)*M;
- int N_s32 = (int)*N;
- int K_s32 = (int)*K;
- int lda_s32 = (int)*lda;
- int ldb_s32 = (int)*ldb;
- int ldc_s32 = (int)*ldc;
- return gemm_s8x8s32(transa, transb, offsetc, &M_s32, &N_s32, &K_s32,
- alpha, A, &lda_s32, ao, B, &ldb_s32, bo, beta, C, &ldc_s32, co);
-}
-
-mkldnn_status_t mkldnn_gemm_s8s8s32(const char *transa, const char *transb,
- const char *offsetc, const int64_t *M, const int64_t *N, const int64_t *K,
- const float *alpha, const int8_t *A, const int64_t *lda, const int8_t *ao,
- const int8_t *B, const int64_t *ldb, const int8_t *bo, const float *beta,
- int32_t *C, const int64_t *ldc, const int32_t *co) {
- int M_s32 = (int)*M;
- int N_s32 = (int)*N;
- int K_s32 = (int)*K;
- int lda_s32 = (int)*lda;
- int ldb_s32 = (int)*ldb;
- int ldc_s32 = (int)*ldc;
-
- return gemm_s8x8s32<int8_t>(transa, transb, offsetc, &M_s32, &N_s32, &K_s32,
- alpha, A, &lda_s32, ao, B, &ldb_s32, bo, beta, C, &ldc_s32, co);
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp
deleted file mode 100644
index dc15ff7130..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp
+++ /dev/null
@@ -1,58 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef GEMM_HPP
-#define GEMM_HPP
-
-#include "mkldnn_types.h"
-#include "os_blas.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-mkldnn_status_t extended_sgemm(const char *transa, const char *transb,
- const int *M, const int *N, const int *K, const float *alpha,
- const float *A, const int *lda, const float *B, const int *ldb,
- const float *beta, float *C, const int *ldc,
- const float *bias = nullptr, bool force_jit_gemm = false);
-
-template <typename b_dt>
-mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb,
- const char *offsetc, const int *M, const int *N, const int *K,
- const float *alpha, const int8_t *A, const int *lda, const int8_t *ao,
- const b_dt *B, const int *ldb, const int8_t *bo, const float *beta,
- int32_t *c, const int *ldc, const int32_t *co);
-
-#ifdef USE_CBLAS
-#define GEMM_IMPL_STR "gemm:blas"
-#else
-#define GEMM_IMPL_STR "gemm:jit"
-#endif
-
-#if USE_MKL_IGEMM
-#define IGEMM_S8U8S32_IMPL_STR "igemm_s8u8s32:blas"
-#define IGEMM_S8S8S32_IMPL_STR "igemm_s8s8s32:blas"
-#else
-#define IGEMM_S8U8S32_IMPL_STR "igemm_s8u8s32:jit"
-#define IGEMM_S8S8S32_IMPL_STR "igemm_s8s8s32:jit"
-#endif
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp
deleted file mode 100644
index 4d34ede0bd..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp
+++ /dev/null
@@ -1,86 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef OS_BLAS_HPP
-#define OS_BLAS_HPP
-
-/** \file
- * Common stuff respecting USE_MKL and USE_CBLAS compile flags
- *
- * USE_MKL USE_CBLAS effect
- * ------- --------- ------
- * yes yes normal compile: jit *may* be preferred over Intel(R) MKL CBLAS
- * yes no jit calls OK; assert if cblas is ever called
- * no yes system-dependent CBLAS
- * no no gemm convolution (or other blas) N/A; create stubs
- */
-
-#if defined(USE_MKL)
-
-#include "mkl_version.h"
-
-#define USE_MKL_PACKED_GEMM (INTEL_MKL_VERSION >= 20190001)
-#define USE_MKL_IGEMM \
- (INTEL_MKL_VERSION >= 20180000 && __INTEL_MKL_BUILD_DATE >= 20170628)
-
-#include "mkl_cblas.h"
-#if !defined(USE_CBLAS)
-#define cblas_sgemm(...) assert(!"CBLAS is unavailable")
-#endif
-
-#else /* defined(USE_MKL) */
-
-#define USE_MKL_PACKED_GEMM 0
-#define USE_MKL_IGEMM 0
-
-#if defined(_SX)
-/* TODO: _SX should also define USE_CBLAS in case the later is available */
-extern "C" {
-#include "cblas.h" // CHECK: does SX also have a fortran API sgemm?
-}
-
-#elif defined(USE_CBLAS)
-#include "cblas.h" // Maybe a system/cmake cblas works for you?
-#else
-/* put the stubs to make a code compilable but not workable */
-#define cblas_sgemm(...) assert(!"CBLAS is unavailable")
-#endif /* defined(_SX) */
-
-#endif /* defined(USE_MKL) */
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-#if defined(USE_MKL) && defined(USE_CBLAS)
-typedef MKL_INT cblas_int;
-
-#elif defined(USE_CBLAS)
-typedef int cblas_int;
-
-#if defined(_SX)
-/* this cblas.h is peculiar... */
-typedef CBLAS_ORDER CBLAS_LAYOUT;
-#endif
-#endif
-
-}
-}
-}
-
-#endif /* OS_BLAS_HPP */
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp
deleted file mode 100644
index dde72f4a17..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp
+++ /dev/null
@@ -1,206 +0,0 @@
-/*******************************************************************************
-* Copyright 2019 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef COMMON_H
-#define COMMON_H
-
-#define GEMM_CODE_SIZE (4096L * 32)
-
-#define AVX512_UNROLL_M 48
-#define AVX512_UNROLL_N 8
-#define AVX512_UNROLL_K 1
-#define AVX512_BM 9984
-#define AVX512_BN 384
-#define AVX512_BK 768
-#define AVX512_BK_VNNI 1536
-#define AVX512_BK_TRADITIONAL 384
-#define AVX512_BLOCKING_SMALL_K 48
-#define AVX512_BN_SMALL_K 24
-
-
-#define PAGESIZE 4096
-
-#define PADD_BYTESIZE_ONPAGE(x, size) (((x) * (size) + PAGESIZE - 1) / PAGESIZE) * PAGESIZE
-#define NEXT_THR_STRIDE(x, size) (PADD_BYTESIZE_ONPAGE(x, size)) / size
-
-#include "jit_generator.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-enum {
- PARTITION_1D_ROW,
- PARTITION_1D_COL,
- PARTITION_2D_COL_MAJOR,
- PARTITION_2D = PARTITION_2D_COL_MAJOR,
-};
-
-enum {
- COPY_NONE,
- COPY_A,
-};
-
-enum {
- NO_OFFSET,
- FIX_OFFSET,
- COL_OFFSET,
- ROW_OFFSET,
-};
-
-// Alias for any dimension related variable.
-typedef long long int dim_t;
-
-typedef struct {
- // Interface arguments.
- int transa, transb, offsetc;
- dim_t m, n, k;
- dim_t lda, ldb, ldc;
- const int8_t *a;
- const uint8_t *b;
- int32_t *c;
- const float *alpha, *beta;
-
- int8_t ao, bo;
- const int32_t *co;
-
- // Kernel parameters.
- dim_t um, un, uk, bm, bn, bk;
- dim_t bn_small_k, bk_traditional, blocking_small_k;
-
- int (*copyA)(const dim_t *m, const dim_t *n, const int8_t *a,
- const dim_t *lda, const int8_t *alpha, int8_t *b,
- const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
-
- int (*copyB)(const dim_t *m, const dim_t *n, const uint8_t *a,
- const dim_t *lda, const uint8_t *alpha, uint8_t *b,
- const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
-
- int (*kernel)(const dim_t *m, const dim_t *n, const dim_t *k,
- const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
- const dim_t ldc, const int32_t *col_offset,
- const int32_t *row_offset);
-
- int (*kernel_b)(const dim_t *m, const dim_t *n, const dim_t *k,
- const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
- const dim_t ldc, const int32_t *col_offset,
- const int32_t *row_offset);
-
- int (*kernel_r)(const dim_t *m, const dim_t *n, const dim_t *k,
- const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
- const dim_t ldc, const int32_t *col_offset,
- const int32_t *row_offset);
-
- int (*kernel_c)(const dim_t *m, const dim_t *n, const dim_t *k,
- const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
- const dim_t ldc, const int32_t *col_offset,
- const int32_t *row_offset);
-
- int (*kernel_b0)(const dim_t *m, const dim_t *n, const dim_t *k,
- const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
- const dim_t ldc, const int32_t *col_offset,
- const int32_t *row_offset);
-
- int (*kernel_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k,
- const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
- const dim_t ldc, const int32_t *col_offset,
- const int32_t *row_offset);
-
- int (*kernel_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k,
- const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
- const dim_t ldc, const int32_t *col_offset,
- const int32_t *row_offset);
-
- int (*kernel_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k,
- const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
- const dim_t ldc, const int32_t *col_offset,
- const int32_t *row_offset);
-
- // Gemv kernels
- void (*gemv_s8u8s32_kernel)(const dim_t, const dim_t, const float,
- const int8_t*, const dim_t, const uint8_t*,
- const float, int32_t*);
-
- void (*gemv_u8s8s32_kernel)(const dim_t, const dim_t, const float,
- const uint8_t*, const dim_t, const int8_t*,
- const float, int32_t*);
-
- // Gemv parameters
- int swap;
-
-} blas_t;
-
-
-class jit_avx512_core_u8_copy_an_kern : public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_an_kern);
-
- public:
- jit_avx512_core_u8_copy_an_kern();
-};
-
-class jit_avx512_core_u8_copy_at_kern : public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_at_kern);
-
- public:
- jit_avx512_core_u8_copy_at_kern();
-};
-
-class jit_avx512_core_u8_copy_bn_kern : public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bn_kern);
-
- public:
- jit_avx512_core_u8_copy_bn_kern();
-};
-
-class jit_avx512_core_u8_copy_bt_kern : public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bt_kern);
-
- public:
- jit_avx512_core_u8_copy_bt_kern();
-};
-
-class jit_avx512_core_u8_copy_sum_an_kern : public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_an_kern);
-
- public:
- jit_avx512_core_u8_copy_sum_an_kern();
-};
-
-class jit_avx512_core_u8_copy_sum_at_kern : public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_at_kern);
-
- public:
- jit_avx512_core_u8_copy_sum_at_kern();
-};
-
-class jit_avx512_core_u8_copy_sum_bn_kern : public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bn_kern);
-
- public:
- jit_avx512_core_u8_copy_sum_bn_kern();
-};
-
-class jit_avx512_core_u8_copy_sum_bt_kern : public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bt_kern);
-
- public:
- jit_avx512_core_u8_copy_sum_bt_kern();
-};
-
-}
-}
-}
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp
deleted file mode 100644
index db9dd9ef97..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp
+++ /dev/null
@@ -1,28 +0,0 @@
-/*******************************************************************************
-* Copyright 2019 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "common.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-int gemm_s8u8s32_jump_to_gemv_s8u8s32(blas_t *arg);
-int gemv_threading_driver(blas_t *arg);
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp
deleted file mode 100644
index e4b8e1cde2..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp
+++ /dev/null
@@ -1,1409 +0,0 @@
-/*******************************************************************************
-* Copyright 2019 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <cstdint>
-#include <mutex>
-
-#include "common.hpp"
-#include "mkldnn_types.h"
-#include "nstl.hpp"
-#include "utils.hpp"
-
-#include "jit_avx512_core_gemm_s8u8s32.hpp"
-#include "jit_avx512_core_gemm_s8u8s32_kern.hpp"
-#include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp"
-#include "gemv.hpp"
-
-#if defined(_MSC_VER)
-#include <malloc.h>
-#endif
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-typedef struct {
- int nthrs_m, nthrs_n;
- int partition;
- int copy_type;
-} blas_thread_t;
-
-static inline void round_to_nearest(int32_t *rounded_val, double fp_val) {
- if (fp_val >= 0.) {
- fp_val += 0.5;
- if (fp_val > INT32_MAX) {
- fp_val = INT32_MAX;
- }
- } else {
- fp_val -= 0.5;
- if (fp_val < INT32_MIN) {
- fp_val = INT32_MIN;
- }
- }
- *rounded_val = (int32_t) fp_val;
-}
-
-static inline void add_results(const dim_t m, const dim_t n, const dim_t k,
- const float alpha, const float beta, const int32_t *c_partial_sum,
- const dim_t ldcp, int32_t *c_data, const dim_t ldc,
- const int32_t *a_row_sum, const int32_t *b_col_sum, const int8_t ao,
- const int8_t bo, const int32_t *co, const int offsetc)
-{
- for (dim_t j = 0; j < n; ++j) {
- for (dim_t i = 0; i < m; ++i) {
- int32_t ctemp = c_partial_sum[i + j * ldcp];
-
- if (alpha == 1.0f) {
- if (beta == 0.0f) {
- c_data[i + j * ldc] = ctemp;
- } else {
- double c_float = (double) beta
- * (double) c_data[i + j * ldc];
- c_float += (double) ctemp;
- round_to_nearest(&c_data[i + j * ldc], c_float);
- }
- } else if (alpha == -1.0f) {
- if (beta == 0.0f) {
- c_data[i + j * ldc] = -ctemp;
- } else {
- double c_float = (double) beta
- * (double) c_data[i + j * ldc];
- c_float -= (double) ctemp;
- round_to_nearest(&c_data[i + j * ldc], c_float);
- }
- } else {
- if (beta == 0.0f) {
- double c_float = alpha * (double) ctemp;
- round_to_nearest(&c_data[i + j * ldc], c_float);
- } else {
- double c_float = alpha * (double) ctemp +
- beta * (double) c_data[i + j * ldc];
- round_to_nearest(&c_data[i + j * ldc], c_float);
- }
- }
-
- if (offsetc == FIX_OFFSET) {
- c_data[i + j * ldc] += co[0];
- } else if (offsetc == ROW_OFFSET) {
- c_data[i + j * ldc] += co[j];
- } else if (offsetc == COL_OFFSET) {
- c_data[i + j * ldc] += co[i];
- }
- }
- }
-}
-
-// TODO Find a better place for those functions.
-static inline dim_t ld_padd(const dim_t x)
-{
- return ((x + ((2048 / sizeof(int32_t)) - 1)) / (2048 / sizeof(int32_t)))
- * (2048 / sizeof(int32_t)) + (64 / sizeof(int32_t));
-}
-
-void igemm_inner_kernel(const dim_t m, const dim_t n, const dim_t k,
- const int8_t *a, const uint8_t *b, float beta, int32_t *c,
- const dim_t ldc, const int32_t *a_row_sum, const int32_t *b_col_sum,
- const int32_t *co, const int offsetc, const blas_t *arg)
-{
- int8_t ao = arg->ao;
- int8_t bo = arg->bo;
- int32_t co_0 = (offsetc == NO_OFFSET)? 0 : co[0];
-
- // Since m and n are limited by blocking, stack overflow may not happen;
- // it's up to 32kB
-#if !defined(_MSC_VER)
- int32_t col_offset[m];
- int32_t row_offset[n];
-#else
- int32_t *col_offset = (int32_t *) _alloca(sizeof(*col_offset) * m);
- int32_t *row_offset = (int32_t *) _alloca(sizeof(*row_offset) * n);
-#endif
-
- int col_req = 0;
- int row_req = 0;
-
- if ((bo != 0) || (offsetc == COL_OFFSET))
- col_req = 1;
- if ((ao != 0) || (offsetc == ROW_OFFSET))
- row_req = 1;
-
- // It needs one of colum or row offsets, but it doesn't need both
- if (((ao != 0) && (bo != 0)) || ((offsetc == FIX_OFFSET) && (co_0 != 0))) {
- if ((col_req == 0) && (row_req == 0)) {
- if (m <= n) {
- col_req = 1;
- } else {
- row_req = 1;
- }
- }
- }
-
- if (col_req) {
- for (dim_t i = 0; i < m; i++)
- col_offset[i] = 0;
-
- if (offsetc == COL_OFFSET) {
- for (dim_t i = 0; i < m; i++)
- col_offset[i] += co[i];
- }
-
- if (bo != 0) {
- for (dim_t i = 0; i < m; i++)
- col_offset[i] += bo * a_row_sum[i];
- }
- }
-
- if (row_req) {
- for (dim_t i = 0; i < n; i++)
- row_offset[i] = 0;
-
- if (offsetc == ROW_OFFSET) {
- for (dim_t i = 0; i < n; i++)
- row_offset[i] += co[i];
- }
-
- if (ao != 0) {
- for (dim_t i = 0; i < n; i++)
- row_offset[i] += ao * b_col_sum[i];
- }
- }
-
- if ((offsetc == FIX_OFFSET) && (co_0 != 0)) {
- if (col_req) {
- for (dim_t i = 0; i < m; i++)
- col_offset[i] += co_0;
- } else {
- for (dim_t i = 0; i < n; i++)
- row_offset[i] += co_0;
- }
- }
-
- if ((ao != 0) && (bo != 0)) {
- if (col_req) {
- for (dim_t i = 0; i < m; i++)
- col_offset[i] += (int32_t) k * ao * bo;
- } else {
- for (dim_t i = 0; i < n; i++)
- row_offset[i] += (int32_t) k * ao * bo;
- }
- }
-
- if (col_req == 0) {
- if (row_req == 0) {
- if (beta == 0.0) {
- arg->kernel_b0(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
- row_offset);
- } else {
- arg->kernel(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
- row_offset);
- }
- } else {
- if (beta == 0.0) {
- arg->kernel_b0_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
- row_offset);
- } else {
- arg->kernel_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
- row_offset);
- }
- }
- } else {
- if (row_req == 0) {
- if (beta == 0.0) {
- arg->kernel_b0_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
- row_offset);
- } else {
- arg->kernel_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
- row_offset);
- }
- } else {
- if (beta == 0.0) {
- arg->kernel_b0_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
- row_offset);
- } else {
- arg->kernel_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
- row_offset);
- }
- }
- }
-}
-
-static inline void *align(void *ptr, size_t alignment)
-{
- return (void *) utils::rnd_up((uintptr_t) ptr, alignment);
-}
-
-static int gemm_kernel_driver(const dim_t m, const dim_t n, const dim_t k,
- const int8_t *a, const uint8_t *b, int32_t *c, const int32_t *co,
- const blas_t *arg)
-{
- dim_t lda = arg->lda;
- dim_t ldb = arg->ldb;
- dim_t ldc = arg->ldc;
- int8_t ao = arg->ao;
- int8_t bo = arg->bo;
- float alpha = *arg->alpha;
- float beta = *arg->beta;
-
- if (m <= 0 || n <= 0) {
- return 0;
- }
-
- // Padding along K dimension.
- dim_t k_padd = 0;
- if (k <= arg->bk_traditional) {
- k_padd = utils::rnd_up(k, arg->uk);
- k_padd = nstl::max(128LL, k_padd);
- } else if (k < 2 * arg->bk) {
- k_padd = utils::rnd_up(k / 2, arg->uk);
- } else {
- k_padd = arg->bk;
- }
-
- // Padding along M dimension.
- dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm),
- arg->um);
-
- // Padding along N dimension.
- dim_t n_padd = 0;
- if (k < arg->blocking_small_k) {
- n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un),
- arg->bn_small_k), arg->un);
- } else {
- n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn),
- arg->un);
- }
-
- // Padding for temporary buffer for C
- dim_t ldc_buf = ld_padd(m_padd);
-
- dim_t strideAm = (arg->transa == 0)? 1 : lda;
- dim_t strideAn = (arg->transa != 0)? 1 : lda;
- dim_t strideBm = (arg->transb == 0)? 1 : ldb;
- dim_t strideBn = (arg->transb != 0)? 1 : ldb;
-
- size_t a_buf_nelems = m_padd * k_padd;
- size_t b_buf_nelems = k_padd * n_padd;
- size_t a_row_sum_nelems = m_padd;
- size_t b_col_sum_nelems = n_padd;
-
- size_t mem_size = a_buf_nelems * sizeof(*a) + PAGE_4K
- + b_buf_nelems * sizeof(*b) + PAGE_4K
- + a_row_sum_nelems * sizeof(*c) + PAGE_4K
- + b_col_sum_nelems * sizeof(*c) + PAGE_4K;
-
- bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0);
- if (need_c_buffer) {
- size_t c_buf_nelems = ldc_buf * n_padd;
- mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K;
- }
-
- char *mem = (char *) malloc(mem_size, 128);
-
- if (!mem) {
- return -1;
- }
-
- int8_t *bufferA = (int8_t *) align(mem, PAGE_4K);
- uint8_t *bufferB = (uint8_t *) align(bufferA + a_buf_nelems, PAGE_4K);
- int32_t *a_row_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K);
- int32_t *b_col_sum = (int32_t *) align(a_row_sum + a_row_sum_nelems,
- PAGE_4K);
-
- int32_t *bufferC = NULL;
- if (need_c_buffer) {
- bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K);
- }
-
- float beta_saved = beta;
-
- int a_block_copied = 0;
- dim_t sizeM = 0;
- for (dim_t Bm = 0; Bm < m; Bm += sizeM) {
- sizeM = m - Bm;
- if (sizeM > m_padd)
- sizeM = m_padd;
-
- dim_t sizeK = 0;
- for (dim_t Bk = 0; Bk < k; Bk += sizeK) {
- sizeK = k - Bk;
- if (sizeK > k_padd)
- sizeK = k_padd;
-
- // Scale C blocks by beta only for the first time
- if (Bk == 0)
- beta = beta_saved;
- else
- beta = 1.0f;
-
- // Apply C offset when to the last k-block of the partial sum.
- int offsetc = NO_OFFSET;
- if (Bk + sizeK == k)
- offsetc = arg->offsetc;
-
- dim_t sizeN = 0;
- for (dim_t Bn = 0; Bn < n; Bn += sizeN) {
- sizeN = n - Bn;
- if (sizeN > n_padd)
- sizeN = n_padd;
-
- const uint8_t *b_block = b + Bk * strideBm + Bn * strideBn;
- arg->copyB(&sizeK, &sizeN, b_block, &ldb, NULL, bufferB, NULL,
- NULL, b_col_sum);
-
- dim_t sizeUM = 0;
- for (dim_t Um = 0; Um < sizeM; Um += sizeUM) {
- sizeUM = sizeM - Um;
- if (sizeUM > arg->um)
- sizeUM = arg->um;
-
- /*
- * Use the whole A buffer only if we have multiple B blocks
- * for k-dimension, otherwise we are wasting cache to store
- * B and C blocks.
- */
- dim_t Um_forA = 0;
- if (sizeN < n)
- Um_forA = Um;
-
- const int8_t *a_block = a + (Bm + Um) * strideAm
- + Bk * strideAn;
- if (!a_block_copied) {
- arg->copyA(&sizeK, &sizeUM, a_block, &lda, NULL,
- bufferA + Um_forA * sizeK, NULL, NULL,
- a_row_sum + Um_forA);
- }
-
- int32_t *c_block = c + (Bm + Um) + Bn * ldc;
- dim_t co_stride = 0;
- if (offsetc == FIX_OFFSET) {
- co_stride = 0;
- } else if (offsetc == ROW_OFFSET) {
- co_stride = Bn;
- } else if (offsetc == COL_OFFSET) {
- co_stride = Bm + Um;
- }
- if (need_c_buffer) {
- igemm_inner_kernel(sizeUM, sizeN, sizeK,
- bufferA + Um_forA * sizeK, bufferB, 0.0f,
- bufferC + Um, ldc_buf, a_row_sum + Um_forA,
- b_col_sum, NULL, NO_OFFSET, arg);
-
- // Finish the block adding the necessary alpha, beta
- // and offsets.
- add_results(sizeUM, sizeN, sizeK, alpha, beta,
- bufferC + Um, ldc_buf, c_block, ldc,
- a_row_sum + Um_forA, b_col_sum, ao, bo,
- co + co_stride, offsetc);
- } else {
- igemm_inner_kernel(sizeUM, sizeN, sizeK,
- bufferA + Um_forA * sizeK, bufferB, beta,
- c_block, ldc, a_row_sum + Um_forA, b_col_sum,
- co + co_stride, offsetc, arg);
- }
- }
- a_block_copied = 1;
- }
- a_block_copied = 0;
- }
- }
-
- free(mem);
-
- return 0;
-}
-
-static int kernel_driver_parallel_acopiedbcopy(const dim_t m, const dim_t n,
- const dim_t k, const int8_t *bufferA, const uint8_t *b,
- const float beta, int32_t *c, const int offsetc, const int32_t *co,
- const int32_t *a_row_sum, const blas_t *arg)
-{
- dim_t ldb = arg->ldb;
- dim_t ldc = arg->ldc;
- int8_t ao = arg->ao;
- int8_t bo = arg->bo;
- float alpha = *arg->alpha;
-
- if (m <= 0 || n <= 0) {
- return 0;
- }
-
- // Padding along N dimension.
- dim_t n_padd = 0;
- if (k < arg->blocking_small_k) {
- n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un),
- arg->bn_small_k), arg->un);
- } else {
- n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn),
- arg->un);
- }
-
- // Padding for temporary buffer for C
- dim_t ldc_buf = ld_padd(m);
-
- dim_t strideBn = (arg->transb != 0)? 1 : ldb;
-
- size_t b_buf_nelems = k * n_padd;
- size_t b_col_sum_nelems = n_padd;
-
- size_t mem_size = b_buf_nelems * sizeof(*b) + PAGE_4K
- + b_col_sum_nelems * sizeof(*c) + PAGE_4K;
-
- bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0);
- if (need_c_buffer) {
- size_t c_buf_nelems = ldc_buf * n_padd;
- mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K;
- }
-
- char *mem = (char *) malloc(mem_size, 128);
-
- if (!mem) {
- return -1;
- }
-
- uint8_t *bufferB = (uint8_t *) align(mem, PAGE_4K);
- int32_t *b_col_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K);
-
- int32_t *bufferC = NULL;
- if (need_c_buffer) {
- bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K);
- }
-
- dim_t sizeN = 0;
- for (dim_t Bn = 0; Bn < n; Bn += sizeN) {
- sizeN = n - Bn;
- if (sizeN > n_padd)
- sizeN = n_padd;
-
- // Implement the kernel here.
- const uint8_t *b_block = b + Bn * strideBn;
- arg->copyB(&k, &sizeN, b_block, &ldb, NULL, bufferB, NULL, NULL,
- b_col_sum);
-
- dim_t co_stride = 0;
- if (offsetc == FIX_OFFSET) {
- co_stride = 0;
- } else if (offsetc == ROW_OFFSET) {
- co_stride = Bn;
- } else if (offsetc == COL_OFFSET) {
- co_stride = 0;
- }
- int32_t *c_block = c + Bn * ldc;
- if (need_c_buffer) {
- igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, 0.0f, bufferC,
- ldc_buf, a_row_sum, b_col_sum, NULL, NO_OFFSET, arg);
-
- // Finish the block adding the necessary alpha, beta and offsets.
- add_results(m, sizeN, k, alpha, beta, bufferC, ldc_buf, c_block,
- ldc, a_row_sum, b_col_sum, ao, bo, co + co_stride,
- offsetc);
- } else {
- igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, beta, c_block,
- ldc, a_row_sum, b_col_sum, co + co_stride, offsetc, arg);
- }
- }
-
- free(mem);
-
- return 0;
-
-}
-
-#define N2D_MAX_AVX512 384
-#define M2D_MIN_AVX512 384
-#define VECLEN 16
-#define NCONS 1
-static inline void set_thread_opts_avx512(int *p_nthrs,
- blas_thread_t *thread_info, const blas_t *arg)
-{
- int nthrs = *p_nthrs;
- dim_t m = arg->m;
- dim_t n = arg->n;
-
- thread_info->nthrs_m = 0;
- thread_info->nthrs_n = 0;
- thread_info->copy_type = COPY_NONE; // By default don't do parallel copy.
-
- int condition_2D_bsrc = -1;
- if ((256 * m > nthrs * n) && (nthrs * m < 256 * n)) {
- condition_2D_bsrc = 1;
- } else {
- condition_2D_bsrc = 0;
- }
-
- int condition_1D_copya = 0;
- if ((m >= 1000) && (n >= nthrs * N2D_MAX_AVX512 / 4)) {
- condition_2D_bsrc = 0;
- condition_1D_copya = 1;
- }
-
- // If offset is non-zero, we need to keep 1D_copya to reduce update overhead
- if (arg->ao != 0 || arg->bo != 0 || arg->co[0] != 0
- || arg->offsetc != FIX_OFFSET) {
- condition_2D_bsrc = 0;
- condition_1D_copya = 1;
- }
-
- if (condition_2D_bsrc == 1) {
- int nthrs_m = 1;
- int nthrs_n = nthrs;
-
- while ((nthrs_n % 2 == 0) &&
- (n / nthrs > N2D_MAX_AVX512 ||
- n / nthrs_n <= N2D_MAX_AVX512 / 2) &&
- (m / nthrs_m >= 2 * M2D_MIN_AVX512) &&
- (nthrs_m < 4)) {
- nthrs_m *= 2;
- nthrs_n /= 2;
- }
-
- thread_info->nthrs_m = nthrs_m;
- thread_info->nthrs_n = nthrs_n;
- thread_info->partition = PARTITION_2D;
-
- // Reset the total number of threads that will be used.
- *p_nthrs = nthrs_m * nthrs_n;
-
- } else if (condition_1D_copya && mkldnn_thr_syncable()) {
- // Use parallel copy A algorithm
- thread_info->copy_type = COPY_A;
- thread_info->partition = PARTITION_1D_COL;
- } else {
- if ((m > n) && (m / nthrs >= VECLEN || n < NCONS * nthrs)) {
- thread_info->partition = PARTITION_1D_ROW;
- } else {
- thread_info->partition = PARTITION_1D_COL;
- }
- }
-}
-#undef N2D_MAX_AVX512
-#undef M2D_MIN_AVX512
-#undef VECLEN
-#undef NCONS
-
-static inline void partition_1d(const int ithr, const int nthrs, const dim_t n,
- dim_t *t_offset, dim_t *t_block)
-{
- dim_t band = n / nthrs;
-
- dim_t tail = n - (nthrs - 1) * band;
- if (tail > (band + 1))
- band++;
- tail = n - (nthrs - 1) * band;
-
- if (ithr < (nthrs - 1))
- *t_block = band;
- else
- *t_block = tail;
-
- *t_offset = ithr * band;
-
- if (*t_offset >= n) {
- *t_block = 0;
- *t_offset = 0;
- } else if ((*t_offset + *t_block) > n) {
- *t_block = n - *t_offset;
- }
-}
-
-static inline void partition_2d(const int ithr, int *nthrs, const int ithr_i,
- const int ithr_j, const int nthrs_m, const int nthrs_n, const dim_t m,
- const dim_t n, dim_t *p_m_disp, dim_t *p_m_band, dim_t *p_n_disp,
- dim_t *p_n_band)
-{
- dim_t m_disp = 0, n_disp = 0;
- dim_t m_band = 0, n_band = 0;
-
- int mdiv = nthrs_m;
- int ndiv = nthrs_n;
-
- dim_t m_bandt = m / mdiv; /* size per thread */
- dim_t n_bandt = n / ndiv; /* size per thread */
- int firstmgroup = mdiv - 1;
- int firstngroup = ndiv - 1;
- dim_t firstmval = m_bandt;
- dim_t firstnval = n_bandt;
-
- int mthr_used = mdiv;
- if (m - (mdiv - 1) * m_bandt > m_bandt + 1) {
- if (m - (mdiv - 1) * m_bandt > mdiv)
- ++m_bandt;
-
- firstmval = m_bandt + 1;
- mthr_used = (int) (m / firstmval);
-
- if (mthr_used * firstmval < m)
- ++mthr_used;
-
- firstmgroup = mthr_used - 1;
- }
-
- int nthr_used = ndiv;
- if (n - (ndiv - 1) * n_bandt > n_bandt + 1) {
- firstnval = n_bandt + 1;
- nthr_used = (int) (n / firstnval);
-
- if (nthr_used * firstnval < n)
- ++nthr_used;
-
- firstngroup = nthr_used - 1;
- }
-
- *nthrs = mthr_used * nthr_used;
-
- if (ithr < *nthrs) {
- if (ithr_i < firstmgroup) {
- m_band = firstmval;
- m_disp = ithr_i * firstmval;
- } else if (ithr_i <= mthr_used - 2) {
- m_band = m_bandt;
- m_disp = firstmgroup * firstmval + (ithr_i - firstmgroup) * m_bandt;
- } else {
- m_disp = firstmgroup * firstmval
- + (mthr_used - 1 - firstmgroup) * m_bandt;
- m_band = nstl::max(0LL, m - m_disp);
- }
-
- if (ithr_j < firstngroup) {
- n_band = firstnval;
- n_disp = ithr_j * firstnval;
- } else if (ithr_j <= nthr_used - 2) {
- n_band = n_bandt;
- n_disp = firstngroup * firstnval + (ithr_j - firstngroup) * n_bandt;
- } else {
- n_disp = firstngroup * firstnval
- + (nthr_used - 1 - firstngroup) * n_bandt;
- n_band = nstl::max(0LL, n - n_disp);
- }
- m_disp = nstl::max(nstl::min(m_disp, m - 1), 0LL);
- n_disp = nstl::max(nstl::min(n_disp, n - 1), 0LL);
- }
-
- if (ithr < *nthrs) {
- *p_m_disp = m_disp;
- *p_n_disp = n_disp;
- *p_m_band = m_band;
- *p_n_band = n_band;
- } else {
- *p_m_disp = 0;
- *p_n_disp = 0;
- *p_m_band = 0;
- *p_n_band = 0;
- }
-
- return;
-}
-
-static inline void decompose_matrices(const int ithr, int *nthrs, dim_t *m,
- dim_t *n, dim_t *k, const int8_t **a, const uint8_t **b, int32_t **c,
- const int32_t **co, const blas_thread_t *thread_info, const blas_t *arg)
-{
- dim_t strideAm = (arg->transa == 0)? 1 : arg->lda;
- dim_t strideBn = (arg->transb != 0)? 1 : arg->ldb;
- int offsetc = arg->offsetc;
-
- switch (thread_info->partition) {
- case PARTITION_1D_ROW:
- {
- dim_t offset = 0;
- dim_t block = 0;
- partition_1d(ithr, *nthrs, arg->m, &offset, &block);
-
- *m = block;
- *n = arg->n;
- *k = arg->k;
-
- // Set matrix A.
- *a = arg->a + offset * strideAm;
-
- // Set matrix B.
- *b = arg->b;
-
- // Set matrix C.
- *c = arg->c + offset;
-
- // Set offset vector for C matrix
- dim_t co_stride = 0;
- if (offsetc == FIX_OFFSET) {
- co_stride = 0;
- } else if (offsetc == ROW_OFFSET) {
- co_stride = 0;
- } else if (offsetc == COL_OFFSET) {
- co_stride = offset;
- }
- *co = arg->co + co_stride;
- break;
- }
-
- case PARTITION_1D_COL:
- {
- dim_t offset = 0;
- dim_t block = 0;
- partition_1d(ithr, *nthrs, arg->n, &offset, &block);
-
- *m = arg->m;
- *n = block;
- *k = arg->k;
-
- // Set matrix A.
- *a = arg->a;
-
- // Set matrix B.
- *b = arg->b + offset * strideBn;
-
- // Set matrix C.
- *c = arg->c + offset * arg->ldc;
-
- // Set offset vector for C matrix
- dim_t co_stride = 0;
- if (offsetc == FIX_OFFSET) {
- co_stride = 0;
- } else if (offsetc == ROW_OFFSET) {
- co_stride = offset;
- } else if (offsetc == COL_OFFSET) {
- co_stride = 0;
- }
- *co = arg->co + co_stride;
- break;
- }
-
- case PARTITION_2D_COL_MAJOR:
- {
- int nthrs_m = thread_info->nthrs_m;
- int nthrs_n = thread_info->nthrs_n;
- int ithr_i = ithr % nthrs_m;
- int ithr_j = ithr / nthrs_m;
-
- dim_t m_disp = 0;
- dim_t m_band = 0;
- dim_t n_disp = 0;
- dim_t n_band = 0;
-
- partition_2d(ithr, nthrs, ithr_i, ithr_j, nthrs_m, nthrs_n,
- arg->m, arg->n, &m_disp, &m_band, &n_disp, &n_band);
-
- *m = m_band;
- *n = n_band;
- *k = arg->k;
-
- // Set matrix A.
- *a = arg->a + m_disp * strideAm;
-
- // Set matrix B.
- *b = arg->b + n_disp * strideBn;
-
- // Set matrix C.
- *c = arg->c + m_disp + n_disp * arg->ldc;
-
- // Set offset vector for C matrix
- dim_t co_stride = 0;
- if (offsetc == FIX_OFFSET) {
- co_stride = 0;
- } else if (offsetc == ROW_OFFSET) {
- co_stride = n_disp;
- } else if (offsetc == COL_OFFSET) {
- co_stride = m_disp;
- }
- *co = arg->co + co_stride;
- break;
- }
- }
-}
-
-#define MULTIPLIER 10
-static int parallel_a_copy(const int ithr, const int nthrs, const dim_t m,
- const dim_t n, const dim_t k, const int8_t *a, const uint8_t *b,
- int32_t *c, const int32_t *co, const blas_t *arg,
- char **p_shared_mem)
-{
- const dim_t lda = arg->lda;
- const dim_t ldb = arg->ldb;
- const dim_t strideAm = (arg->transa == 0)? 1 : lda;
- const dim_t strideAn = (arg->transa != 0)? 1 : lda;
- const dim_t strideBm = (arg->transb == 0)? 1 : ldb;
-
- // Padding along M dimension.
- dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm),
- arg->um);
-
- // Padding along K dimension.
- dim_t k_padd = 0;
- if (k <= arg->bk_traditional) {
- k_padd = utils::rnd_up(k, arg->uk);
- k_padd = nstl::max(128LL, k_padd);
- } else if (k < 2 * arg->bk) {
- k_padd = utils::rnd_up(k / 2, arg->uk);
- } else {
- k_padd = arg->bk;
- }
-
- m_padd *= nthrs > MULTIPLIER ? MULTIPLIER : nthrs;
- if (m_padd > m) {
- m_padd = utils::rnd_up(m, arg->um);
- }
-
- size_t a_buf_nelems = m_padd * k_padd;
-
- // Allocate shared memory for A and its row sum buffers in master thread.
- if (ithr == 0) { // If thread master
- size_t a_row_sum_nelems = m_padd;
-
- size_t mem_size = (a_buf_nelems * sizeof(*a) + PAGE_4K)
- + a_row_sum_nelems * sizeof(*c) + PAGE_4K;
-
- *p_shared_mem = (char *) malloc(mem_size, 128);
-
- }
- mkldnn_thr_barrier();
-
- char *mem = *p_shared_mem;
- int8_t *bufferA = (int8_t *) align(mem, PAGE_4K);
- int32_t *a_row_sum = (int32_t *) align(bufferA + a_buf_nelems, PAGE_4K);
-
- if (!mem) {
- return -1;
- }
-
- int result = 0; // Return status
-
- dim_t sizeK = 0;
- for (dim_t Bk = 0; Bk < k; Bk += sizeK) {
- sizeK = k - Bk;
- if (sizeK > k_padd)
- sizeK = k_padd;
-
- // Scale C blocks by beta only for the first term of partial sum.
- float beta = 1.0f;
- if (Bk == 0)
- beta = *(arg->beta);
-
- // Apply C offset for the last k-block of the partial sum.
- int offsetc = NO_OFFSET;
- if (Bk + sizeK == k)
- offsetc = arg->offsetc;
-
- dim_t sizeM = 0;
- for (dim_t Bm = 0; Bm < m; Bm += sizeM) {
- sizeM = m - Bm;
- if (sizeM > m_padd)
- sizeM = m_padd;
-
- if (ithr < nthrs) {
- dim_t band = (sizeM + nthrs - 1) / nthrs;
- band = utils::rnd_up(band, arg->um);
-
- dim_t offset = band * ithr;
-
- // If offset is too large don't use that thread for copying.
- if (offset >= sizeM) {
- offset = 0;
- band = 0;
- }
-
- // Handle the tail of the copy.
- if (offset + band > sizeM) {
- band = sizeM - offset;
- }
-
- if (band > 0) {
- const int8_t *a_block = a + (Bm + offset) * strideAm
- + Bk * strideAn;
- arg->copyA(&sizeK, &band, a_block, &lda, NULL,
- bufferA + offset * sizeK, NULL, NULL,
- a_row_sum + offset);
- }
- }
- mkldnn_thr_barrier(); // Wait for finishing parallel copy.
-
- const uint8_t *b_block = b + Bk * strideBm;
- int32_t *c_block = c + Bm;
- dim_t co_stride = 0;
- if (offsetc == FIX_OFFSET) {
- co_stride = 0;
- } else if (offsetc == ROW_OFFSET) {
- co_stride = 0;
- } else if (offsetc == COL_OFFSET) {
- co_stride = Bm;
- }
-
- result = kernel_driver_parallel_acopiedbcopy(sizeM, n, sizeK,
- bufferA, b_block, beta, c_block, offsetc, co + co_stride,
- a_row_sum, arg);
-
- mkldnn_thr_barrier(); // Wait for kernel computations to finish.
- }
- }
-
- // Free memory allocated in master thread
- if (ithr == 0) {
- free(mem);
- }
-
- return result;
-}
-#undef MULTIPLIER
-
-static inline void get_omp_thread_count(dim_t m, dim_t n, dim_t k,
- double fp_per_cycle, int *nthrs)
-{
- double omp_overhead_small_core = 3.0e+3;
- double omp_intercept_big_core = 4.0e+3;
- double omp_slope_big_core = 5.0e+2;
-
- double gemm_cycles = 8.0 * m * n * k / fp_per_cycle;
-
- int i = *nthrs;
-
- // Use a different model for omp overheads if nthrs is <= 4
- if (*nthrs <= 4 && omp_overhead_small_core > 0) {
- double omp_cycles = omp_overhead_small_core;
- if (gemm_cycles < omp_cycles) {
- *nthrs = 1;
- return;
- } else {
- while (i > 1) {
- if (omp_cycles * i < gemm_cycles * (i - 1)) break;
- --i;
- }
- }
- } else {
- if (gemm_cycles < (omp_intercept_big_core + 2 * omp_slope_big_core)) {
- *nthrs = 1;
- return;
- }
-
- // adaptive decrement to march faster·
- while (i > 1) {
- double omp_cycles = omp_intercept_big_core + i * omp_slope_big_core;
- if (omp_cycles * i < gemm_cycles * (i - 1))
- break;
-
- if (i < 10)
- i -= 2;
- else if (i < 30)
- i -= 4;
- else
- i -= 8;
- }
- }
-
- if (i < 1)
- i = 1;
-
- *nthrs = i;
-}
-
-#define CACHE_LINE_SIZE 64
-static int gemm_threading_driver(blas_t *arg)
-{
- if ((arg->m <= 0) || (arg->n <= 0))
- return mkldnn_success;
-
- if (gemm_s8u8s32_jump_to_gemv_s8u8s32(arg)) {
- return mkldnn_success;
- }
-
- int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
- get_omp_thread_count(arg->m, arg->n, arg->k, 64.0, &nthr);
-
- if (nthr == 1) {
- return gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a, arg->b,
- arg->c, arg->co, arg);
- }
-
- int *results = (int *) malloc(sizeof(*results) * nthr * CACHE_LINE_SIZE,
- PAGE_4K);
-
- if (!results) {
- return -1;
- }
-
- for (int i = 0; i < nthr; i++) {
- results[i * CACHE_LINE_SIZE] = 0; // Initialize to success
- }
-
- char *shared_mem = NULL;
-
- parallel(nthr, [&](const int ithr, const int nthr) {
- int nthrs = nthr;
- if (nthrs == 1) {
- results[0] = gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a,
- arg->b, arg->c, arg->co, arg);
- } else {
- blas_thread_t thread_info;
- set_thread_opts_avx512(&nthrs, &thread_info, arg);
-
- const int8_t *a = NULL;
- const uint8_t *b = NULL;
- int32_t *c = NULL;
- const int32_t *co = NULL;
- dim_t m = -1;
- dim_t n = -1;
- dim_t k = -1;
- decompose_matrices(ithr, &nthrs, &m, &n, &k, &a, &b, &c, &co,
- &thread_info, arg);
-
- if (ithr < nthrs) {
- switch (thread_info.copy_type) {
- case COPY_A:
- results[ithr * CACHE_LINE_SIZE] =
- parallel_a_copy(ithr, nthrs, m, n, k, a, b, c, co, arg,
- &shared_mem);
- break;
-
- default:
- case COPY_NONE:
- results[ithr * CACHE_LINE_SIZE] =
- gemm_kernel_driver(m, n, k, a, b, c, co, arg);
- break;
- }
- }
- }
- });
-
- int result = 0; // Initialize to success
- for (int i = 0; i < nthr; i++) {
- if (results[i] != 0) {
- result = results[i * CACHE_LINE_SIZE];
- break;
- }
- }
-
- free(results);
-
- return result;
-}
-#undef CACHE_LINE_SIZE
-
-static jit_avx512_core_u8_copy_an_kern *copy_an;
-static jit_avx512_core_u8_copy_at_kern *copy_at;
-static jit_avx512_core_u8_copy_bn_kern *copy_bn;
-static jit_avx512_core_u8_copy_bt_kern *copy_bt;
-static jit_avx512_core_u8_copy_sum_an_kern *copy_sum_an;
-static jit_avx512_core_u8_copy_sum_at_kern *copy_sum_at;
-static jit_avx512_core_u8_copy_sum_bn_kern *copy_sum_bn;
-static jit_avx512_core_u8_copy_sum_bt_kern *copy_sum_bt;
-static jit_avx512_core_gemm_s8u8s32_kern *kernel;
-static jit_avx512_core_gemm_s8u8s32_kern *kernel_b;
-static jit_avx512_core_gemm_s8u8s32_kern *kernel_r;
-static jit_avx512_core_gemm_s8u8s32_kern *kernel_c;
-static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0;
-static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_b;
-static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_r;
-static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_c;
-static jit_avx512_core_gemv_s8u8s32_kern *gemv_s8u8s32_kernel;
-static jit_avx512_core_gemv_s8u8s32_kern *gemv_u8s8s32_kernel;
-
-static void jit_init(blas_t *arg)
-{
- static int (*copyAn)(const dim_t *m, const dim_t *n, const int8_t *a,
- const dim_t *lda, const int8_t *alpha, int8_t *b,
- const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
-
- static int (*copyAt)(const dim_t *m, const dim_t *n, const int8_t *a,
- const dim_t *lda, const int8_t *alpha, int8_t *b,
- const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
-
- static int (*copyBn)(const dim_t *m, const dim_t *n, const uint8_t *a,
- const dim_t *lda, const uint8_t *alpha, uint8_t *b,
- const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
-
- static int (*copyBt)(const dim_t *m, const dim_t *n, const uint8_t *a,
- const dim_t *lda, const uint8_t *alpha, uint8_t *b,
- const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
-
- static int (*copySumAn)(const dim_t *m, const dim_t *n, const int8_t *a,
- const dim_t *lda, const int8_t *alpha, int8_t *b,
- const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
-
- static int (*copySumAt)(const dim_t *m, const dim_t *n, const int8_t *a,
- const dim_t *lda, const int8_t *alpha, int8_t *b,
- const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
-
- static int (*copySumBn)(const dim_t *m, const dim_t *n, const uint8_t *a,
- const dim_t *lda, const uint8_t *alpha, uint8_t *b,
- const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
-
- static int (*copySumBt)(const dim_t *m, const dim_t *n, const uint8_t *a,
- const dim_t *lda, const uint8_t *alpha, uint8_t *b,
- const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
-
- static int (*kern)(const dim_t *m, const dim_t *n, const dim_t *k,
- const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
- const dim_t ldc, const int32_t *col_offset,
- const int32_t *row_offset);
-
- static int (*kern_b)(const dim_t *m, const dim_t *n, const dim_t *k,
- const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
- const dim_t ldc, const int32_t *col_offset,
- const int32_t *row_offset);
-
- static int (*kern_r)(const dim_t *m, const dim_t *n, const dim_t *k,
- const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
- const dim_t ldc, const int32_t *col_offset,
- const int32_t *row_offset);
-
- static int (*kern_c)(const dim_t *m, const dim_t *n, const dim_t *k,
- const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
- const dim_t ldc, const int32_t *col_offset,
- const int32_t *row_offset);
-
- static int (*kern_b0)(const dim_t *m, const dim_t *n, const dim_t *k,
- const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
- const dim_t ldc, const int32_t *col_offset,
- const int32_t *row_offset);
-
- static int (*kern_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k,
- const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
- const dim_t ldc, const int32_t *col_offset,
- const int32_t *row_offset);
-
- static int (*kern_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k,
- const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
- const dim_t ldc, const int32_t *col_offset,
- const int32_t *row_offset);
-
- static int (*kern_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k,
- const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
- const dim_t ldc, const int32_t *col_offset,
- const int32_t *row_offset);
-
- static void (*gemv_s8u8s32_kern)(const dim_t, const dim_t, const float,
- const int8_t*, const dim_t, const uint8_t*,
- const float, int32_t*);
-
- static void (*gemv_u8s8s32_kern)(const dim_t, const dim_t, const float,
- const uint8_t*, const dim_t, const int8_t*,
- const float, int32_t*);
-
- if (mayiuse(avx512_core_vnni)) {
- arg->um = AVX512_UNROLL_M;
- arg->un = AVX512_UNROLL_N;
- arg->uk = AVX512_UNROLL_K;
- arg->bm = AVX512_BM;
- arg->bn = AVX512_BN;
- arg->bk = AVX512_BK_VNNI;
-
- arg->bk_traditional = AVX512_BK_TRADITIONAL;
- arg->bn_small_k = AVX512_BN_SMALL_K;
- arg->blocking_small_k = AVX512_BLOCKING_SMALL_K;
- } else {
- arg->um = AVX512_UNROLL_M;
- arg->un = AVX512_UNROLL_N;
- arg->uk = AVX512_UNROLL_K;
- arg->bm = AVX512_BM;
- arg->bn = AVX512_BN;
- arg->bk = AVX512_BK;
-
- arg->bk_traditional = AVX512_BK_TRADITIONAL;
- arg->bn_small_k = AVX512_BN_SMALL_K;
- arg->blocking_small_k = AVX512_BLOCKING_SMALL_K;
- }
-
- static std::once_flag initialized;
- std::call_once(initialized, []{
-
- copy_an = new jit_avx512_core_u8_copy_an_kern();
- copy_at = new jit_avx512_core_u8_copy_at_kern();
- copy_bn = new jit_avx512_core_u8_copy_bn_kern();
- copy_bt = new jit_avx512_core_u8_copy_bt_kern();
-
- copy_sum_an = new jit_avx512_core_u8_copy_sum_an_kern();
- copy_sum_at = new jit_avx512_core_u8_copy_sum_at_kern();
- copy_sum_bn = new jit_avx512_core_u8_copy_sum_bn_kern();
- copy_sum_bt = new jit_avx512_core_u8_copy_sum_bt_kern();
-
- kernel = new jit_avx512_core_gemm_s8u8s32_kern(false, false, false);
- kernel_b = new jit_avx512_core_gemm_s8u8s32_kern(false, true, true);
- kernel_r = new jit_avx512_core_gemm_s8u8s32_kern(false, false, true);
- kernel_c = new jit_avx512_core_gemm_s8u8s32_kern(false, true, false);
- kernel_b0 = new jit_avx512_core_gemm_s8u8s32_kern(true, false, false);
- kernel_b0_b = new jit_avx512_core_gemm_s8u8s32_kern(true, true, true);
- kernel_b0_r = new jit_avx512_core_gemm_s8u8s32_kern(true, false, true);
- kernel_b0_c = new jit_avx512_core_gemm_s8u8s32_kern(true, true, false);
-
- gemv_s8u8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern();
- gemv_u8s8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern();
-
-
- copyAn = copy_an->getCode<int (*)(const dim_t *, const dim_t *,
- const int8_t *, const dim_t *, const int8_t *, int8_t *,
- const dim_t *, const dim_t *, int32_t *)>();
-
- copyAt = copy_at->getCode<int (*)(const dim_t *, const dim_t *,
- const int8_t *, const dim_t *, const int8_t *, int8_t *,
- const dim_t *, const dim_t *, int32_t *)>();
-
- copyBn = copy_bn->getCode<int (*)(const dim_t *, const dim_t *,
- const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
- const dim_t *, const dim_t *, int32_t *)>();
-
- copyBt = copy_bt->getCode<int (*)(const dim_t *, const dim_t *,
- const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
- const dim_t *, const dim_t *, int32_t *)>();
-
- copySumAn = copy_sum_an->getCode<int (*)(const dim_t *, const dim_t *,
- const int8_t *, const dim_t *, const int8_t *, int8_t *,
- const dim_t *, const dim_t *, int32_t *)>();
-
- copySumAt = copy_sum_at->getCode<int (*)(const dim_t *, const dim_t *,
- const int8_t *, const dim_t *, const int8_t *, int8_t *,
- const dim_t *, const dim_t *, int32_t *)>();
-
- copySumBn = copy_sum_bn->getCode<int (*)(const dim_t *, const dim_t *,
- const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
- const dim_t *, const dim_t *, int32_t *)>();
-
- copySumBt = copy_sum_bt->getCode<int (*)(const dim_t *, const dim_t *,
- const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
- const dim_t *, const dim_t *, int32_t *)>();
-
- kern = kernel->getCode<int (*)(const dim_t *, const dim_t *,
- const dim_t *, const float *, const int8_t *, const uint8_t *,
- int32_t *, const dim_t, const int32_t *, const int32_t *)>();
-
- kern_b = kernel_b->getCode<int (*)(const dim_t *, const dim_t *,
- const dim_t *, const float *, const int8_t *, const uint8_t *,
- int32_t *, const dim_t, const int32_t *, const int32_t *)>();
-
- kern_r = kernel_r->getCode<int (*)(const dim_t *, const dim_t *,
- const dim_t *, const float *, const int8_t *, const uint8_t *,
- int32_t *, const dim_t, const int32_t *, const int32_t *)>();
-
- kern_c = kernel_c->getCode<int (*)(const dim_t *, const dim_t *,
- const dim_t *, const float *, const int8_t *, const uint8_t *,
- int32_t *, const dim_t, const int32_t *, const int32_t *)>();
-
- kern_b0 = kernel_b0->getCode<int (*)(const dim_t *, const dim_t *,
- const dim_t *, const float *, const int8_t *, const uint8_t *,
- int32_t *, const dim_t, const int32_t *, const int32_t *)>();
-
- kern_b0_b = kernel_b0_b->getCode<int (*)(const dim_t *, const dim_t *,
- const dim_t *, const float *, const int8_t *, const uint8_t *,
- int32_t *, const dim_t, const int32_t *, const int32_t *)>();
-
- kern_b0_r = kernel_b0_r->getCode<int (*)(const dim_t *, const dim_t *,
- const dim_t *, const float *, const int8_t *, const uint8_t *,
- int32_t *, const dim_t, const int32_t *, const int32_t *)>();
-
- kern_b0_c = kernel_b0_c->getCode<int (*)(const dim_t *, const dim_t *,
- const dim_t *, const float *, const int8_t *, const uint8_t *,
- int32_t *, const dim_t, const int32_t *, const int32_t *)>();
-
- gemv_s8u8s32_kern =
- gemv_s8u8s32_kernel -> generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t>
- (mayiuse(avx512_core_vnni));
- gemv_u8s8s32_kern =
- gemv_u8s8s32_kernel -> generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t>
- (mayiuse(avx512_core_vnni));
- });
-
- if (arg->bo == 0) { // No need to compute A row sum if bo is zero
- if (arg->transa == 0) {
- arg->copyA = copyAn;
- } else {
- arg->copyA = copyAt;
- }
- } else {
- if (arg->transa == 0) {
- arg->copyA = copySumAn;
- } else {
- arg->copyA = copySumAt;
- }
- }
-
- if (arg->ao == 0) { // No need to compute B column sum if ao is zero
- if (arg->transb == 0) {
- arg->copyB = copyBn;
- } else {
- arg->copyB = copyBt;
- }
- } else {
- if (arg->transb == 0) {
- arg->copyB = copySumBn;
- } else {
- arg->copyB = copySumBt;
- }
- }
-
- arg->kernel = kern;
- arg->kernel_b = kern_b;
- arg->kernel_r = kern_r;
- arg->kernel_c = kern_c;
- arg->kernel_b0 = kern_b0;
- arg->kernel_b0_b = kern_b0_b;
- arg->kernel_b0_r = kern_b0_r;
- arg->kernel_b0_c = kern_b0_c;
- arg -> gemv_s8u8s32_kernel = gemv_s8u8s32_kern;
- arg -> gemv_u8s8s32_kernel = gemv_u8s8s32_kern;
-}
-
-mkldnn_status_t jit_avx512_core_gemm_s8u8s32(
- const char *transA, const char *transB, const char *offsetC,
- const int *m, const int *n, const int *k,
- const float *alpha, const int8_t *a, const int *lda, const int8_t *oa,
- const uint8_t *b, const int *ldb, const int8_t *ob,
- const float *beta, int32_t *c, const int *ldc, const int32_t *oc)
-{
- char transa = *transA;
- char transb = *transB;
- char offsetc = *offsetC;
-
- blas_t args;
-
- // Initialize blas structure
- args.m = *m;
- args.n = *n;
- args.k = *k;
- args.alpha = alpha;
- args.a = a;
- args.lda = *lda;
- args.b = b;
- args.ldb = *ldb;
- args.beta = beta;
- args.c = c;
- args.ldc = *ldc;
- args.transa = (transa == 'N' || transa == 'n') ? 0 : 1;
- args.transb = (transb == 'N' || transb == 'n') ? 0 : 1;
- args.um = 0;
- args.un = 0;
- args.bm = 0;
- args.bn = 0;
- args.bk = 0;
- args.copyA = NULL;
- args.copyB = NULL;
- args.kernel = NULL;
- args.kernel_b0 = NULL;
- args.ao = *oa;
- args.bo = *ob;
- args.co = oc;
-
- if (offsetc == 'F' || offsetc == 'f') {
- args.offsetc = FIX_OFFSET;
- } else if (offsetc == 'R' || offsetc == 'r') {
- args.offsetc = ROW_OFFSET;
- } else { // offsetc == 'C' || offsetc == 'c'
- args.offsetc = COL_OFFSET;
- }
-
- jit_init(&args);
- int result = gemm_threading_driver(&args);
-
- return (result < 0) ? mkldnn_out_of_memory : mkldnn_success;
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp
deleted file mode 100644
index b2e2902a12..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp
+++ /dev/null
@@ -1,38 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef JIT_AVX512_CORE_GEMM_S8U8S32_HPP
-#define JIT_AVX512_CORE_GEMM_S8U8S32_HPP
-
-#include <cstdint>
-#include "mkldnn_types.h"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-mkldnn_status_t jit_avx512_core_gemm_s8u8s32(
- const char *transA, const char *transB, const char *offsetC,
- const int *m, const int *n, const int *k,
- const float *alpha, const int8_t *a, const int *lda, const int8_t *oa,
- const uint8_t *b, const int *ldb, const int8_t *ob,
- const float *beta, int32_t *c, const int *ldc, const int32_t *oc);
-
-}
-}
-}
-
-#endif // JIT_AVX512_CORE_GEMM_S8U8S32_HPP
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp
deleted file mode 100644
index 57554a1852..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp
+++ /dev/null
@@ -1,539 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "jit_avx512_core_gemm_s8u8s32_kern.hpp"
-
-
-#ifdef _WIN32
-static const bool is_windows = 1;
-#else
-static const bool is_windows = 0;
-#endif
-
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace Xbyak;
-
-
-
-
-// Convert between vector register lengths.
-static inline Xmm make_xmm(const Xmm &v) { return Xmm(v.getIdx()); }
-static inline Ymm make_ymm(const Xmm &v) { return Ymm(v.getIdx()); }
-
-// Load from or store to C.
-void jit_avx512_core_gemm_s8u8s32_kern::c_load(const Xbyak::Xmm &dst,
- const Xbyak::Address &src, int nelems)
-{
- switch (nelems) {
- default: vmovups(dst, src); break;
- case 8: vmovups(make_ymm(dst), src); break;
- case 4: vmovups(make_xmm(dst), src); break;
- case 2: vmovlps(make_xmm(dst), src); break;
- case 1: vmovss(make_xmm(dst), src); break;
- }
-}
-void jit_avx512_core_gemm_s8u8s32_kern::c_store(const Xbyak::Address &dst,
- const Xbyak::Xmm &src, int nelems)
-{
- switch (nelems) {
- default: vmovups(dst, src); break;
- case 8: vmovups(dst, make_ymm(src)); break;
- case 4: vmovups(dst, make_xmm(src)); break;
- case 2: vmovsd(dst, make_xmm(src)); break;
- case 1: vmovss(dst, make_xmm(src)); break;
- }
-}
-
-// Perform length-4 dot product accumulations of unsigned and signed bytes
-// in parallel.
-// Use vpdpbusd if VNNI available, otherwise emulate.
-void jit_avx512_core_gemm_s8u8s32_kern::dot_product(const Xmm &dst,
- const Xmm &src1, const Xmm &src2)
-{
- if (vnni)
- vpdpbusd(dst, src1, src2);
- else {
- vpmaddubsw(dp_scratch, src1, src2);
- vpmaddwd(dp_scratch, ones, dp_scratch);
- vpaddd(dst, dst, dp_scratch);
- }
-}
-
-// Inner kernel.
-void jit_avx512_core_gemm_s8u8s32_kern::kernel_loop(int unroll_m, int unroll_n,
- bool cfetch)
-{
- int um_vecs = (unroll_m + 15) >> 4;
- Label label_kernel_loop;
-
- L_aligned(label_kernel_loop); {
- for (int h = 0; h < 4; h++) {
- for (int j = 0; j < unroll_n; j++) {
- const Zmm b = b_regs[j & 1];
-
- vpbroadcastd(b, ptr[BO + isize *
- (2 * j + 2 * h * unroll_n - offset_b)]);
- dot_product(c_regs[0][j], b, a_regs[0]);
-
- if (j == 1 && !(h & 1))
- prefetch_b(ptr[BO + isize * (prefetch_size_b
- + 2 * h * unroll_n - offset_b)]);
- else if (j % 3 == 0)
- prefetch_a(ptr[AO + isize * (prefetch_size_a
- + 32 * (j / 3) + 2 * h * unroll_m - offset_a)]);
-
- for (int i = 1; i < um_vecs; i++)
- dot_product(c_regs[i][j], b, a_regs[i]);
-
- if (cfetch && (j == std::min(1, unroll_n - 1))) {
- if (h == 3)
- lea(CO2, ptr[CO2 + LDC]);
- else if (h < um_vecs)
- prefetch_c(ptr[CO2 + (16 * h * size)]);
- }
-
- if (h == 3 && j == std::min(3, unroll_n - 1))
- lea(AA, ptr[AA + (32 * isize)]);
- }
-
- for (int i = 0; i < um_vecs; i++)
- vmovups(a_regs[i], ptr[AO + isize *
- (32 * i + 2 * (h + 1) * unroll_m - offset_a)]);
-
- if (h == 2)
- prefetch_x(ptr[AA - (offset_a * isize)]);
- }
-
- add(AO, 8 * isize * unroll_m);
- add(BO, 8 * isize * unroll_n);
- sub(LoopCount, 1);
- jg(label_kernel_loop, T_NEAR);
- }
-}
-
-// k remainder loop for kernel.
-void jit_avx512_core_gemm_s8u8s32_kern::remainder_kernel(int unroll_m,
- int unroll_n, int unroll_k, int bwidth)
-{
- if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N)
- || (unroll_m < 0) || (unroll_n < 0))
- return;
-
- int um_vecs = (unroll_m + 15) >> 4;
-
- for (int h = 0; h < unroll_k; h++) {
- for (int j = 0; j < unroll_n; j++) {
- Zmm b = b_regs[j & 1];
- auto b_src = ptr[BO + (-isize * offset_b
- + bwidth * (j + h * unroll_n))];
-
- switch (bwidth) {
- case 4:
- vpbroadcastd(b, b_src);
- break;
- case 2:
- vpbroadcastw(b, b_src);
- break;
- case 1:
- vpbroadcastb(b, b_src);
- break;
- }
- for (int i = 0; i < um_vecs; i++)
- dot_product(c_regs[i][j], b, a_regs[i]);
- }
-
- if (unroll_k > 1) {
- for (int i = 0; i < um_vecs; i++)
- vmovups(a_regs[i], ptr[AO + isize * (32 * i
- + (h + 1) * 2 * unroll_m - offset_a)]);
- }
- }
-
- add(AO, unroll_k * unroll_m * bwidth);
- add(BO, unroll_k * unroll_n * bwidth);
-}
-
-// Inner loop.
-void jit_avx512_core_gemm_s8u8s32_kern::innerloop(int unroll_m, int unroll_n)
-{
- if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N)
- || (unroll_m < 0) || (unroll_n < 0))
- return;
-
- int um_vecs = (unroll_m + 15) >> 4;
- int stage1 = unroll_n, stage2 = unroll_n;
-
- Label label_kernel_loop_1, label_k_main_loop_2, label_kernel_loop_2;
- Label label_k_main_loop_3, label_kernel_loop_3;
- Label label_k_remainder_loop_begin, label_k_rem_4, label_k_rem_2;
- Label label_k_rem_1, label_update_begin;
-
- mov(AO, A);
- for (int i = 0; i < um_vecs; i++)
- vmovups(a_regs[i], ptr[AO + isize * (32 * i - offset_a)]);
-
- mov(LoopCount, K);
- sar(LoopCount, 4);
- jle(label_k_remainder_loop_begin, T_NEAR);
-
- // Main k loops, broken into three parts to time C prefetching.
- sub(LoopCount, stage1 + stage2);
- jle(label_k_main_loop_2, T_NEAR);
-
- kernel_loop(unroll_m, unroll_n, false);
-
- L_aligned(label_k_main_loop_2);
- lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]);
- add(LoopCount, stage1);
- jle(label_k_main_loop_3, T_NEAR);
-
- kernel_loop(unroll_m, unroll_n, true);
-
- L_aligned(label_k_main_loop_3);
- lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]);
- add(LoopCount, stage2);
- jle(label_k_remainder_loop_begin, T_NEAR);
-
- kernel_loop(unroll_m, unroll_n, true);
-
- // k remainder handling
- L_aligned(label_k_remainder_loop_begin);
- mov(LoopCount, K);
- test(LoopCount, 8);
- je(label_k_rem_4, T_NEAR);
-
- remainder_kernel(unroll_m, unroll_n, 2, 4);
-
- L_aligned(label_k_rem_4);
- mov(LoopCount, K);
- test(LoopCount, 4);
- je(label_k_rem_2, T_NEAR);
-
- remainder_kernel(unroll_m, unroll_n, 1, 4);
-
- L_aligned(label_k_rem_2);
- mov(LoopCount, K);
- test(LoopCount, 2);
- je(label_k_rem_1, T_NEAR);
-
- Zmm zero = zmm6;
- Zmm tmp = zmm5;
-
- vpxorq(zero, zero, zero);
- for (int i = 0; i < um_vecs; i++) {
- Zmm a = a_regs[i];
- vbroadcasti64x4(a, ptr[AO + isize * (16 * i - offset_a)]);
- vpunpcklwd(tmp, a, zero);
- vpunpckhwd(a, a, zero);
- vshufi32x4(a, tmp, a, 0x44);
- vshufi32x4(a, a, a, 0xD8);
- }
-
- remainder_kernel(unroll_m, unroll_n, 1, 2);
-
- L_aligned(label_k_rem_1);
- mov(LoopCount, K);
- test(LoopCount, 1);
- je(label_update_begin, T_NEAR);
-
- vpxorq(zero, zero, zero);
- for (int i = 0; i < um_vecs; i++) {
- Zmm a = a_regs[i];
- vbroadcasti32x4(a, ptr[AO + isize * (8 * i - offset_a)]);
- vpunpcklbw(tmp, a, zero);
- vpunpckhbw(a, a, zero);
- vinsertf128(make_ymm(a), make_ymm(tmp), make_xmm(a), 1);
- vpunpcklwd(tmp, a, zero);
- vpunpckhwd(a, a, zero);
- vshufi32x4(a, tmp, a, 0x44);
- vshufi32x4(a, a, a, 0xD8);
- }
-
- remainder_kernel(unroll_m, unroll_n, 1, 1);
-
- // Add offsets and update C.
- L_aligned(label_update_begin);
-
- if (enable_offset_r) {
- // Add row offsets.
- mov(rax, coffset_ry);
- for (int j = 0; j < unroll_n; j++) {
- Zmm row_offset = zmm0;
-
- vbroadcastss(row_offset, ptr[rax + size * j]);
-
- for (int i = 0; i < um_vecs; i++)
- vpaddd(c_regs[i][j], c_regs[i][j], row_offset);
- }
- add(coffset_ry, size * unroll_n);
- }
-
- if (enable_offset_c) {
- // Add column offsets.
- mov(rax, coffset_cy);
- for (int i = 0; i < um_vecs; i++) {
- Zmm col_offset = zmm0;
-
- c_load(col_offset, ptr[rax + size * 16 * i], unroll_m);
-
- for (int j = 0; j < unroll_n; j++)
- vpaddd(c_regs[i][j], c_regs[i][j], col_offset);
- }
- }
-
- Reg64 LDC3 = rax;
- lea(LDC3, ptr[LDC + LDC * 2]);
-
- // C updates.
- int c_off_j = 0;
- for (int j = 0; j < unroll_n; j++) {
- if (j > 0 && (j & 3) == 0) {
- lea(CO1, ptr[CO1 + LDC * 4]);
- c_off_j += 4;
- }
-
- int jj = j - c_off_j;
-
- for (int i = 0; i < um_vecs; i++) {
- Zmm c = c_regs[i][j];
- Zmm c_old = zmm0;
- decltype(LDC * jj) ldc_mult = (jj == 3) ? LDC3 : LDC * jj;
-
- auto c_mem = ptr[CO1 + ldc_mult + size * 16 * i];
-
- if (beta_zero)
- c_store(c_mem, c, unroll_m);
- else {
- c_load(c_old, c_mem, unroll_m);
- vpaddd(c_old, c, c_old);
- c_store(c_mem, c_old, unroll_m);
- }
-
- vpxorq(c, c, c);
- }
- }
-
- lea(CO1, ptr[CO1 + LDC * (unroll_n - c_off_j)]);
-}
-
-// Outer loop.
-void jit_avx512_core_gemm_s8u8s32_kern::outerloop(int unroll_x, int unroll_y,
- Label *&cur_outerloop_label)
-{
- Label label_m_loop, label_n_loop, label_n_remainder_loops[6];
-
- L(*cur_outerloop_label);
- cur_outerloop_label++;
- if (unroll_x >= IGEMM_UNROLL_M) {
- mov(J, M);
- cmp(J, unroll_x);
- jl(*cur_outerloop_label, T_NEAR); // Jump to next outerloop label.
- } else {
- test(J, unroll_x);
- jle(*cur_outerloop_label, T_NEAR);
- }
-
- L_aligned(label_m_loop); {
- mov(CO1, C);
- add(C, unroll_x * size);
-
- mov(BO, B);
-
- mov(AA, K);
- imul(AA, AA, unroll_x * isize);
- lea(AA, ptr[A + AA + isize * prefetch_size_a]);
-
- if (enable_offset_c) {
- mov(rax, coffset_cx);
- mov(coffset_cy, rax);
- add(rax, unroll_x * size);
- mov(coffset_cx, rax);
- }
-
- if (enable_offset_r) {
- mov(rax, coffset_rx);
- mov(coffset_ry, rax);
- }
-
- mov(I, N);
- cmp(I, unroll_y);
- jl(label_n_remainder_loops[0], T_NEAR);
-
- L_aligned(label_n_loop); {
- innerloop(unroll_x, unroll_y);
- sub(I, unroll_y);
- cmp(I, unroll_y);
- jge(label_n_loop, T_NEAR);
- }
-
- align(16);
-
- int label_idx = 0;
- for (int uy = 16; uy > 0; uy >>= 1) {
- L(label_n_remainder_loops[label_idx++]);
- if (unroll_y > uy) {
- test(I, uy);
- jle(label_n_remainder_loops[label_idx], T_NEAR);
-
- innerloop(unroll_x, uy);
- align(16);
- }
- }
- L(label_n_remainder_loops[label_idx]);
-
- mov(A, AO);
- if (unroll_x >= IGEMM_UNROLL_M) {
- sub(J, unroll_x);
- cmp(J, unroll_x);
- jge(label_m_loop);
- }
- }
-
- align(16);
-}
-
-void jit_avx512_core_gemm_s8u8s32_kern::generate()
-{
- // Prologue
- preamble();
- sub(rsp, stack_alloc_size);
-
- if (is_windows) {
- mov(A, arg_a);
- mov(B, arg_b);
- }
-
- mov(C, arg_c);
- mov(LDC, arg_ldc);
-
- sub(A, -offset_a * isize);
- sub(B, -offset_b * isize);
-
- mov(M, qword[M]);
- mov(N, qword[N]);
- mov(K, qword[K]);
-
- lea(LDC, ptr[LDC * size]);
-
- if (enable_offset_c) {
- mov(rax, arg_coffset_c);
- mov(coffset_cx, rax);
- }
- if (enable_offset_r) {
- mov(rax, arg_coffset_r);
- mov(coffset_rx, rax);
- }
-
- for (int i = 0; i < (max_unroll_m >> 4); i++) {
- for (int j = 0; j < max_unroll_n; j++) {
- auto &c = c_regs[i][j];
- vpxorq(c, c, c);
- }
- }
-
- if (!vnni) {
- mov(rax, 1);
- movq(make_xmm(ones), rax);
- vpbroadcastw(ones, make_xmm(ones));
- }
-
- Label outerloop_labels[8];
- Label *cur_outerloop_label = &outerloop_labels[0];
-
- // Main m loop.
- outerloop(IGEMM_UNROLL_M, IGEMM_UNROLL_N, cur_outerloop_label);
-
- // m remainder loops.
- for (int um = 32; um > 0; um >>= 1)
- if (IGEMM_UNROLL_M > um)
- outerloop(um, IGEMM_UNROLL_N, cur_outerloop_label);
-
- L(*cur_outerloop_label);
-
- // Epilogue.
- add(rsp, stack_alloc_size);
- postamble();
-}
-
-
-jit_avx512_core_gemm_s8u8s32_kern::jit_avx512_core_gemm_s8u8s32_kern(bool
- beta_zero_, bool enable_offset_c_, bool enable_offset_r_) :
- jit_generator(nullptr, 100000), arg_a(0), arg_b(0), arg_c(0), arg_ldc(0),
- arg_coffset_c(0), arg_coffset_r(0), coffset_cx(0), coffset_cy(0),
- coffset_rx(0), coffset_ry(0)
-{
- beta_zero = beta_zero_;
- enable_offset_c = enable_offset_c_;
- enable_offset_r = enable_offset_r_;
- vnni = mayiuse(avx512_core_vnni);
-
- // Assign integer registers
- M = is_windows ? rcx : rdi;
- N = is_windows ? rdx : rsi;
- K = is_windows ? r8 : rdx;
- A = is_windows ? rsi : r8;
- B = r9;
- C = r10;
- LDC = r11;
- I = r12;
- J = r13;
- LoopCount = rax;
- AO = r14;
- BO = r15;
- CO1 = rbx;
- CO2 = rbp;
- AA = is_windows ? rdi : rcx;
-
- // Assign vector registers
- dp_scratch = zmm6;
- ones = zmm7;
- for (int i = 0; i < (max_unroll_m >> 4); i++)
- a_regs[i] = Zmm(i);
- b_regs[0] = zmm4;
- b_regs[1] = zmm5;
-
- int rn = 0;
- for (int i = 0; i < (max_unroll_m >> 4); i++)
- for (int j = 0; j < max_unroll_n; j++)
- c_regs[i][j] = Zmm(8 + rn++);
-
- // Assign stack variables.
- stack_alloc_size = 32;
- auto args_offset = stack_alloc_size + get_size_of_abi_save_regs()
- + 8 + (is_windows ? 48 : 0);
-
- arg_a = ptr[rsp + (args_offset - 16)];
- arg_b = ptr[rsp + (args_offset - 8)];
- arg_c = ptr[rsp + (args_offset + 0)];
- arg_ldc = ptr[rsp + (args_offset + 8)];
- arg_coffset_c = ptr[rsp + (args_offset + 16)];
- arg_coffset_r = ptr[rsp + (args_offset + 24)];
-
- coffset_cx = qword[rsp + 0];
- coffset_cy = qword[rsp + 8];
- coffset_rx = qword[rsp + 16];
- coffset_ry = qword[rsp + 24];
-
- generate();
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp
deleted file mode 100644
index e8efcc1cc8..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp
+++ /dev/null
@@ -1,101 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef IGEMM_KERNEL_GENERATOR_HPP
-#define IGEMM_KERNEL_GENERATOR_HPP
-
-#include "jit_generator.hpp"
-
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-class jit_avx512_core_gemm_s8u8s32_kern : public jit_generator {
-public:
- jit_avx512_core_gemm_s8u8s32_kern(bool beta_zero_, bool enable_offset_c_,
- bool enable_offset_r_);
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemm_s8u8s32_kern);
-
-protected:
- bool beta_zero;
- bool enable_offset_c, enable_offset_r;
- bool vnni;
-
- void prefetch_a(const Xbyak::Address &src) {
- prefetcht0(src);
- }
- void prefetch_b(const Xbyak::Address &src) {
- prefetcht0(src);
- }
- void prefetch_c(const Xbyak::Address &src) {
- prefetchw(src);
- }
- void prefetch_x(const Xbyak::Address &src) {
- prefetcht0(src);
- }
-
- void c_load(const Xbyak::Xmm &dst, const Xbyak::Address &src, int nelems);
- void c_store(const Xbyak::Address &dst, const Xbyak::Xmm &src, int nelems);
-
- void dot_product(const Xbyak::Xmm &dst, const Xbyak::Xmm &src1,
- const Xbyak::Xmm &src2);
- void kernel_loop(int unroll_m, int unroll_n, bool cfetch);
- void remainder_kernel(int unroll_m, int unroll_n, int unroll_k, int bwidth);
- void innerloop(int unroll_m, int unroll_n);
- void outerloop(int unroll_x, int unroll_y, Xbyak::Label *&outerloop_label);
-
- void generate();
-
-
-private:
- static const int IGEMM_UNROLL_M = 48;
- static const int IGEMM_UNROLL_N = 8;
-
- static const int isize = 2;
- static const int size = 4;
-
- // Prefetch configuration
- static const int prefetch_size_a = 32 * 5;
- static const int prefetch_size_b = 32 * 4;
-
- static const int offset_a = 256, offset_b = 256;
- static const int max_unroll_m = 48, max_unroll_n = 8;
-
- // Integer register assignments
- Xbyak::Reg64 M, N, K, A, B, C, LDC, I, J, LoopCount;
- Xbyak::Reg64 AO, BO, CO1, CO2, AA;
-
- // Vector register assignments
- Xbyak::Zmm dp_scratch, ones, a_regs[max_unroll_m >> 4], b_regs[2];
- Xbyak::Zmm c_regs[max_unroll_m >> 4][max_unroll_n];
-
- // Stack variable assignments
- int stack_alloc_size;
- Xbyak::Address arg_a, arg_b, arg_c, arg_ldc, arg_coffset_c, arg_coffset_r;
- Xbyak::Address coffset_cx, coffset_cy, coffset_rx, coffset_ry;
-
- void L_aligned(Xbyak::Label &label, int alignment = 16) {
- align(alignment);
- L(label);
- }
-};
-
-}
-}
-}
-
-#endif /* header guard */
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp
deleted file mode 100644
index 4f0b10dadd..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp
+++ /dev/null
@@ -1,290 +0,0 @@
-/*******************************************************************************
- * Copyright 2019 Intel Corporation
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *******************************************************************************/
-
-#include "gemv.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-int gemm_s8u8s32_jump_to_gemv_s8u8s32(blas_t *arg) {
-
- blas_t arg_gemv = *arg;
-
- if ((arg -> offsetc == FIX_OFFSET) && // Fix offset
- (arg -> ao == 0) &&
- (arg -> bo == 0) &&
- (arg -> co[0] == 0) &&
- (*(arg -> alpha) == 1.0f) &&
- ((*(arg -> beta) == 1.0f) || *(arg -> beta) == 0.0f)) {
-
- if (arg -> n == 1) {
-
- if (arg -> transa == 1) { // A transpose
- arg_gemv.n = arg -> k;
- arg_gemv.ldc = 1;
- arg_gemv.swap = 0;
- if (arg -> transb == 0) { // B non transpose
- arg_gemv.ldb = 1;
- }
- // B transpose arg_gemv.ldb = arg -> ldb
- gemv_threading_driver(&arg_gemv);
- return 1;
- }
- }
-
- if (arg -> m == 1) {
-
- if (arg -> transb == 0) { // B non transpose
- arg_gemv.transa = 1;
- arg_gemv.m = arg -> n;
- arg_gemv.n = arg -> k;
- arg_gemv.a = (int8_t *) arg -> b;
- arg_gemv.lda = arg -> ldb;
- arg_gemv.b = (uint8_t *) arg -> a;
- arg_gemv.swap = 1;
- if (arg -> transa == 0) { // A non transpose
- arg_gemv.ldb = arg -> lda;
- }
- else { // A transpose
- arg_gemv.ldb = 1;
- }
- gemv_threading_driver(&arg_gemv);
- return 1;
- }
- }
- }
-
- return 0;
-}
-
-
-int gemv_kernel_driver(blas_t *arg) {
-
- dim_t m = arg -> m;
- dim_t n = arg -> n;
- uint8_t *a = (uint8_t *) arg -> a;
- dim_t lda = arg -> lda;
- int8_t *b = (int8_t *) arg -> b;
- float beta = *(arg -> beta);
-
- if (arg -> swap) {
- arg -> gemv_u8s8s32_kernel(m, n, 1.0f, a, lda, b, beta, arg -> c);
- }
- else {
- arg -> gemv_s8u8s32_kernel(arg -> m, arg -> n, 1.0f, arg -> a,
- arg -> lda, arg -> b, *(arg -> beta), arg -> c);
- }
-
- return 0;
-}
-
-int gemv_threading_driver(blas_t *arg) {
-
- dim_t nthr_m, nthr_n = 1;
- dim_t MB, NB, UM = 16, UN = 64;
- dim_t BLOCKM = 192, BLOCKN = 3072;
- int status;
- dim_t i;
-
- dim_t nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
-
- uint8_t *new_x = NULL;
- int32_t *tmp_y = NULL, *new_y = NULL;
-
- dim_t m = arg -> m, n = arg -> n;
-
- blas_t arg_seq = *arg;
- float zero = 0.0f;
-
- nthr_m = std::min(std::max(m / BLOCKM, (dim_t) 1), nthr);
- MB = m / nthr_m;
- MB = (((MB / UM) * UM) == MB) ? MB : (MB / UM) * UM + UM;
- nthr_m = (((m / MB) * MB) == m) ? m / MB : m / MB + 1;
- nthr_m = std::min(std::max(nthr_m, (dim_t) 1), nthr);
-
- while ((nthr_m * (nthr_n + 1) <= nthr) && ((n / (nthr_n + 1)) >= BLOCKN)) {
- nthr_n++;
- }
-
- NB = n / nthr_n;
- NB = (((NB / UN) * UN) == NB) ? NB : (NB / UN) * UN + UN;
- nthr_n = (((n / NB) * NB) == n) ? n / NB : n / NB + 1;
- nthr_n = std::min(std::max(nthr_n, (dim_t) 1), nthr / nthr_m);
-
- nthr = nthr_m * nthr_n;
-
- if (arg -> ldb != 1) {
- new_x = (uint8_t *)malloc(n, 64);
- if (new_x == NULL)
- return 1;
- for (i = 0; i < n; i++) {
- new_x[i] = (arg -> b)[i * arg -> ldb];
- }
- arg_seq.b = new_x;
- arg_seq.ldb = 1;
- }
- else new_x = (uint8_t *) arg -> b;
-
- if (arg -> ldc != 1) {
- new_y = (int32_t *) malloc(nthr_m * PADD_BYTESIZE_ONPAGE(MB, sizeof(int32_t)), 64);
- if (new_y == NULL) {
- if (arg -> ldb != 1) {
- free(new_x);
- }
- return 1;
- }
- }
-
- // GEMV computation
- if (nthr == 1) {
-
- if (arg -> ldc != 1) {
- if (*(arg -> beta) != 0.0f) {
- for (i = 0; i < m; i++) {
- new_y[i] = arg -> c[i * arg -> ldc];
- }
- }
- }
-
- status = gemv_kernel_driver(&arg_seq);
-
- if (arg -> ldc != 1) {
- for (i = 0; i < m; i++) {
- arg -> c[i * arg -> ldc] = new_y[i];
- }
- }
-
- if (arg -> ldb != 1) {
- free(new_x);
- }
- if (arg -> ldc != 1) {
- free(new_y);
- }
- return status;
- }
-
- if (nthr_n > 1) {
- tmp_y = (int32_t *) malloc((nthr_n - 1) * PADD_BYTESIZE_ONPAGE(m, sizeof(int32_t)), PAGESIZE);
- if (tmp_y == NULL) {
- if (arg -> ldb != 1) {
- free(new_x);
- }
- return 1;
- }
- }
-
- parallel_nd((int) nthr, [&](const dim_t ithr) {
-
- dim_t m_from, m_to, myM;
- dim_t n_from, n_to, myN;
-
- dim_t n_id, m_id;
- dim_t loc_incy = 1;
- int32_t *loc_y;
-
- blas_t arg_loc = arg_seq;
- int j;
-
- m_id = ithr / nthr_n;
- n_id = ithr % nthr_n;
-
- m_from = MB * m_id;
- m_to = MB * (m_id + 1);
- if ((m_to > m) || (m_id == nthr_m - 1))
- m_to = m;
-
- myM = m_to - m_from;
-
- n_from = NB * n_id;
- n_to = NB * (n_id + 1);
- if ((n_to > n) || (n_id == nthr_n - 1))
- n_to = n;
-
- myN = n_to - n_from;
-
- if (n_id != 0) {
- arg_loc.beta = &zero;
- loc_y = tmp_y + (NEXT_THR_STRIDE(m, sizeof(int32_t))) * (n_id - 1) + m_from;
- }
- else {
- if (arg -> ldc == 1) {
- loc_y = arg_seq.c + m_from;
- }
- else {
- // need to copy the block of c in new_y
- loc_y = new_y + m_id * NEXT_THR_STRIDE(MB, sizeof(int32_t));
- if (*(arg -> beta) != 0.0f) {
- for (j = 0; j < myM; j++) {
- loc_y[j] = arg -> c[(m_from + j) * arg -> ldc];
- }
- }
- }
- }
-
- arg_loc.m = myM;
- arg_loc.n = myN;
- arg_loc.a = arg_seq.a + m_from * arg_seq.lda + n_from;
- arg_loc.b = arg_seq.b + n_from;
- arg_loc.c = loc_y;
- arg_loc.ldc = loc_incy;
-
- gemv_kernel_driver(&arg_loc);
-
- if ((n_id == 0) && (arg -> ldc != 1)) {
- for (j = 0; j < myM; j++) {
- arg -> c[(m_from + j) * arg -> ldc] = loc_y[j];
- }
- }
-
- });
-
- if (nthr_n > 1) {
- parallel_nd((int) nthr_m, [&](const dim_t ithr) {
-
- dim_t j, j_from, j_to, ii;
- int32_t acc;
-
- j_from = MB * ithr;
- j_to = MB * (ithr + 1);
- if ((j_to > m) || (ithr == nthr - 1))
- j_to = m;
-
- for (j = j_from; j < j_to; j++) {
- acc = 0;
- for (ii = 0; ii < nthr_n - 1; ii++) {
- acc += tmp_y[ii * NEXT_THR_STRIDE(m, sizeof(int32_t)) + j];
- }
- (arg -> c)[j * arg -> ldc] += acc;
- }
- });
- free(tmp_y);
- }
-
- if (arg -> ldb != 1) {
- free(new_x);
- }
-
- if (arg -> ldc != 1) {
- free(new_y);
- }
-
- return 0;
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp
deleted file mode 100644
index c57a8c1d12..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp
+++ /dev/null
@@ -1,411 +0,0 @@
-/*******************************************************************************
- * Copyright 2019 Intel Corporation
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *******************************************************************************/
-
-#include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp"
-
-#ifdef _WIN32
-#define is_windows 1
-#else
-#define is_windows 0
-#endif
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-void jit_avx512_core_gemv_s8u8s32_kern::vnni(Xbyak::Zmm acc, Xbyak::Zmm b,
- Xbyak::Zmm a, Xbyak::Zmm tmp,
- Xbyak::Zmm one, bool swap,
- int use_vnni) {
-
- if (use_vnni) {
- if (swap)
- vpdpbusd(acc, a, b);
- else
- vpdpbusd(acc, b, a);
- }
-
- else {
- if (swap)
- vpmaddubsw(tmp, a, b);
- else
- vpmaddubsw(tmp, b, a);
- vpmaddwd(tmp, tmp, one);
- vpaddd(acc, tmp, acc);
- }
-
-}
-
-void jit_avx512_core_gemv_s8u8s32_kern::n_loop_body(int start_a_idx, int start_acc_idx,
- int b_idx, int nreg_acc,
- Xbyak::Reg64 A, Xbyak::Reg64 lda,
- Xbyak::Reg64 X, Xbyak::Zmm tmp,
- Xbyak::Zmm one, bool swap, int use_vnni,
- int use_mask, Xbyak::Opmask mask_n) {
-
- int i;
- int nreg_A = nreg_acc / 2 + (nreg_acc % 2);
-
- // load X + j
- if (use_mask)
- vmovdqu8(Xbyak::Zmm(b_idx) | mask_n | T_z, ptr[X]);
- else
- vmovdqu8(Xbyak::Zmm(b_idx), ptr[X]);
-
- xor_(r14, r14);
- // load values of A
- for (i = 0; i < nreg_A; i++) {
- if (use_mask)
- vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]);
- else
- vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]);
- add(r14, lda);
- }
-
- for (i = 0; i < nreg_A; i++) {
- // vnni (acc, b, a, tmp, one, swap, use_vnni)
- vnni(Xbyak::Zmm(start_acc_idx + i), Xbyak::Zmm(b_idx),
- Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni);
- }
-
- for (i = 0; i < nreg_A - (nreg_acc % 2); i++) {
- if (use_mask)
- vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]);
- else
- vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]);
- add(r14, lda);
- }
-
- for (i = 0; i < nreg_A - (nreg_acc % 2); i++) {
- vnni(Xbyak::Zmm(start_acc_idx + i + nreg_A), Xbyak::Zmm(b_idx),
- Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni);
- }
-
-}
-
-void jit_avx512_core_gemv_s8u8s32_kern::shuffle_and_add(Xbyak::Zmm dest, Xbyak::Zmm A,
- Xbyak::Zmm B, Xbyak::Zmm C,
- Xbyak::Zmm D) {
-
- vshufi32x4(dest, A, C, 0x44);
- vshufi32x4(A, A, C, 0xEE);
- vpaddd(C, dest, A); // C = A0 + A2|A1 + A3|C0 + C2|C1 + C3
-
- vshufi32x4(dest, B, D, 0x44);
- vshufi32x4(B, B, D, 0xEE);
- vpaddd(D, dest, B); // D = B0 + B2|B1 + B3|D0 + D2|D1 + D3
-
- vshufi32x4(A, C, D, 0x88);
- vshufi32x4(B, C, D, 0xDD);
- vpaddd(dest, A, B); // dest = SAi|SBi|SCi|SDi
-
-}
-
-void jit_avx512_core_gemv_s8u8s32_kern::update_c(int nreg_acc, Xbyak::Reg64 Y,
- int start_a_idx, int start_acc_idx,
- Xbyak::Xmm beta, int use_mask,
- Xbyak::Opmask mask_m) {
-
- int l, i, k, j, last_it;
- Xbyak::Label store_label;
-
- l = 0;
- for (k = 0; k < nreg_acc; k += 8) {
- for (i = 0, j = k; i < 8; i += 4, j += 2) {
- if (j < nreg_acc) {
- // shuffle per block of 4 registers
- shuffle_and_add(Xbyak::Zmm(start_a_idx + l), // dest
- Xbyak::Zmm(start_acc_idx + j), // A = acc0
- Xbyak::Zmm(start_acc_idx + 1 + j), // B = acc1
- Xbyak::Zmm(start_acc_idx + 4 + j), // C = acc4
- Xbyak::Zmm(start_acc_idx + 5 + j)); // D = acc5
-
- // extract low and high from dest and hadd
- vextracti32x8(Xbyak::Ymm(start_a_idx + l + 1), Xbyak::Zmm(start_a_idx + l), 0);
- vextracti32x8(Xbyak::Ymm(start_a_idx + l + 2), Xbyak::Zmm(start_a_idx + l), 1);
- vphaddd(Xbyak::Ymm(start_a_idx + l),
- Xbyak::Ymm(start_a_idx + l + 1),
- Xbyak::Ymm(start_a_idx + l + 2));
- }
- l++;
- }
-
- vphaddd(Xbyak::Ymm(start_a_idx + l),
- Xbyak::Ymm(start_a_idx + l - 2),
- Xbyak::Ymm(start_a_idx + l - 1));
-
- l++;
- }
-
- // eventually add with C and store new value
- vxorps(Xbyak::Ymm(start_a_idx),
- Xbyak::Ymm(start_a_idx),
- Xbyak::Ymm(start_a_idx));
- vucomiss(beta, Xbyak::Ymm(start_a_idx));
- je(store_label, T_NEAR);
-
- // beta = 1
- for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) {
- // load Y and add
- last_it = (k + 8) > nreg_acc;
- if (use_mask && last_it)
- vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8) | mask_m | T_z, ptr[Y + (k / 8) * 32]);
- else
- vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8), ptr[Y + (k / 8) * 32]);
-
- vpaddd(Xbyak::Ymm(start_a_idx + l),
- Xbyak::Ymm(start_a_idx + l),
- Xbyak::Ymm(start_a_idx + k / 8));
- }
-
- // store
- aligned_label(store_label);
- for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) {
- last_it = (k + 8) > nreg_acc;
- if (use_mask && last_it)
- vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l) | mask_m);
- else
- vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l));
- }
-
-}
-
-template <typename T>
-T jit_avx512_core_gemv_s8u8s32_kern::generate(int use_vnni) {
-
- Xbyak::Opmask mask_n = k1, mask_m = k2;
- Xbyak::Label one_label, m_tail_label, m_loop_label, n_loop_label;
- Xbyak::Label n_tail_label, update_c_label, end_label;
- constexpr unsigned int n_labels = (1 << unroll_m) - 1;
- Xbyak::Label m_tail_label_case[n_labels];
- Xbyak::Label n_loop_label_case[n_labels];
- Xbyak::Label n_tail_label_case[n_labels];
- Xbyak::Label update_c_label_case[n_labels];
-
- int i, ii;
-
- Xbyak::Zmm one, tmp;
- Xbyak::Reg64 n = abi_param2, m = abi_param1;
- Xbyak::Reg64 A = is_windows ? abi_param4 : abi_param3;
- Xbyak::Reg64 lda = is_windows ? abi_param3 : abi_param4;
- Xbyak::Reg64 X = is_windows ? rdi : r8;
- Xbyak::Xmm beta = xmm1;
- Xbyak::Reg64 Y = is_windows ? rsi : r9;
-
- bool swap = !std::is_same<T, gemv_s8u8s32_kernel_t>::value;
-
- // Windows: read on the stack lda, X, beta, Y
-
- int zmm_idx = 1;
- int nreg_acc = 1 << unroll_m;
- int nreg_A = 1 << (unroll_m - 1);
- int nreg_A_acc = nreg_acc + nreg_A;
-
- if (!use_vnni) {
- // set a zmm register to one
- tmp = Xbyak::Zmm(0);
- one = Xbyak::Zmm(zmm_idx + 1);
- zmm_idx += 2; // one + tmp
- }
- else {
- beta = xmm0;
- }
-
- preamble();
-
- if (is_windows) {
- mov(lda, ptr[rsp + get_size_of_abi_save_regs() + 40]);
- mov(X, ptr[rsp + get_size_of_abi_save_regs() + 48]);
- movss(beta, ptr[rsp + get_size_of_abi_save_regs() + 56]);
- mov(Y, ptr[rsp + get_size_of_abi_save_regs() + 64]);
- }
-
- if (use_vnni && !is_windows) {
- movaps(beta, xmm1);
- }
-
- mov(rax, (1 << unroll_n) - 1);
- kmovq(k3, rax);
-
- and_(rax, n); // rax contains n & ((1 << unroll_n) - 1)
- mov(rbx, 1);
- shlx(rbx, rbx, rax);
- sub(rbx, 1);
- kmovq(mask_n, rbx);
- // mask_n set (AVX512 only), can use rax and rbx again
-
- // set mask_m for update of the C matrix
- // load/store on the C matrix use Ymm so tail according to Ymm size
- mov(rax, 7); // 8 * 32 = 256 Ymm size
- and_(rax, m); // rax contains m & 7
- mov(rbx, 1);
- shlx(rbx, rbx, rax);
- sub(rbx, 1);
- kmovq(mask_m, rbx);
- // mask_m set (AVX512 only), can use rax and rbx again
-
- // setup register of ones when VNNI instructions not available
- if (!use_vnni) {
- vmovdqu16(one, ptr[rip + one_label]);
- }
-
- // M loop
- // base pointer for A rax contains a + i * lda
- // Loop stop when rax >= a + (m & mask_um) * lda = rbx
- // loop increment r10 = um * lda
- // rbp = Y + i
- mov(rax, A); // i = 0
- mov(rbx, m);
- and_(rbx, mask_um);
- imul(rbx, lda);
- add(rbx, A);
- mov(r10, lda);
- sal(r10, unroll_m);
- mov(rbp, Y);
-
- // N loop
- // base pointer for X r11 contains x + j
- // Loop stop when r11 >= x + n & mask_un = r12
- // loop increment un
- // r13 = rax + j = A + i * lda + j
- mov(r12, n);
- and_(r12, mask_un);
- add(r12, X);
-
- // M loop
- aligned_label(m_loop_label);
- cmp(rax, rbx);
- jge(m_tail_label, T_NEAR);
-
- // enter M loop
- for(i = 0; i < nreg_acc; i++) {
- vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A),
- Xbyak::Zmm(i + zmm_idx + nreg_A),
- Xbyak::Zmm(i + zmm_idx + nreg_A));
- }
-
- // N loop
- mov(r11, X); // j = 0
- mov(r13, rax);
- aligned_label(n_loop_label);
- cmp(r11, r12);
- jge(n_tail_label, T_NEAR);
-
- // enter N loop
-
- n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc,
- r13, lda, r11, tmp, one, swap, use_vnni, 0, mask_n);
-
- // increment rax with un
- add(r11, 1 << unroll_n);
- add(r13, 1 << unroll_n);
- jmp(n_loop_label, T_NEAR);
- // end N loop
-
- // N tail
- aligned_label(n_tail_label);
-
- ktestq(mask_n, k3);
- je(update_c_label, T_NEAR);
- n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc,
- r13, lda, r11, tmp, one, swap, use_vnni, 1, mask_n);
-
- // update C matrix
- aligned_label(update_c_label);
-
- update_c(nreg_acc, rbp, zmm_idx, zmm_idx + nreg_A, beta, 0, mask_m);
-
- // increment rax with um * lda
- add(rax, r10);
- add(rbp, 1 << (unroll_m + 2));
- jmp(m_loop_label, T_NEAR);
- // end M loop
-
- // M tail
- aligned_label(m_tail_label);
-
- // r10 will contain m_tail = m % unroll_m = m & (1 << unroll_m) - 1
- mov(r10, m);
- and_(r10, (1 << unroll_m) - 1);
- for (ii = 1; ii < 1 << unroll_m; ii++) {
- aligned_label(m_tail_label_case[ii-1]);
- cmp(r10, ii);
- if (ii == (1 << unroll_m) - 1)
- jne(end_label, T_NEAR);
- else
- jne(m_tail_label_case[ii], T_NEAR);
-
- // m_tail = i, use i accumulators
-
- for(i = 0; i < ii; i++) {
- vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A),
- Xbyak::Zmm(i + zmm_idx + nreg_A),
- Xbyak::Zmm(i + zmm_idx + nreg_A));
- }
-
- // N loop
- mov(r11, X); // j = 0
- mov(r13, rax);
- aligned_label(n_loop_label_case[ii - 1]);
- cmp(r11, r12);
- jge(n_tail_label_case[ii - 1], T_NEAR);
-
- n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13,
- lda, r11, tmp, one, swap, use_vnni, 0, mask_n);
-
- // increment rax with un
- add(r11, 1 << unroll_n);
- add(r13, 1 << unroll_n);
- jmp(n_loop_label_case[ii - 1], T_NEAR);
- // end N loop
-
- // N tail
- aligned_label(n_tail_label_case[ii - 1]);
- ktestq(mask_n, k3);
- je(update_c_label_case[ii - 1], T_NEAR);
- n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13,
- lda, r11, tmp, one, swap, use_vnni, 1, mask_n);
-
- // update C matrix
- aligned_label(update_c_label_case[ii - 1]);
- update_c(ii, rbp, zmm_idx, zmm_idx + nreg_A, beta, 1, mask_m);
-
- if (ii < ((1 << unroll_m) - 1))
- jmp(end_label, T_NEAR);
- }
-
- aligned_label(end_label);
-
- postamble();
-
- if (!use_vnni) {
- aligned_label(one_label);
- for (i = 0; i < size_vec_reg/8; i++)
- dq(0x0001000100010001);
- }
-
- return (T) getCode();
-}
-
-template jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t
-jit_avx512_core_gemv_s8u8s32_kern::generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t>(int);
-
-template jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t
-jit_avx512_core_gemv_s8u8s32_kern::generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t>(int);
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp
deleted file mode 100644
index 9ea23a5f56..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp
+++ /dev/null
@@ -1,64 +0,0 @@
-/*******************************************************************************
- * Copyright 2019 Intel Corporation
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *******************************************************************************/
-
-#include "jit_generator.hpp"
-#include "common.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-class jit_avx512_core_gemv_s8u8s32_kern : jit_generator {
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemv_s8u8s32_kern);
-
- // assumes untoll_{m,n} are a power of 2
- static constexpr unsigned int unroll_m = 4; // real unrolling factor is 2^unroll_m
- const int mask_um = 0xFFFFFFF0;
- static constexpr unsigned int unroll_n = 6; // real unrolling factor is 2^unroll_n
- const int mask_un = 0xFFFFFFC0;
- const int size_vec_reg = 64; // bytes
-
- void aligned_label(Xbyak::Label &label, int alignment = 16) {
- align(alignment);
- L(label);
- }
-
- void vnni(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, bool, int);
- void n_loop_body(int, int, int, int, Xbyak::Reg64, Xbyak::Reg64,
- Xbyak::Reg64, Xbyak::Zmm, Xbyak::Zmm, bool, int, int, Xbyak::Opmask);
- void shuffle_and_add(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm);
- void update_c(int, Xbyak::Reg64, int, int, Xbyak::Xmm, int, Xbyak::Opmask);
-
-public:
- jit_avx512_core_gemv_s8u8s32_kern() : jit_generator(nullptr, GEMM_CODE_SIZE) {};
-
- // m, n, alpha, a, lda, x, beta, y
- typedef void (*gemv_s8u8s32_kernel_t)(const dim_t, const dim_t, const float,
- const int8_t*, const dim_t, const uint8_t*,
- const float, int32_t*);
- typedef void (*gemv_u8s8s32_kernel_t)(const dim_t, const dim_t, const float,
- const uint8_t*, const dim_t, const int8_t*,
- const float, int32_t*);
-
- template <typename T>
- T generate(int use_vnni);
-
-};
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp
deleted file mode 100644
index 544cd2ff25..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp
+++ /dev/null
@@ -1,819 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "jit_generator.hpp"
-#include "common.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-jit_avx512_core_u8_copy_an_kern::jit_avx512_core_u8_copy_an_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
-{
-
-#ifndef _WIN32
-#define M rdi
-#define N rsi
-#define A rdx
-#define LDA rcx
-#define ALPHA r8
-#define B r9
-
-#define I rax
-#define A1 r10
-#define A2 r8
-#define LDA3 r11
-
-#else
-
-#define M rcx
-#define N rdx
-#define A r8
-#define LDA r9
-#define ALPHA rax
-#define B rdi
-
-#define I rax
-#define A1 rsi
-#define A2 r10
-#define LDA3 r11
-
-#define ARG_ALPHA 40+stacksize+rsp
-#define ARG_B 48+stacksize+rsp
-
-#endif
-
-inLocalLabel();
-{
-
-Xbyak::Label l170;
-Xbyak::Label l1f0;
-Xbyak::Label l20;
-Xbyak::Label l224;
-Xbyak::Label l234;
-Xbyak::Label l240;
-Xbyak::Label l254;
-Xbyak::Label l32c;
-Xbyak::Label l34;
-Xbyak::Label l388;
-Xbyak::Label l3b0;
-Xbyak::Label l3c0;
-Xbyak::Label l3cc;
-Xbyak::Label l3dc;
-Xbyak::Label l454;
-Xbyak::Label l48c;
-Xbyak::Label l4a8;
-Xbyak::Label l4b8;
-Xbyak::Label l4c4;
-Xbyak::Label l4d8;
-Xbyak::Label l570;
-Xbyak::Label l5c4;
-Xbyak::Label l5f0;
-Xbyak::Label l60c;
-Xbyak::Label l61c;
-Xbyak::Label l628;
-Xbyak::Label l638;
-Xbyak::Label l6b0;
-Xbyak::Label l6f4;
-Xbyak::Label l720;
-Xbyak::Label l73c;
-Xbyak::Label l74c;
-Xbyak::Label l758;
-Xbyak::Label l76c;
-Xbyak::Label l804;
-Xbyak::Label l858;
-Xbyak::Label l88c;
-Xbyak::Label l8a4;
-Xbyak::Label l8b2;
-Xbyak::Label l8bc;
-Xbyak::Label l8cc;
-Xbyak::Label l944;
-Xbyak::Label l98c;
-Xbyak::Label l9b0;
-Xbyak::Label l9c8;
-Xbyak::Label l9d8;
-
- preamble();
-#ifdef _WIN32
- auto stacksize = get_size_of_abi_save_regs();
- mov(ALPHA, ptr[ARG_ALPHA]);
- mov(B, ptr[ARG_B]);
-#endif
-
- mov(M, qword[M]);
- mov(N, qword[N]);
- mov(LDA, qword[LDA]);
- lea(LDA3, ptr[LDA+LDA*2]);
- sub(A, -128);
- sub(B, -128);
- cmp(N, 0x30);
- jl(l234, T_NEAR);
- align(4);
-
-L(l20);
- mov(A1, A);
- add(A, 0x30);
- mov(I, M);
- sar(I, 0x2);
- jle(l170, T_NEAR);
- align(4);
-
-L(l34);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1+LDA*1-0x80]);
- movdqu(xmm2, xword[A1+LDA*2-0x80]);
- movdqu(xmm3, xword[A1+LDA3*1-0x80]);
- movdqa(xmm4, xmm0);
- punpcklbw(xmm0, xmm1);
- punpckhbw(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpcklbw(xmm2, xmm3);
- punpckhbw(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- movdqa(xmm2, xmm4);
- punpcklwd(xmm4, xmm5);
- punpckhwd(xmm2, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm1);
- movdqu(xword[B-0x60], xmm4);
- movdqu(xword[B-0x50], xmm2);
- movdqu(xmm0, xword[A1-0x70]);
- movdqu(xmm1, xword[A1+LDA*1-0x70]);
- movdqu(xmm2, xword[A1+LDA*2-0x70]);
- movdqu(xmm3, xword[A1+LDA3*1-0x70]);
- movdqa(xmm4, xmm0);
- punpcklbw(xmm0, xmm1);
- punpckhbw(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpcklbw(xmm2, xmm3);
- punpckhbw(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- movdqa(xmm2, xmm4);
- punpcklwd(xmm4, xmm5);
- punpckhwd(xmm2, xmm5);
- movdqu(xword[B-0x40], xmm0);
- movdqu(xword[B-0x30], xmm1);
- movdqu(xword[B-0x20], xmm4);
- movdqu(xword[B-0x10], xmm2);
- movdqu(xmm0, xword[A1-0x60]);
- movdqu(xmm1, xword[A1+LDA*1-0x60]);
- movdqu(xmm2, xword[A1+LDA*2-0x60]);
- movdqu(xmm3, xword[A1+LDA3*1-0x60]);
- lea(A1, ptr[A1+LDA*4]);
- movdqa(xmm4, xmm0);
- punpcklbw(xmm0, xmm1);
- punpckhbw(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpcklbw(xmm2, xmm3);
- punpckhbw(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- movdqa(xmm2, xmm4);
- punpcklwd(xmm4, xmm5);
- punpckhwd(xmm2, xmm5);
- movdqu(xword[B], xmm0);
- movdqu(xword[B+0x10], xmm1);
- movdqu(xword[B+0x20], xmm4);
- movdqu(xword[B+0x30], xmm2);
- sub(B, -192);
- dec(I);
- jg(l34, T_NEAR);
- align(4);
-
-L(l170);
- test(M, 0x2);
- jle(l1f0, T_NEAR);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1-0x70]);
- movdqu(xmm2, xword[A1-0x60]);
- add(A1, LDA);
- movdqu(xmm3, xword[A1-0x80]);
- movdqu(xmm4, xword[A1-0x70]);
- movdqu(xmm5, xword[A1-0x60]);
- add(A1, LDA);
- movdqa(xmm6, xmm0);
- punpcklbw(xmm0, xmm3);
- punpckhbw(xmm6, xmm3);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm6);
- movdqa(xmm6, xmm1);
- punpcklbw(xmm1, xmm4);
- punpckhbw(xmm6, xmm4);
- movdqu(xword[B-0x60], xmm1);
- movdqu(xword[B-0x50], xmm6);
- movdqa(xmm6, xmm2);
- punpcklbw(xmm2, xmm5);
- punpckhbw(xmm6, xmm5);
- movdqu(xword[B-0x40], xmm2);
- movdqu(xword[B-0x30], xmm6);
- sub(B, -96);
- align(4);
-
-L(l1f0);
- test(M, 0x1);
- jle(l224, T_NEAR);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1-0x70]);
- movdqu(xmm2, xword[A1-0x60]);
- add(A1, LDA);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm1);
- movdqu(xword[B-0x60], xmm2);
- sub(B, -48);
- align(4);
-
-L(l224);
- sub(N, 0x30);
- cmp(N, 0x30);
- jge(l20, T_NEAR);
- align(4);
-
-L(l234);
- cmp(N, 0x20);
- jl(l3c0, T_NEAR);
- align(4);
-
-L(l240);
- mov(A1, A);
- add(A, 0x20);
- mov(I, M);
- sar(I, 0x2);
- jle(l32c, T_NEAR);
- align(4);
-
-L(l254);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1+LDA*1-0x80]);
- movdqu(xmm2, xword[A1+LDA*2-0x80]);
- movdqu(xmm3, xword[A1+LDA3*1-0x80]);
- movdqa(xmm4, xmm0);
- punpcklbw(xmm0, xmm1);
- punpckhbw(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpcklbw(xmm2, xmm3);
- punpckhbw(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- movdqa(xmm2, xmm4);
- punpcklwd(xmm4, xmm5);
- punpckhwd(xmm2, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm1);
- movdqu(xword[B-0x60], xmm4);
- movdqu(xword[B-0x50], xmm2);
- movdqu(xmm0, xword[A1-0x70]);
- movdqu(xmm1, xword[A1+LDA*1-0x70]);
- movdqu(xmm2, xword[A1+LDA*2-0x70]);
- movdqu(xmm3, xword[A1+LDA3*1-0x70]);
- lea(A1, ptr[A1+LDA*4]);
- movdqa(xmm4, xmm0);
- punpcklbw(xmm0, xmm1);
- punpckhbw(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpcklbw(xmm2, xmm3);
- punpckhbw(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- movdqa(xmm2, xmm4);
- punpcklwd(xmm4, xmm5);
- punpckhwd(xmm2, xmm5);
- movdqu(xword[B-0x40], xmm0);
- movdqu(xword[B-0x30], xmm1);
- movdqu(xword[B-0x20], xmm4);
- movdqu(xword[B-0x10], xmm2);
- sub(B, -128);
- dec(I);
- jg(l254, T_NEAR);
- align(4);
-
-L(l32c);
- test(M, 0x2);
- jle(l388, T_NEAR);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1-0x70]);
- add(A1, LDA);
- movdqu(xmm2, xword[A1-0x80]);
- movdqu(xmm3, xword[A1-0x70]);
- add(A1, LDA);
- movdqa(xmm4, xmm0);
- punpcklbw(xmm0, xmm2);
- punpckhbw(xmm4, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm4);
- movdqa(xmm4, xmm1);
- punpcklbw(xmm1, xmm3);
- punpckhbw(xmm4, xmm3);
- movdqu(xword[B-0x60], xmm1);
- movdqu(xword[B-0x50], xmm4);
- sub(B, -64);
- align(4);
-
-L(l388);
- test(M, 0x1);
- jle(l3b0, T_NEAR);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1-0x70]);
- add(A1, LDA);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm1);
- sub(B, -32);
- align(4);
-
-L(l3b0);
- sub(N, 0x20);
- cmp(N, 0x20);
- jge(l240, T_NEAR);
- align(4);
-
-L(l3c0);
- cmp(N, 0x10);
- jl(l4b8, T_NEAR);
- align(4);
-
-L(l3cc);
- mov(A1, A);
- add(A, 0x10);
- mov(I, M);
- sar(I, 0x2);
- jle(l454, T_NEAR);
- align(4);
-
-L(l3dc);
- movdqu(xmm0, xword[A1-0x80]);
- add(A1, LDA);
- movdqu(xmm1, xword[A1-0x80]);
- add(A1, LDA);
- movdqu(xmm2, xword[A1-0x80]);
- add(A1, LDA);
- movdqu(xmm3, xword[A1-0x80]);
- add(A1, LDA);
- movdqa(xmm4, xmm0);
- punpcklbw(xmm0, xmm1);
- punpckhbw(xmm4, xmm1);
- movdqa(xmm1, xmm2);
- punpcklbw(xmm2, xmm3);
- punpckhbw(xmm1, xmm3);
- movdqa(xmm3, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm3, xmm2);
- movdqa(xmm2, xmm4);
- punpcklwd(xmm4, xmm1);
- punpckhwd(xmm2, xmm1);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm3);
- movdqu(xword[B-0x60], xmm4);
- movdqu(xword[B-0x50], xmm2);
- sub(B, -64);
- dec(I);
- jg(l3dc, T_NEAR);
- align(4);
-
-L(l454);
- test(M, 0x2);
- jle(l48c, T_NEAR);
- movdqu(xmm0, xword[A1-0x80]);
- add(A1, LDA);
- movdqu(xmm1, xword[A1-0x80]);
- add(A1, LDA);
- movdqa(xmm2, xmm0);
- punpcklbw(xmm0, xmm1);
- punpckhbw(xmm2, xmm1);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm2);
- sub(B, -32);
- align(4);
-
-L(l48c);
- test(M, 0x1);
- jle(l4a8, T_NEAR);
- movdqu(xmm0, xword[A1-0x80]);
- add(A1, LDA);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l4a8);
- sub(N, 0x10);
- cmp(N, 0x10);
- jge(l3cc, T_NEAR);
- align(4);
-
-L(l4b8);
- cmp(N, 0x8);
- jl(l61c, T_NEAR);
- align(4);
-
-L(l4c4);
- mov(A1, A);
- add(A, 0x8);
- mov(I, M);
- sar(I, 0x3);
- jle(l570, T_NEAR);
- align(4);
-
-L(l4d8);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm1, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm2, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm3, qword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm1);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm1, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm2, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm3, qword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- movdqu(xword[B-0x60], xmm0);
- movdqu(xword[B-0x50], xmm1);
- sub(B, -64);
- dec(I);
- jg(l4d8, T_NEAR);
- align(4);
-
-L(l570);
- test(M, 0x4);
- jle(l5c4, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm1, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm2, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm3, qword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm1);
- sub(B, -32);
- align(4);
-
-L(l5c4);
- test(M, 0x2);
- jle(l5f0, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm1, qword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l5f0);
- test(M, 0x1);
- jle(l60c, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l60c);
- sub(N, 0x8);
- cmp(N, 0x8);
- jge(l4c4, T_NEAR);
- align(4);
-
-L(l61c);
- cmp(N, 0x4);
- jl(l74c, T_NEAR);
- align(4);
-
-L(l628);
- mov(A1, A);
- add(A, 0x4);
- mov(I, M);
- sar(I, 0x3);
- jle(l6b0, T_NEAR);
- align(4);
-
-L(l638);
- movd(xmm0, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm1, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm2, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm3, dword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movd(xmm0, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm1, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm2, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm3, dword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- movdqu(xword[B-0x70], xmm0);
- sub(B, -32);
- dec(I);
- jg(l638, T_NEAR);
- align(4);
-
-L(l6b0);
- test(M, 0x4);
- jle(l6f4, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm1, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm2, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm3, dword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l6f4);
- test(M, 0x2);
- jle(l720, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm1, dword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l720);
- test(M, 0x1);
- jle(l73c, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l73c);
- sub(N, 0x4);
- cmp(N, 0x4);
- jge(l628, T_NEAR);
- align(4);
-
-L(l74c);
- cmp(N, 0x2);
- jl(l8b2, T_NEAR);
- align(4);
-
-L(l758);
- mov(A1, A);
- add(A, 0x2);
- mov(LDA3, M);
- sar(LDA3, 0x3);
- jle(l804, T_NEAR);
- align(4);
-
-L(l76c);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm1, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm2, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm3, eax, 0x0);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm1, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm2, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm3, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm4, eax, 0x0);
- punpcklbw(xmm1, xmm2);
- punpcklbw(xmm3, xmm4);
- punpcklwd(xmm1, xmm3);
- punpcklqdq(xmm0, xmm1);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- dec(LDA3);
- jg(l76c, T_NEAR);
- align(4);
-
-L(l804);
- test(M, 0x4);
- jle(l858, T_NEAR);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm1, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm2, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm3, eax, 0x0);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l858);
- test(M, 0x2);
- jle(l88c, T_NEAR);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm1, eax, 0x0);
- punpcklbw(xmm0, xmm1);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l88c);
- test(M, 0x1);
- jle(l8a4, T_NEAR);
- mov(ax, word[A1-0x80]);
- mov(word[B-0x80], ax);
- sub(B, -2);
- align(4);
-
-L(l8a4);
- sub(N, 0x2);
- cmp(N, 0x2);
- jge(l758, T_NEAR);
- align(4);
-
-L(l8b2);
- cmp(N, 0x1);
- jl(l9d8, T_NEAR);
- align(4);
-
-L(l8bc);
- mov(A1, A);
- add(A, 0x1);
- mov(LDA3, M);
- sar(LDA3, 0x3);
- jle(l944, T_NEAR);
- align(4);
-
-L(l8cc);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x3);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x4);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x5);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x6);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x7);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- dec(LDA3);
- jg(l8cc, T_NEAR);
- align(4);
-
-L(l944);
- test(M, 0x4);
- jle(l98c, T_NEAR);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x3);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l98c);
- test(M, 0x2);
- jle(l9b0, T_NEAR);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- mov(byte[B-0x80], al);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- mov(byte[B-0x7f], al);
- sub(B, -2);
- align(4);
-
-L(l9b0);
- test(M, 0x1);
- jle(l9c8, T_NEAR);
- mov(al, byte[A1-0x80]);
- mov(byte[B-0x80], al);
- sub(B, -1);
- align(4);
-
-L(l9c8);
- sub(N, 0x1);
- cmp(N, 0x1);
- jge(l8bc, T_NEAR);
- align(4);
-
-L(l9d8);
-
- postamble();
-}
-outLocalLabel();
-
-#undef M
-#undef N
-#undef A
-#undef LDA
-#undef ALPHA
-#undef B
-#undef I
-#undef A1
-#undef A2
-#undef LDA3
-#ifdef _WIN32
-#undef ARG_ALPHA
-#undef ARG_B
-#endif
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp
deleted file mode 100644
index 1c11fc6cef..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp
+++ /dev/null
@@ -1,2209 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "jit_generator.hpp"
-#include "common.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-jit_avx512_core_u8_copy_at_kern::jit_avx512_core_u8_copy_at_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
-{
-
-#ifndef _WIN32
-#define M rdi
-#define N rsi
-#define A rdx
-#define LDA rcx
-#define ALPHA r8
-#define B r9
-
-#define I rax
-#define A1 r10
-#define A2 r8
-#define LDA3 r11
-
-#else
-
-#define M rcx
-#define N rdx
-#define A r8
-#define LDA r9
-#define ALPHA rax
-#define B rdi
-
-#define I rax
-#define A1 rsi
-#define A2 r10
-#define LDA3 r11
-
-#define ARG_ALPHA 40+stacksize+rsp
-#define ARG_B 48+stacksize+rsp
-
-#endif
-
-inLocalLabel();
-{
-
-Xbyak::Label l1014;
-Xbyak::Label l1390;
-Xbyak::Label l159c;
-Xbyak::Label l173c;
-Xbyak::Label l18e4;
-Xbyak::Label l1a7c;
-Xbyak::Label l1a8c;
-Xbyak::Label l1a98;
-Xbyak::Label l1ab4;
-Xbyak::Label l1c64;
-Xbyak::Label l1d74;
-Xbyak::Label l1e50;
-Xbyak::Label l1f2c;
-Xbyak::Label l1ffc;
-Xbyak::Label l20;
-Xbyak::Label l200c;
-Xbyak::Label l2018;
-Xbyak::Label l2034;
-Xbyak::Label l2110;
-Xbyak::Label l21a0;
-Xbyak::Label l2210;
-Xbyak::Label l2284;
-Xbyak::Label l22f0;
-Xbyak::Label l2300;
-Xbyak::Label l230c;
-Xbyak::Label l2324;
-Xbyak::Label l2398;
-Xbyak::Label l23e8;
-Xbyak::Label l242c;
-Xbyak::Label l2474;
-Xbyak::Label l24b4;
-Xbyak::Label l24c4;
-Xbyak::Label l24d0;
-Xbyak::Label l24e8;
-Xbyak::Label l2520;
-Xbyak::Label l254c;
-Xbyak::Label l2578;
-Xbyak::Label l25a8;
-Xbyak::Label l25c8;
-Xbyak::Label l25d6;
-Xbyak::Label l25e0;
-Xbyak::Label l25f0;
-Xbyak::Label l260c;
-Xbyak::Label l262c;
-Xbyak::Label l264c;
-Xbyak::Label l2668;
-Xbyak::Label l2680;
-Xbyak::Label l2690;
-Xbyak::Label l44;
-Xbyak::Label l58c;
-Xbyak::Label l8b0;
-Xbyak::Label lb14;
-Xbyak::Label ld84;
-Xbyak::Label lfdc;
-Xbyak::Label lfec;
-Xbyak::Label lff8;
-
- preamble();
-#ifdef _WIN32
- auto stacksize = get_size_of_abi_save_regs();
- mov(ALPHA, ptr[ARG_ALPHA]);
- mov(B, ptr[ARG_B]);
-#endif
-
- mov(N, qword[N]);
- mov(M, qword[M]);
- mov(LDA, qword[LDA]);
- sub(A, -128);
- sub(B, -128);
- lea(LDA3, ptr[LDA+LDA*2]);
- cmp(N, 0x30);
- jl(lfec, T_NEAR);
- align(4);
-
-L(l20);
- mov(A1, A);
- mov(I, LDA);
- shl(I, 0x5);
- lea(I, ptr[I+LDA*8]);
- lea(I, ptr[I+LDA*8]);
- add(A, I);
- mov(I, M);
- sar(I, 0x4);
- jle(l58c, T_NEAR);
- align(4);
-
-L(l44);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1+LDA*1-0x80]);
- movdqu(xmm2, xword[A1+LDA*2-0x80]);
- movdqu(xmm3, xword[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B+0x40], xmm1);
- movdqu(xword[B+0x100], xmm4);
- movdqu(xword[B+0x1c0], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x70], xmm0);
- movdqu(xword[B+0x50], xmm1);
- movdqu(xword[B+0x110], xmm4);
- movdqu(xword[B+0x1d0], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x60], xmm0);
- movdqu(xword[B+0x60], xmm1);
- movdqu(xword[B+0x120], xmm4);
- movdqu(xword[B+0x1e0], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x50], xmm0);
- movdqu(xword[B+0x70], xmm1);
- movdqu(xword[B+0x130], xmm4);
- movdqu(xword[B+0x1f0], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x40], xmm0);
- movdqu(xword[B+0x80], xmm1);
- movdqu(xword[B+0x140], xmm4);
- movdqu(xword[B+0x200], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x30], xmm0);
- movdqu(xword[B+0x90], xmm1);
- movdqu(xword[B+0x150], xmm4);
- movdqu(xword[B+0x210], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x20], xmm0);
- movdqu(xword[B+0xa0], xmm1);
- movdqu(xword[B+0x160], xmm4);
- movdqu(xword[B+0x220], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x10], xmm0);
- movdqu(xword[B+0xb0], xmm1);
- movdqu(xword[B+0x170], xmm4);
- movdqu(xword[B+0x230], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B], xmm0);
- movdqu(xword[B+0xc0], xmm1);
- movdqu(xword[B+0x180], xmm4);
- movdqu(xword[B+0x240], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B+0x10], xmm0);
- movdqu(xword[B+0xd0], xmm1);
- movdqu(xword[B+0x190], xmm4);
- movdqu(xword[B+0x250], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B+0x20], xmm0);
- movdqu(xword[B+0xe0], xmm1);
- movdqu(xword[B+0x1a0], xmm4);
- movdqu(xword[B+0x260], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B+0x30], xmm0);
- movdqu(xword[B+0xf0], xmm1);
- movdqu(xword[B+0x1b0], xmm4);
- movdqu(xword[B+0x270], xmm3);
- sub(A1, -16);
- sub(B, -768);
- dec(I);
- jg(l44, T_NEAR);
- align(4);
-
-L(l58c);
- test(M, 0x8);
- jle(l8b0, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- movq(xmm1, qword[A1+LDA*1-0x80]);
- movq(xmm2, qword[A1+LDA*2-0x80]);
- movq(xmm3, qword[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B+0x40], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x70], xmm0);
- movdqu(xword[B+0x50], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x60], xmm0);
- movdqu(xword[B+0x60], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x50], xmm0);
- movdqu(xword[B+0x70], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x40], xmm0);
- movdqu(xword[B+0x80], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x30], xmm0);
- movdqu(xword[B+0x90], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x20], xmm0);
- movdqu(xword[B+0xa0], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x10], xmm0);
- movdqu(xword[B+0xb0], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B], xmm0);
- movdqu(xword[B+0xc0], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B+0x10], xmm0);
- movdqu(xword[B+0xd0], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B+0x20], xmm0);
- movdqu(xword[B+0xe0], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B+0x30], xmm0);
- movdqu(xword[B+0xf0], xmm1);
- sub(A1, -8);
- sub(B, -384);
- align(4);
-
-L(l8b0);
- test(M, 0x4);
- jle(lb14, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- movd(xmm1, dword[A1+LDA*1-0x80]);
- movd(xmm2, dword[A1+LDA*2-0x80]);
- movd(xmm3, dword[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x70], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x60], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x50], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x40], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x30], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x20], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x10], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B+0x10], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B+0x20], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B+0x30], xmm0);
- sub(A1, -4);
- sub(B, -192);
- align(4);
-
-L(lb14);
- test(M, 0x2);
- jle(ld84, T_NEAR);
- mov(ax, word[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A1+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrw(xmm0, eax, 0x7);
- movdqu(xword[B-0x80], xmm0);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- pinsrw(xmm0, eax, 0x7);
- lea(A2, ptr[A2+LDA*4]);
- movdqu(xword[B-0x70], xmm0);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- pinsrw(xmm0, eax, 0x7);
- lea(A2, ptr[A2+LDA*4]);
- movdqu(xword[B-0x60], xmm0);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- pinsrw(xmm0, eax, 0x7);
- lea(A2, ptr[A2+LDA*4]);
- movdqu(xword[B-0x50], xmm0);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- pinsrw(xmm0, eax, 0x7);
- lea(A2, ptr[A2+LDA*4]);
- movdqu(xword[B-0x40], xmm0);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- pinsrw(xmm0, eax, 0x7);
- lea(A2, ptr[A2+LDA*4]);
- movdqu(xword[B-0x30], xmm0);
- sub(A1, -2);
- sub(B, -96);
- align(4);
-
-L(ld84);
- test(M, 0x1);
- jle(lfdc, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A1+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- pinsrb(xmm0, eax, 0x3);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x4);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x5);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x6);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0x7);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x8);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x9);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0xa);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0xb);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0xc);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0xd);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0xe);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0xf);
- movdqu(xword[B-0x80], xmm0);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0x3);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x4);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x5);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x6);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0x7);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x8);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x9);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0xa);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0xb);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0xc);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0xd);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0xe);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0xf);
- movdqu(xword[B-0x70], xmm0);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0x3);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x4);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x5);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x6);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0x7);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x8);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x9);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0xa);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0xb);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0xc);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0xd);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0xe);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0xf);
- movdqu(xword[B-0x60], xmm0);
- sub(B, -48);
- align(4);
-
-L(lfdc);
- sub(N, 0x30);
- cmp(N, 0x30);
- jge(l20, T_NEAR);
- align(4);
-
-L(lfec);
- cmp(N, 0x20);
- jl(l1a8c, T_NEAR);
- align(4);
-
-L(lff8);
- mov(A1, A);
- mov(I, LDA);
- shl(I, 0x5);
- add(A, I);
- mov(I, M);
- sar(I, 0x4);
- jle(l1390, T_NEAR);
- align(4);
-
-L(l1014);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1+LDA*1-0x80]);
- movdqu(xmm2, xword[A1+LDA*2-0x80]);
- movdqu(xmm3, xword[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B], xmm1);
- movdqu(xword[B+0x80], xmm4);
- movdqu(xword[B+0x100], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x70], xmm0);
- movdqu(xword[B+0x10], xmm1);
- movdqu(xword[B+0x90], xmm4);
- movdqu(xword[B+0x110], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x60], xmm0);
- movdqu(xword[B+0x20], xmm1);
- movdqu(xword[B+0xa0], xmm4);
- movdqu(xword[B+0x120], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x50], xmm0);
- movdqu(xword[B+0x30], xmm1);
- movdqu(xword[B+0xb0], xmm4);
- movdqu(xword[B+0x130], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x40], xmm0);
- movdqu(xword[B+0x40], xmm1);
- movdqu(xword[B+0xc0], xmm4);
- movdqu(xword[B+0x140], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x30], xmm0);
- movdqu(xword[B+0x50], xmm1);
- movdqu(xword[B+0xd0], xmm4);
- movdqu(xword[B+0x150], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x20], xmm0);
- movdqu(xword[B+0x60], xmm1);
- movdqu(xword[B+0xe0], xmm4);
- movdqu(xword[B+0x160], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x10], xmm0);
- movdqu(xword[B+0x70], xmm1);
- movdqu(xword[B+0xf0], xmm4);
- movdqu(xword[B+0x170], xmm3);
- sub(A1, -16);
- sub(B, -512);
- dec(I);
- jg(l1014, T_NEAR);
- align(4);
-
-L(l1390);
- test(M, 0x8);
- jle(l159c, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- movq(xmm1, qword[A1+LDA*1-0x80]);
- movq(xmm2, qword[A1+LDA*2-0x80]);
- movq(xmm3, qword[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x70], xmm0);
- movdqu(xword[B+0x10], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x60], xmm0);
- movdqu(xword[B+0x20], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x50], xmm0);
- movdqu(xword[B+0x30], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x40], xmm0);
- movdqu(xword[B+0x40], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x30], xmm0);
- movdqu(xword[B+0x50], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x20], xmm0);
- movdqu(xword[B+0x60], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x10], xmm0);
- movdqu(xword[B+0x70], xmm1);
- sub(A1, -8);
- sub(B, -256);
- align(4);
-
-L(l159c);
- test(M, 0x4);
- jle(l173c, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- movd(xmm1, dword[A1+LDA*1-0x80]);
- movd(xmm2, dword[A1+LDA*2-0x80]);
- movd(xmm3, dword[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x70], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x60], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x50], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x40], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x30], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x20], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x10], xmm0);
- sub(A1, -4);
- sub(B, -128);
- align(4);
-
-L(l173c);
- test(M, 0x2);
- jle(l18e4, T_NEAR);
- mov(ax, word[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A1+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrw(xmm0, eax, 0x7);
- movdqu(xword[B-0x80], xmm0);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- pinsrw(xmm0, eax, 0x7);
- lea(A2, ptr[A2+LDA*4]);
- movdqu(xword[B-0x70], xmm0);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- pinsrw(xmm0, eax, 0x7);
- lea(A2, ptr[A2+LDA*4]);
- movdqu(xword[B-0x60], xmm0);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- pinsrw(xmm0, eax, 0x7);
- lea(A2, ptr[A2+LDA*4]);
- movdqu(xword[B-0x50], xmm0);
- sub(A1, -2);
- sub(B, -64);
- align(4);
-
-L(l18e4);
- test(M, 0x1);
- jle(l1a7c, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A1+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- pinsrb(xmm0, eax, 0x3);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x4);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x5);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x6);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0x7);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x8);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x9);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0xa);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0xb);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0xc);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0xd);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0xe);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0xf);
- movdqu(xword[B-0x80], xmm0);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0x3);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x4);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x5);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x6);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0x7);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x8);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x9);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0xa);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0xb);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0xc);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0xd);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0xe);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0xf);
- movdqu(xword[B-0x70], xmm0);
- sub(B, -32);
- align(4);
-
-L(l1a7c);
- sub(N, 0x20);
- cmp(N, 0x20);
- jge(lff8, T_NEAR);
- align(4);
-
-L(l1a8c);
- cmp(N, 0x10);
- jl(l200c, T_NEAR);
- align(4);
-
-L(l1a98);
- mov(A1, A);
- mov(I, LDA);
- shl(I, 0x4);
- add(A, I);
- mov(I, M);
- sar(I, 0x4);
- jle(l1c64, T_NEAR);
- align(4);
-
-L(l1ab4);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1+LDA*1-0x80]);
- movdqu(xmm2, xword[A1+LDA*2-0x80]);
- movdqu(xmm3, xword[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x40], xmm1);
- movdqu(xword[B], xmm4);
- movdqu(xword[B+0x40], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x70], xmm0);
- movdqu(xword[B-0x30], xmm1);
- movdqu(xword[B+0x10], xmm4);
- movdqu(xword[B+0x50], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x60], xmm0);
- movdqu(xword[B-0x20], xmm1);
- movdqu(xword[B+0x20], xmm4);
- movdqu(xword[B+0x60], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x50], xmm0);
- movdqu(xword[B-0x10], xmm1);
- movdqu(xword[B+0x30], xmm4);
- movdqu(xword[B+0x70], xmm3);
- sub(A1, -16);
- sub(B, -256);
- dec(I);
- jg(l1ab4, T_NEAR);
- align(4);
-
-L(l1c64);
- test(M, 0x8);
- jle(l1d74, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- movq(xmm1, qword[A1+LDA*1-0x80]);
- movq(xmm2, qword[A1+LDA*2-0x80]);
- movq(xmm3, qword[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x40], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x70], xmm0);
- movdqu(xword[B-0x30], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x60], xmm0);
- movdqu(xword[B-0x20], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x50], xmm0);
- movdqu(xword[B-0x10], xmm1);
- sub(A1, -8);
- sub(B, -128);
- align(4);
-
-L(l1d74);
- test(M, 0x4);
- jle(l1e50, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- movd(xmm1, dword[A1+LDA*1-0x80]);
- movd(xmm2, dword[A1+LDA*2-0x80]);
- movd(xmm3, dword[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x70], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x60], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x50], xmm0);
- sub(A1, -4);
- sub(B, -64);
- align(4);
-
-L(l1e50);
- test(M, 0x2);
- jle(l1f2c, T_NEAR);
- mov(ax, word[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A1+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrw(xmm0, eax, 0x7);
- movdqu(xword[B-0x80], xmm0);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- pinsrw(xmm0, eax, 0x7);
- movdqu(xword[B-0x70], xmm0);
- sub(A1, -2);
- sub(B, -32);
- align(4);
-
-L(l1f2c);
- test(M, 0x1);
- jle(l1ffc, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A1+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- pinsrb(xmm0, eax, 0x3);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x4);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x5);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x6);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0x7);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x8);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x9);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0xa);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0xb);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0xc);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0xd);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0xe);
- mov(al, byte[A2+LDA3*1-0x80]);
- pinsrb(xmm0, eax, 0xf);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l1ffc);
- sub(N, 0x10);
- cmp(N, 0x10);
- jge(l1a98, T_NEAR);
- align(4);
-
-L(l200c);
- cmp(N, 0x8);
- jl(l2300, T_NEAR);
- align(4);
-
-L(l2018);
- mov(A1, A);
- lea(A2, ptr[A1+LDA*4]);
- lea(I, ptr[A1+LDA*8]);
- mov(A, I);
- mov(I, M);
- sar(I, 0x4);
- jle(l2110, T_NEAR);
- align(4);
-
-L(l2034);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1+LDA*1-0x80]);
- movdqu(xmm2, xword[A1+LDA*2-0x80]);
- movdqu(xmm3, xword[A1+LDA3*1-0x80]);
- sub(A1, -16);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x60], xmm1);
- movdqu(xword[B-0x40], xmm4);
- movdqu(xword[B-0x20], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- sub(A2, -16);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x70], xmm0);
- movdqu(xword[B-0x50], xmm1);
- movdqu(xword[B-0x30], xmm4);
- movdqu(xword[B-0x10], xmm3);
- sub(B, -128);
- dec(I);
- jg(l2034, T_NEAR);
- align(4);
-
-L(l2110);
- test(M, 0x8);
- jle(l21a0, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- movq(xmm1, qword[A1+LDA*1-0x80]);
- movq(xmm2, qword[A1+LDA*2-0x80]);
- movq(xmm3, qword[A1+LDA3*1-0x80]);
- sub(A1, -8);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x60], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- sub(A2, -8);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x70], xmm0);
- movdqu(xword[B-0x50], xmm1);
- sub(B, -64);
- align(4);
-
-L(l21a0);
- test(M, 0x4);
- jle(l2210, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- movd(xmm1, dword[A1+LDA*1-0x80]);
- movd(xmm2, dword[A1+LDA*2-0x80]);
- movd(xmm3, dword[A1+LDA3*1-0x80]);
- sub(A1, -4);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- sub(A2, -4);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x70], xmm0);
- sub(B, -32);
- align(4);
-
-L(l2210);
- test(M, 0x2);
- jle(l2284, T_NEAR);
- mov(ax, word[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A1+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A1+LDA3*1-0x80]);
- sub(A1, -2);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- sub(A2, -2);
- pinsrw(xmm0, eax, 0x7);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l2284);
- test(M, 0x1);
- jle(l22f0, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A1+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A1+LDA3*1-0x80]);
- pinsrb(xmm0, eax, 0x3);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x4);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x5);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x6);
- mov(al, byte[A2+LDA3*1-0x80]);
- pinsrb(xmm0, eax, 0x7);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l22f0);
- sub(N, 0x8);
- cmp(N, 0x8);
- jge(l2018, T_NEAR);
- align(4);
-
-L(l2300);
- cmp(N, 0x4);
- jl(l24c4, T_NEAR);
- align(4);
-
-L(l230c);
- mov(A1, A);
- lea(A2, ptr[A1+LDA*2]);
- lea(I, ptr[A1+LDA*4]);
- mov(A, I);
- mov(I, M);
- sar(I, 0x4);
- jle(l2398, T_NEAR);
- align(4);
-
-L(l2324);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1+LDA*1-0x80]);
- sub(A1, -16);
- movdqu(xmm2, xword[A2-0x80]);
- movdqu(xmm3, xword[A2+LDA*1-0x80]);
- sub(A2, -16);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm1);
- movdqu(xword[B-0x60], xmm4);
- movdqu(xword[B-0x50], xmm3);
- sub(B, -64);
- dec(I);
- jg(l2324, T_NEAR);
- align(4);
-
-L(l2398);
- test(M, 0x8);
- jle(l23e8, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- movq(xmm1, qword[A1+LDA*1-0x80]);
- sub(A1, -8);
- movq(xmm2, qword[A2-0x80]);
- movq(xmm3, qword[A2+LDA*1-0x80]);
- sub(A2, -8);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm1);
- sub(B, -32);
- align(4);
-
-L(l23e8);
- test(M, 0x4);
- jle(l242c, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- movd(xmm1, dword[A1+LDA*1-0x80]);
- sub(A1, -4);
- movd(xmm2, dword[A2-0x80]);
- movd(xmm3, dword[A2+LDA*1-0x80]);
- sub(A2, -4);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l242c);
- test(M, 0x2);
- jle(l2474, T_NEAR);
- mov(ax, word[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1+LDA*1-0x80]);
- sub(A1, -2);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A2+LDA*1-0x80]);
- sub(A2, -2);
- pinsrw(xmm0, eax, 0x3);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l2474);
- test(M, 0x1);
- jle(l24b4, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x3);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l24b4);
- sub(N, 0x4);
- cmp(N, 0x4);
- jge(l230c, T_NEAR);
- align(4);
-
-L(l24c4);
- cmp(N, 0x2);
- jl(l25d6, T_NEAR);
- align(4);
-
-L(l24d0);
- mov(A1, A);
- lea(A2, ptr[A1+LDA*1]);
- lea(I, ptr[A1+LDA*2]);
- mov(A, I);
- mov(I, M);
- sar(I, 0x4);
- jle(l2520, T_NEAR);
- align(4);
-
-L(l24e8);
- movdqu(xmm0, xword[A1-0x80]);
- sub(A1, -16);
- movdqu(xmm1, xword[A2-0x80]);
- sub(A2, -16);
- movdqa(xmm2, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm2, xmm1);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm2);
- sub(B, -32);
- dec(I);
- jg(l24e8, T_NEAR);
- align(4);
-
-L(l2520);
- test(M, 0x8);
- jle(l254c, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- sub(A1, -8);
- movq(xmm1, qword[A2-0x80]);
- sub(A2, -8);
- punpckldq(xmm0, xmm1);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l254c);
- test(M, 0x4);
- jle(l2578, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- sub(A1, -4);
- movd(xmm1, dword[A2-0x80]);
- sub(A2, -4);
- punpckldq(xmm0, xmm1);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l2578);
- test(M, 0x2);
- jle(l25a8, T_NEAR);
- mov(ax, word[A1-0x80]);
- sub(A1, -2);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A2-0x80]);
- sub(A2, -2);
- pinsrw(xmm0, eax, 0x1);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l25a8);
- test(M, 0x1);
- jle(l25c8, T_NEAR);
- mov(al, byte[A1-0x80]);
- mov(byte[B-0x80], al);
- mov(al, byte[A2-0x80]);
- mov(byte[B-0x7f], al);
- sub(B, -2);
- align(4);
-
-L(l25c8);
- sub(N, 0x2);
- cmp(N, 0x2);
- jge(l24d0, T_NEAR);
- align(4);
-
-L(l25d6);
- cmp(N, 0x1);
- jl(l2690, T_NEAR);
- align(4);
-
-L(l25e0);
- mov(A1, A);
- add(A, LDA);
- mov(I, M);
- sar(I, 0x4);
- jle(l260c, T_NEAR);
- align(4);
-
-L(l25f0);
- movdqu(xmm0, xword[A1-0x80]);
- sub(A1, -16);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- dec(I);
- jg(l25f0, T_NEAR);
- align(4);
-
-L(l260c);
- test(M, 0x8);
- jle(l262c, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- sub(A1, -8);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l262c);
- test(M, 0x4);
- jle(l264c, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- sub(A1, -4);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l264c);
- test(M, 0x2);
- jle(l2668, T_NEAR);
- mov(ax, word[A1-0x80]);
- mov(word[B-0x80], ax);
- sub(A1, -2);
- sub(B, -2);
- align(4);
-
-L(l2668);
- test(M, 0x1);
- jle(l2680, T_NEAR);
- mov(al, byte[A1-0x80]);
- mov(byte[B-0x80], al);
- sub(B, -1);
- align(4);
-
-L(l2680);
- sub(N, 0x1);
- cmp(N, 0x1);
- jge(l25e0, T_NEAR);
- align(4);
-
-L(l2690);
-
- postamble();
-}
-outLocalLabel();
-
-#undef M
-#undef N
-#undef A
-#undef LDA
-#undef ALPHA
-#undef B
-#undef I
-#undef A1
-#undef A2
-#undef LDA3
-#ifdef _WIN32
-#undef ARG_ALPHA
-#undef ARG_B
-#endif
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp
deleted file mode 100644
index 56c36ee14a..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp
+++ /dev/null
@@ -1,564 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "jit_generator.hpp"
-#include "common.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-jit_avx512_core_u8_copy_bn_kern::jit_avx512_core_u8_copy_bn_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
-{
-
-#ifndef _WIN32
-#define M rdi
-#define N rsi
-#define A rdx
-#define LDA rcx
-#define ALPHA r8
-#define B r9
-
-#define I rax
-#define A1 r10
-#define A2 r8
-#define LDA3 r11
-
-#else
-
-#define M rcx
-#define N rdx
-#define A r8
-#define LDA r9
-#define ALPHA rax
-#define B rdi
-
-#define I rax
-#define A1 rsi
-#define A2 r10
-#define LDA3 r11
-
-#define ARG_ALPHA 40+stacksize+rsp
-#define ARG_B 48+stacksize+rsp
-
-#endif
-
-inLocalLabel();
-{
-
-Xbyak::Label l118;
-Xbyak::Label l1a8;
-Xbyak::Label l20;
-Xbyak::Label l218;
-Xbyak::Label l28c;
-Xbyak::Label l2f8;
-Xbyak::Label l308;
-Xbyak::Label l314;
-Xbyak::Label l32c;
-Xbyak::Label l3a0;
-Xbyak::Label l3c;
-Xbyak::Label l3f0;
-Xbyak::Label l434;
-Xbyak::Label l47c;
-Xbyak::Label l4bc;
-Xbyak::Label l4cc;
-Xbyak::Label l4d8;
-Xbyak::Label l4f0;
-Xbyak::Label l528;
-Xbyak::Label l554;
-Xbyak::Label l580;
-Xbyak::Label l5b0;
-Xbyak::Label l5d0;
-Xbyak::Label l5de;
-Xbyak::Label l5e8;
-Xbyak::Label l5f8;
-Xbyak::Label l614;
-Xbyak::Label l634;
-Xbyak::Label l654;
-Xbyak::Label l670;
-Xbyak::Label l688;
-Xbyak::Label l698;
-
- preamble();
-#ifdef _WIN32
- auto stacksize = get_size_of_abi_save_regs();
- mov(ALPHA, ptr[ARG_ALPHA]);
- mov(B, ptr[ARG_B]);
-#endif
-
- mov(N, qword[N]);
- mov(M, qword[M]);
- mov(LDA, qword[LDA]);
- sub(A, -128);
- sub(B, -128);
- lea(LDA3, ptr[LDA+LDA*2]);
- cmp(N, 0x8);
- jl(l308, T_NEAR);
- align(4);
-
-L(l20);
- mov(A1, A);
- lea(A2, ptr[A1+LDA*4]);
- lea(I, ptr[A1+LDA*8]);
- mov(A, I);
- mov(I, M);
- sar(I, 0x4);
- jle(l118, T_NEAR);
- align(4);
-
-L(l3c);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1+LDA*1-0x80]);
- movdqu(xmm2, xword[A1+LDA*2-0x80]);
- movdqu(xmm3, xword[A1+LDA3*1-0x80]);
- sub(A1, -16);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x60], xmm1);
- movdqu(xword[B-0x40], xmm4);
- movdqu(xword[B-0x20], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- sub(A2, -16);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x70], xmm0);
- movdqu(xword[B-0x50], xmm1);
- movdqu(xword[B-0x30], xmm4);
- movdqu(xword[B-0x10], xmm3);
- sub(B, -128);
- dec(I);
- jg(l3c, T_NEAR);
- align(4);
-
-L(l118);
- test(M, 0x8);
- jle(l1a8, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- movq(xmm1, qword[A1+LDA*1-0x80]);
- movq(xmm2, qword[A1+LDA*2-0x80]);
- movq(xmm3, qword[A1+LDA3*1-0x80]);
- sub(A1, -8);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x60], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- sub(A2, -8);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x70], xmm0);
- movdqu(xword[B-0x50], xmm1);
- sub(B, -64);
- align(4);
-
-L(l1a8);
- test(M, 0x4);
- jle(l218, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- movd(xmm1, dword[A1+LDA*1-0x80]);
- movd(xmm2, dword[A1+LDA*2-0x80]);
- movd(xmm3, dword[A1+LDA3*1-0x80]);
- sub(A1, -4);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- sub(A2, -4);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x70], xmm0);
- sub(B, -32);
- align(4);
-
-L(l218);
- test(M, 0x2);
- jle(l28c, T_NEAR);
- mov(ax, word[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A1+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A1+LDA3*1-0x80]);
- sub(A1, -2);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- sub(A2, -2);
- pinsrw(xmm0, eax, 0x7);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l28c);
- test(M, 0x1);
- jle(l2f8, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A1+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A1+LDA3*1-0x80]);
- pinsrb(xmm0, eax, 0x3);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x4);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x5);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x6);
- mov(al, byte[A2+LDA3*1-0x80]);
- pinsrb(xmm0, eax, 0x7);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l2f8);
- sub(N, 0x8);
- cmp(N, 0x8);
- jge(l20, T_NEAR);
- align(4);
-
-L(l308);
- cmp(N, 0x4);
- jl(l4cc, T_NEAR);
- align(4);
-
-L(l314);
- mov(A1, A);
- lea(A2, ptr[A1+LDA*2]);
- lea(I, ptr[A1+LDA*4]);
- mov(A, I);
- mov(I, M);
- sar(I, 0x4);
- jle(l3a0, T_NEAR);
- align(4);
-
-L(l32c);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1+LDA*1-0x80]);
- sub(A1, -16);
- movdqu(xmm2, xword[A2-0x80]);
- movdqu(xmm3, xword[A2+LDA*1-0x80]);
- sub(A2, -16);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm1);
- movdqu(xword[B-0x60], xmm4);
- movdqu(xword[B-0x50], xmm3);
- sub(B, -64);
- dec(I);
- jg(l32c, T_NEAR);
- align(4);
-
-L(l3a0);
- test(M, 0x8);
- jle(l3f0, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- movq(xmm1, qword[A1+LDA*1-0x80]);
- sub(A1, -8);
- movq(xmm2, qword[A2-0x80]);
- movq(xmm3, qword[A2+LDA*1-0x80]);
- sub(A2, -8);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm1);
- sub(B, -32);
- align(4);
-
-L(l3f0);
- test(M, 0x4);
- jle(l434, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- movd(xmm1, dword[A1+LDA*1-0x80]);
- sub(A1, -4);
- movd(xmm2, dword[A2-0x80]);
- movd(xmm3, dword[A2+LDA*1-0x80]);
- sub(A2, -4);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l434);
- test(M, 0x2);
- jle(l47c, T_NEAR);
- mov(ax, word[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1+LDA*1-0x80]);
- sub(A1, -2);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A2+LDA*1-0x80]);
- sub(A2, -2);
- pinsrw(xmm0, eax, 0x3);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l47c);
- test(M, 0x1);
- jle(l4bc, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x3);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l4bc);
- sub(N, 0x4);
- cmp(N, 0x4);
- jge(l314, T_NEAR);
- align(4);
-
-L(l4cc);
- cmp(N, 0x2);
- jl(l5de, T_NEAR);
- align(4);
-
-L(l4d8);
- mov(A1, A);
- lea(A2, ptr[A1+LDA*1]);
- lea(I, ptr[A1+LDA*2]);
- mov(A, I);
- mov(I, M);
- sar(I, 0x4);
- jle(l528, T_NEAR);
- align(4);
-
-L(l4f0);
- movdqu(xmm0, xword[A1-0x80]);
- sub(A1, -16);
- movdqu(xmm1, xword[A2-0x80]);
- sub(A2, -16);
- movdqa(xmm2, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm2, xmm1);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm2);
- sub(B, -32);
- dec(I);
- jg(l4f0, T_NEAR);
- align(4);
-
-L(l528);
- test(M, 0x8);
- jle(l554, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- sub(A1, -8);
- movq(xmm1, qword[A2-0x80]);
- sub(A2, -8);
- punpckldq(xmm0, xmm1);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l554);
- test(M, 0x4);
- jle(l580, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- sub(A1, -4);
- movd(xmm1, dword[A2-0x80]);
- sub(A2, -4);
- punpckldq(xmm0, xmm1);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l580);
- test(M, 0x2);
- jle(l5b0, T_NEAR);
- mov(ax, word[A1-0x80]);
- sub(A1, -2);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A2-0x80]);
- sub(A2, -2);
- pinsrw(xmm0, eax, 0x1);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l5b0);
- test(M, 0x1);
- jle(l5d0, T_NEAR);
- mov(al, byte[A1-0x80]);
- mov(byte[B-0x80], al);
- mov(al, byte[A2-0x80]);
- mov(byte[B-0x7f], al);
- sub(B, -2);
- align(4);
-
-L(l5d0);
- sub(N, 0x2);
- cmp(N, 0x2);
- jge(l4d8, T_NEAR);
- align(4);
-
-L(l5de);
- cmp(N, 0x1);
- jl(l698, T_NEAR);
- align(4);
-
-L(l5e8);
- mov(A1, A);
- add(A, LDA);
- mov(I, M);
- sar(I, 0x4);
- jle(l614, T_NEAR);
- align(4);
-
-L(l5f8);
- movdqu(xmm0, xword[A1-0x80]);
- sub(A1, -16);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- dec(I);
- jg(l5f8, T_NEAR);
- align(4);
-
-L(l614);
- test(M, 0x8);
- jle(l634, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- sub(A1, -8);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l634);
- test(M, 0x4);
- jle(l654, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- sub(A1, -4);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l654);
- test(M, 0x2);
- jle(l670, T_NEAR);
- mov(ax, word[A1-0x80]);
- mov(word[B-0x80], ax);
- sub(A1, -2);
- sub(B, -2);
- align(4);
-
-L(l670);
- test(M, 0x1);
- jle(l688, T_NEAR);
- mov(al, byte[A1-0x80]);
- mov(byte[B-0x80], al);
- sub(B, -1);
- align(4);
-
-L(l688);
- sub(N, 0x1);
- cmp(N, 0x1);
- jge(l5e8, T_NEAR);
- align(4);
-
-L(l698);
-
- postamble();
-}
-outLocalLabel();
-
-#undef M
-#undef N
-#undef A
-#undef LDA
-#undef ALPHA
-#undef B
-#undef I
-#undef A1
-#undef A2
-#undef LDA3
-#ifdef _WIN32
-#undef ARG_ALPHA
-#undef ARG_B
-#endif
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp
deleted file mode 100644
index 53e99d94de..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp
+++ /dev/null
@@ -1,501 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "jit_generator.hpp"
-#include "common.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-jit_avx512_core_u8_copy_bt_kern::jit_avx512_core_u8_copy_bt_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
-{
-
-#ifndef _WIN32
-#define M rdi
-#define N rsi
-#define A rdx
-#define LDA rcx
-#define ALPHA r8
-#define B r9
-
-#define I rax
-#define A1 r10
-#define A2 r8
-#define LDA3 r11
-
-#else
-
-#define M rcx
-#define N rdx
-#define A r8
-#define LDA r9
-#define ALPHA rax
-#define B rdi
-
-#define I rax
-#define A1 rsi
-#define A2 r10
-#define LDA3 r11
-
-#define ARG_ALPHA 40+stacksize+rsp
-#define ARG_B 48+stacksize+rsp
-
-#endif
-
-inLocalLabel();
-{
-
-Xbyak::Label l120;
-Xbyak::Label l14c;
-Xbyak::Label l168;
-Xbyak::Label l178;
-Xbyak::Label l184;
-Xbyak::Label l194;
-Xbyak::Label l20;
-Xbyak::Label l20c;
-Xbyak::Label l250;
-Xbyak::Label l27c;
-Xbyak::Label l298;
-Xbyak::Label l2a8;
-Xbyak::Label l2b4;
-Xbyak::Label l2c8;
-Xbyak::Label l34;
-Xbyak::Label l360;
-Xbyak::Label l3b4;
-Xbyak::Label l3e8;
-Xbyak::Label l400;
-Xbyak::Label l40e;
-Xbyak::Label l418;
-Xbyak::Label l428;
-Xbyak::Label l4a0;
-Xbyak::Label l4e8;
-Xbyak::Label l50c;
-Xbyak::Label l524;
-Xbyak::Label l534;
-Xbyak::Label lcc;
-
- preamble();
-#ifdef _WIN32
- auto stacksize = get_size_of_abi_save_regs();
- mov(ALPHA, ptr[ARG_ALPHA]);
- mov(B, ptr[ARG_B]);
-#endif
-
- mov(M, qword[M]);
- mov(N, qword[N]);
- mov(LDA, qword[LDA]);
- lea(LDA3, ptr[LDA+LDA*2]);
- sub(A, -128);
- sub(B, -128);
- cmp(N, 0x8);
- jl(l178, T_NEAR);
- align(4);
-
-L(l20);
- mov(A1, A);
- add(A, 0x8);
- mov(I, M);
- sar(I, 0x3);
- jle(lcc, T_NEAR);
- align(4);
-
-L(l34);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm1, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm2, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm3, qword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm1);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm1, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm2, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm3, qword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- movdqu(xword[B-0x60], xmm0);
- movdqu(xword[B-0x50], xmm1);
- sub(B, -64);
- dec(I);
- jg(l34, T_NEAR);
- align(4);
-
-L(lcc);
- test(M, 0x4);
- jle(l120, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm1, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm2, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm3, qword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm1);
- sub(B, -32);
- align(4);
-
-L(l120);
- test(M, 0x2);
- jle(l14c, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm1, qword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l14c);
- test(M, 0x1);
- jle(l168, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l168);
- sub(N, 0x8);
- cmp(N, 0x8);
- jge(l20, T_NEAR);
- align(4);
-
-L(l178);
- cmp(N, 0x4);
- jl(l2a8, T_NEAR);
- align(4);
-
-L(l184);
- mov(A1, A);
- add(A, 0x4);
- mov(I, M);
- sar(I, 0x3);
- jle(l20c, T_NEAR);
- align(4);
-
-L(l194);
- movd(xmm0, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm1, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm2, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm3, dword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- movdqu(xword[B-0x80], xmm0);
- movd(xmm0, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm1, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm2, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm3, dword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- movdqu(xword[B-0x70], xmm0);
- sub(B, -32);
- dec(I);
- jg(l194, T_NEAR);
- align(4);
-
-L(l20c);
- test(M, 0x4);
- jle(l250, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm1, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm2, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm3, dword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l250);
- test(M, 0x2);
- jle(l27c, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm1, dword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l27c);
- test(M, 0x1);
- jle(l298, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l298);
- sub(N, 0x4);
- cmp(N, 0x4);
- jge(l184, T_NEAR);
- align(4);
-
-L(l2a8);
- cmp(N, 0x2);
- jl(l40e, T_NEAR);
- align(4);
-
-L(l2b4);
- mov(A1, A);
- add(A, 0x2);
- mov(LDA3, M);
- sar(LDA3, 0x3);
- jle(l360, T_NEAR);
- align(4);
-
-L(l2c8);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm1, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm2, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm3, eax, 0x0);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm1, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm2, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm3, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm4, eax, 0x0);
- punpcklbw(xmm1, xmm2);
- punpcklbw(xmm3, xmm4);
- punpcklwd(xmm1, xmm3);
- punpcklqdq(xmm0, xmm1);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- dec(LDA3);
- jg(l2c8, T_NEAR);
- align(4);
-
-L(l360);
- test(M, 0x4);
- jle(l3b4, T_NEAR);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm1, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm2, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm3, eax, 0x0);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l3b4);
- test(M, 0x2);
- jle(l3e8, T_NEAR);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm1, eax, 0x0);
- punpcklbw(xmm0, xmm1);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l3e8);
- test(M, 0x1);
- jle(l400, T_NEAR);
- mov(ax, word[A1-0x80]);
- mov(word[B-0x80], ax);
- sub(B, -2);
- align(4);
-
-L(l400);
- sub(N, 0x2);
- cmp(N, 0x2);
- jge(l2b4, T_NEAR);
- align(4);
-
-L(l40e);
- cmp(N, 0x1);
- jl(l534, T_NEAR);
- align(4);
-
-L(l418);
- mov(A1, A);
- add(A, 0x1);
- mov(LDA3, M);
- sar(LDA3, 0x3);
- jle(l4a0, T_NEAR);
- align(4);
-
-L(l428);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x3);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x4);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x5);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x6);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x7);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- dec(LDA3);
- jg(l428, T_NEAR);
- align(4);
-
-L(l4a0);
- test(M, 0x4);
- jle(l4e8, T_NEAR);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x3);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l4e8);
- test(M, 0x2);
- jle(l50c, T_NEAR);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- mov(byte[B-0x80], al);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- mov(byte[B-0x7f], al);
- sub(B, -2);
- align(4);
-
-L(l50c);
- test(M, 0x1);
- jle(l524, T_NEAR);
- mov(al, byte[A1-0x80]);
- mov(byte[B-0x80], al);
- sub(B, -1);
- align(4);
-
-L(l524);
- sub(N, 0x1);
- cmp(N, 0x1);
- jge(l418, T_NEAR);
- align(4);
-
-L(l534);
-
- postamble();
-}
-outLocalLabel();
-
-#undef M
-#undef N
-#undef A
-#undef LDA
-#undef ALPHA
-#undef B
-#undef I
-#undef A1
-#undef A2
-#undef LDA3
-#ifdef _WIN32
-#undef ARG_ALPHA
-#undef ARG_B
-#endif
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp
deleted file mode 100644
index 49a312fc88..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp
+++ /dev/null
@@ -1,1283 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "jit_generator.hpp"
-#include "common.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-jit_avx512_core_u8_copy_sum_an_kern::jit_avx512_core_u8_copy_sum_an_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
-{
-
-#ifndef _WIN32
-#define M rdi
-#define N rsi
-#define A rdx
-#define LDA rcx
-#define ALPHA r8
-#define B r9
-
-#define I rax
-#define A1 r10
-#define A2 r8
-#define LDA3 r11
-
-#define ARG_BIAS 24+stacksize+rsp
-
-#else
-
-#define M rcx
-#define N rdx
-#define A r8
-#define LDA r9
-#define ALPHA rax
-#define B rdi
-
-#define I rax
-#define A1 rsi
-#define A2 r10
-#define LDA3 r11
-
-#define ARG_ALPHA 40+stacksize+rsp
-#define ARG_B 48+stacksize+rsp
-#define ARG_BIAS 72+stacksize+rsp
-
-#endif
-
-inLocalLabel();
-{
-
-Xbyak::Label l1024;
-Xbyak::Label l1090;
-Xbyak::Label l10d4;
-Xbyak::Label l10fc;
-Xbyak::Label l111a;
-Xbyak::Label l1124;
-Xbyak::Label l113c;
-Xbyak::Label l11d4;
-Xbyak::Label l1234;
-Xbyak::Label l1278;
-Xbyak::Label l129c;
-Xbyak::Label l12bc;
-Xbyak::Label l20;
-Xbyak::Label l2a0;
-Xbyak::Label l3c0;
-Xbyak::Label l438;
-Xbyak::Label l480;
-Xbyak::Label l48c;
-Xbyak::Label l4c8;
-Xbyak::Label l5c;
-Xbyak::Label l6a8;
-Xbyak::Label l7b4;
-Xbyak::Label l850;
-Xbyak::Label l89c;
-Xbyak::Label l8a8;
-Xbyak::Label l8d0;
-Xbyak::Label l9d0;
-Xbyak::Label la64;
-Xbyak::Label lab8;
-Xbyak::Label lae8;
-Xbyak::Label laf4;
-Xbyak::Label lb14;
-Xbyak::Label lc30;
-Xbyak::Label lcc8;
-Xbyak::Label ld1c;
-Xbyak::Label ld54;
-Xbyak::Label ld78;
-Xbyak::Label ld84;
-Xbyak::Label ld9c;
-Xbyak::Label le58;
-Xbyak::Label lebc;
-Xbyak::Label lef8;
-Xbyak::Label lf1c;
-Xbyak::Label lf3c;
-Xbyak::Label lf48;
-Xbyak::Label lf60;
-
- preamble();
- auto stacksize = get_size_of_abi_save_regs();
-#ifdef _WIN32
- mov(ALPHA, ptr[ARG_ALPHA]);
- mov(B, ptr[ARG_B]);
-#endif
-
- mov(M, qword[M]);
- mov(N, qword[N]);
- mov(LDA, qword[LDA]);
- lea(LDA3, ptr[LDA+LDA*2]);
- sub(A, -128);
- sub(B, -128);
- cmp(N, 0x30);
- jl(l480, T_NEAR);
- align(4);
-
-L(l20);
- mov(A1, A);
- add(A, 0x30);
- vxorps(ymm8, ymm8, ymm8);
- vxorps(ymm9, ymm9, ymm9);
- vxorps(ymm10, ymm10, ymm10);
- vxorps(ymm11, ymm11, ymm11);
- vxorps(ymm12, ymm12, ymm12);
- vxorps(ymm13, ymm13, ymm13);
- vxorps(ymm14, ymm14, ymm14);
- vxorps(ymm15, ymm15, ymm15);
- mov(I, M);
- sar(I, 0x2);
- jle(l2a0, T_NEAR);
- align(4);
-
-L(l5c);
- vmovdqu(xmm0, xword[A1-0x80]);
- vmovdqu(xmm1, xword[A1+LDA*1-0x80]);
- vmovdqu(xmm2, xword[A1+LDA*2-0x80]);
- vmovdqu(xmm3, xword[A1+LDA3*1-0x80]);
- vpunpcklbw(xmm4, xmm0, xmm1);
- vpunpckhbw(xmm5, xmm0, xmm1);
- vpunpcklbw(xmm6, xmm2, xmm3);
- vpunpckhbw(xmm7, xmm2, xmm3);
- vpunpcklwd(xmm0, xmm4, xmm6);
- vpunpckhwd(xmm1, xmm4, xmm6);
- vpunpcklwd(xmm2, xmm5, xmm7);
- vpunpckhwd(xmm3, xmm5, xmm7);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm1);
- vmovhlps(xmm7, xmm1, xmm1);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm8, ymm8, ymm5);
- vmovdqu(xword[B-0x80], xmm0);
- vmovdqu(xword[B-0x70], xmm1);
- vpmovsxbw(ymm5, xmm2);
- vmovhlps(xmm6, xmm2, xmm2);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm3);
- vmovhlps(xmm7, xmm3, xmm3);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm9, ymm9, ymm5);
- vmovdqu(xword[B-0x60], xmm2);
- vmovdqu(xword[B-0x50], xmm3);
- vmovdqu(xmm0, xword[A1-0x70]);
- vmovdqu(xmm1, xword[A1+LDA*1-0x70]);
- vmovdqu(xmm2, xword[A1+LDA*2-0x70]);
- vmovdqu(xmm3, xword[A1+LDA3*1-0x70]);
- vpunpcklbw(xmm4, xmm0, xmm1);
- vpunpckhbw(xmm5, xmm0, xmm1);
- vpunpcklbw(xmm6, xmm2, xmm3);
- vpunpckhbw(xmm7, xmm2, xmm3);
- vpunpcklwd(xmm0, xmm4, xmm6);
- vpunpckhwd(xmm1, xmm4, xmm6);
- vpunpcklwd(xmm2, xmm5, xmm7);
- vpunpckhwd(xmm3, xmm5, xmm7);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm1);
- vmovhlps(xmm7, xmm1, xmm1);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm10, ymm10, ymm5);
- vmovdqu(xword[B-0x40], xmm0);
- vmovdqu(xword[B-0x30], xmm1);
- vpmovsxbw(ymm5, xmm2);
- vmovhlps(xmm6, xmm2, xmm2);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm3);
- vmovhlps(xmm7, xmm3, xmm3);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm11, ymm11, ymm5);
- vmovdqu(xword[B-0x20], xmm2);
- vmovdqu(xword[B-0x10], xmm3);
- vmovdqu(xmm0, xword[A1-0x60]);
- vmovdqu(xmm1, xword[A1+LDA*1-0x60]);
- vmovdqu(xmm2, xword[A1+LDA*2-0x60]);
- vmovdqu(xmm3, xword[A1+LDA3*1-0x60]);
- lea(A1, ptr[A1+LDA*4]);
- vpunpcklbw(xmm4, xmm0, xmm1);
- vpunpckhbw(xmm5, xmm0, xmm1);
- vpunpcklbw(xmm6, xmm2, xmm3);
- vpunpckhbw(xmm7, xmm2, xmm3);
- vpunpcklwd(xmm0, xmm4, xmm6);
- vpunpckhwd(xmm1, xmm4, xmm6);
- vpunpcklwd(xmm2, xmm5, xmm7);
- vpunpckhwd(xmm3, xmm5, xmm7);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm1);
- vmovhlps(xmm7, xmm1, xmm1);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm12, ymm12, ymm5);
- vmovdqu(xword[B], xmm0);
- vmovdqu(xword[B+0x10], xmm1);
- vpmovsxbw(ymm5, xmm2);
- vmovhlps(xmm6, xmm2, xmm2);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm3);
- vmovhlps(xmm7, xmm3, xmm3);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm13, ymm13, ymm5);
- vmovdqu(xword[B+0x20], xmm2);
- vmovdqu(xword[B+0x30], xmm3);
- sub(B, -192);
- dec(I);
- jg(l5c, T_NEAR);
- align(4);
-
-L(l2a0);
- test(M, 0x2);
- jle(l3c0, T_NEAR);
- vmovdqu(xmm0, xword[A1-0x80]);
- vmovdqu(xmm1, xword[A1-0x70]);
- vmovdqu(xmm2, xword[A1-0x60]);
- add(A1, LDA);
- vmovdqu(xmm6, xword[A1-0x80]);
- vmovdqu(xmm4, xword[A1-0x70]);
- vmovdqu(xmm5, xword[A1-0x60]);
- add(A1, LDA);
- vpunpcklbw(xmm3, xmm0, xmm6);
- vpunpckhbw(xmm0, xmm0, xmm6);
- vpmovsxbw(ymm7, xmm3);
- vmovhlps(xmm6, xmm3, xmm3);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm7, ymm7, ymm6);
- vpmovsxwd(ymm7, xmm7);
- vpaddd(ymm8, ymm8, ymm7);
- vmovdqu(xword[B-0x80], xmm3);
- vpmovsxbw(ymm7, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm7, ymm7, ymm6);
- vpmovsxwd(ymm7, xmm7);
- vpaddd(ymm9, ymm9, ymm7);
- vmovdqu(xword[B-0x70], xmm0);
- vpunpcklbw(xmm3, xmm1, xmm4);
- vpunpckhbw(xmm0, xmm1, xmm4);
- vpmovsxbw(ymm7, xmm3);
- vmovhlps(xmm6, xmm3, xmm3);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm7, ymm7, ymm6);
- vpmovsxwd(ymm7, xmm7);
- vpaddd(ymm10, ymm10, ymm7);
- vmovdqu(xword[B-0x60], xmm3);
- vpmovsxbw(ymm7, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm7, ymm7, ymm6);
- vpmovsxwd(ymm7, xmm7);
- vpaddd(ymm11, ymm11, ymm7);
- vmovdqu(xword[B-0x50], xmm0);
- vpunpcklbw(xmm3, xmm2, xmm5);
- vpunpckhbw(xmm0, xmm2, xmm5);
- vpmovsxbw(ymm7, xmm3);
- vmovhlps(xmm6, xmm3, xmm3);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm7, ymm7, ymm6);
- vpmovsxwd(ymm7, xmm7);
- vpaddd(ymm12, ymm12, ymm7);
- vmovdqu(xword[B-0x40], xmm3);
- vpmovsxbw(ymm7, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm7, ymm7, ymm6);
- vpmovsxwd(ymm7, xmm7);
- vpaddd(ymm13, ymm13, ymm7);
- vmovdqu(xword[B-0x30], xmm0);
- sub(B, -96);
- align(4);
-
-L(l3c0);
- test(M, 0x1);
- jle(l438, T_NEAR);
- vmovdqu(xmm0, xword[A1-0x80]);
- vmovdqu(xmm1, xword[A1-0x70]);
- vmovdqu(xmm2, xword[A1-0x60]);
- add(A1, LDA);
- vpmovsxbd(ymm7, xmm0);
- vpaddd(ymm8, ymm8, ymm7);
- vmovhlps(xmm7, xmm0, xmm0);
- vpmovsxbd(ymm7, xmm7);
- vpaddd(ymm9, ymm9, ymm7);
- vmovdqu(xword[B-0x80], xmm0);
- vpmovsxbd(ymm7, xmm1);
- vpaddd(ymm10, ymm10, ymm7);
- vmovhlps(xmm7, xmm1, xmm1);
- vpmovsxbd(ymm7, xmm7);
- vpaddd(ymm11, ymm11, ymm7);
- vmovdqu(xword[B-0x70], xmm1);
- vpmovsxbd(ymm7, xmm2);
- vpaddd(ymm12, ymm12, ymm7);
- vmovhlps(xmm7, xmm2, xmm2);
- vpmovsxbd(ymm7, xmm7);
- vpaddd(ymm13, ymm13, ymm7);
- vmovdqu(xword[B-0x60], xmm2);
- sub(B, -48);
- align(4);
-
-L(l438);
- mov(A1, qword[ARG_BIAS]);
- vmovdqu(yword[A1], ymm8);
- vmovdqu(yword[A1+0x20], ymm9);
- vmovdqu(yword[A1+0x40], ymm10);
- vmovdqu(yword[A1+0x60], ymm11);
- vmovdqu(yword[A1+0x80], ymm12);
- vmovdqu(yword[A1+0xa0], ymm13);
- add(qword[ARG_BIAS], 0xc0);
- sub(N, 0x30);
- cmp(N, 0x30);
- jge(l20, T_NEAR);
- vzeroupper();
- align(4);
-
-L(l480);
- cmp(N, 0x20);
- jl(l89c, T_NEAR);
- align(4);
-
-L(l48c);
- mov(A1, A);
- add(A, 0x20);
- pxor(xmm8, xmm8);
- pxor(xmm9, xmm9);
- pxor(xmm10, xmm10);
- pxor(xmm11, xmm11);
- pxor(xmm12, xmm12);
- pxor(xmm13, xmm13);
- pxor(xmm14, xmm14);
- pxor(xmm15, xmm15);
- mov(I, M);
- sar(I, 0x2);
- jle(l6a8, T_NEAR);
- align(4);
-
-L(l4c8);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1+LDA*1-0x80]);
- movdqu(xmm2, xword[A1+LDA*2-0x80]);
- movdqu(xmm3, xword[A1+LDA3*1-0x80]);
- movdqa(xmm4, xmm0);
- punpcklbw(xmm0, xmm1);
- punpckhbw(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpcklbw(xmm2, xmm3);
- punpckhbw(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- movdqa(xmm2, xmm4);
- punpcklwd(xmm4, xmm5);
- punpckhwd(xmm2, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x80], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x70], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movdqu(xword[B-0x60], xmm4);
- pmovsxbw(xmm5, xmm2);
- movhlps(xmm6, xmm2);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm11, xmm5);
- movdqu(xword[B-0x50], xmm2);
- movdqu(xmm0, xword[A1-0x70]);
- movdqu(xmm1, xword[A1+LDA*1-0x70]);
- movdqu(xmm2, xword[A1+LDA*2-0x70]);
- movdqu(xmm3, xword[A1+LDA3*1-0x70]);
- lea(A1, ptr[A1+LDA*4]);
- movdqa(xmm4, xmm0);
- punpcklbw(xmm0, xmm1);
- punpckhbw(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpcklbw(xmm2, xmm3);
- punpckhbw(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- movdqa(xmm2, xmm4);
- punpcklwd(xmm4, xmm5);
- punpckhwd(xmm2, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm12, xmm5);
- movdqu(xword[B-0x40], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm13, xmm5);
- movdqu(xword[B-0x30], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm14, xmm5);
- movdqu(xword[B-0x20], xmm4);
- pmovsxbw(xmm5, xmm2);
- movhlps(xmm6, xmm2);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm15, xmm5);
- movdqu(xword[B-0x10], xmm2);
- sub(B, -128);
- dec(I);
- jg(l4c8, T_NEAR);
- align(4);
-
-L(l6a8);
- test(M, 0x2);
- jle(l7b4, T_NEAR);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1-0x70]);
- add(A1, LDA);
- movdqu(xmm2, xword[A1-0x80]);
- movdqu(xmm3, xword[A1-0x70]);
- add(A1, LDA);
- movdqa(xmm4, xmm0);
- punpcklbw(xmm0, xmm2);
- punpckhbw(xmm4, xmm2);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm6, xmm6);
- pmovsxwd(xmm6, xmm6);
- paddd(xmm9, xmm6);
- movdqu(xword[B-0x80], xmm0);
- pmovsxbw(xmm5, xmm4);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm6, xmm6);
- pmovsxwd(xmm6, xmm6);
- paddd(xmm11, xmm6);
- movdqu(xword[B-0x70], xmm4);
- movdqa(xmm4, xmm1);
- punpcklbw(xmm1, xmm3);
- punpckhbw(xmm4, xmm3);
- pmovsxbw(xmm5, xmm1);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm12, xmm5);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm6, xmm6);
- pmovsxwd(xmm6, xmm6);
- paddd(xmm13, xmm6);
- movdqu(xword[B-0x60], xmm1);
- pmovsxbw(xmm5, xmm4);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm14, xmm5);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm6, xmm6);
- pmovsxwd(xmm6, xmm6);
- paddd(xmm15, xmm6);
- movdqu(xword[B-0x50], xmm4);
- sub(B, -64);
- align(4);
-
-L(l7b4);
- test(M, 0x1);
- jle(l850, T_NEAR);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1-0x70]);
- add(A1, LDA);
- pmovsxbd(xmm5, xmm0);
- paddd(xmm8, xmm5);
- pshufd(xmm6, xmm0, 0x55);
- pmovsxbd(xmm6, xmm6);
- paddd(xmm9, xmm6);
- pshufd(xmm5, xmm0, 0xaa);
- pmovsxbd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- pshufd(xmm6, xmm0, 0xff);
- pmovsxbd(xmm6, xmm6);
- paddd(xmm11, xmm6);
- movdqu(xword[B-0x80], xmm0);
- pmovsxbd(xmm5, xmm1);
- paddd(xmm12, xmm5);
- pshufd(xmm6, xmm1, 0x55);
- pmovsxbd(xmm6, xmm6);
- paddd(xmm13, xmm6);
- pshufd(xmm5, xmm1, 0xaa);
- pmovsxbd(xmm5, xmm5);
- paddd(xmm14, xmm5);
- pshufd(xmm6, xmm1, 0xff);
- pmovsxbd(xmm6, xmm6);
- paddd(xmm15, xmm6);
- movdqu(xword[B-0x70], xmm1);
- sub(B, -32);
- align(4);
-
-L(l850);
- mov(A1, qword[ARG_BIAS]);
- movdqu(xword[A1], xmm8);
- movdqu(xword[A1+0x10], xmm9);
- movdqu(xword[A1+0x20], xmm10);
- movdqu(xword[A1+0x30], xmm11);
- movdqu(xword[A1+0x40], xmm12);
- movdqu(xword[A1+0x50], xmm13);
- movdqu(xword[A1+0x60], xmm14);
- movdqu(xword[A1+0x70], xmm15);
- add(qword[ARG_BIAS], 0x80);
- sub(N, 0x20);
- cmp(N, 0x20);
- jge(l48c, T_NEAR);
- align(4);
-
-L(l89c);
- cmp(N, 0x10);
- jl(lae8, T_NEAR);
- align(4);
-
-L(l8a8);
- mov(A1, A);
- add(A, 0x10);
- pxor(xmm8, xmm8);
- pxor(xmm9, xmm9);
- pxor(xmm10, xmm10);
- pxor(xmm11, xmm11);
- mov(I, M);
- sar(I, 0x2);
- jle(l9d0, T_NEAR);
- align(4);
-
-L(l8d0);
- movdqu(xmm0, xword[A1-0x80]);
- add(A1, LDA);
- movdqu(xmm1, xword[A1-0x80]);
- add(A1, LDA);
- movdqu(xmm2, xword[A1-0x80]);
- add(A1, LDA);
- movdqu(xmm3, xword[A1-0x80]);
- add(A1, LDA);
- movdqa(xmm4, xmm0);
- punpcklbw(xmm0, xmm1);
- punpckhbw(xmm4, xmm1);
- movdqa(xmm1, xmm2);
- punpcklbw(xmm2, xmm3);
- punpckhbw(xmm1, xmm3);
- movdqa(xmm3, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm3, xmm2);
- movdqa(xmm2, xmm4);
- punpcklwd(xmm4, xmm1);
- punpckhwd(xmm2, xmm1);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm3);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- pmovsxbw(xmm5, xmm2);
- movhlps(xmm6, xmm2);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm11, xmm5);
- movdqu(xword[B-0x60], xmm4);
- movdqu(xword[B-0x50], xmm2);
- sub(B, -64);
- dec(I);
- jg(l8d0, T_NEAR);
- align(4);
-
-L(l9d0);
- test(M, 0x2);
- jle(la64, T_NEAR);
- movdqu(xmm0, xword[A1-0x80]);
- add(A1, LDA);
- movdqu(xmm1, xword[A1-0x80]);
- add(A1, LDA);
- movdqa(xmm2, xmm0);
- punpcklbw(xmm0, xmm1);
- punpckhbw(xmm2, xmm1);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm6, xmm6);
- pmovsxwd(xmm6, xmm6);
- paddd(xmm9, xmm6);
- pmovsxbw(xmm5, xmm2);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movhlps(xmm6, xmm2);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm6, xmm6);
- pmovsxwd(xmm6, xmm6);
- paddd(xmm11, xmm6);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm2);
- sub(B, -32);
- align(4);
-
-L(la64);
- test(M, 0x1);
- jle(lab8, T_NEAR);
- movdqu(xmm0, xword[A1-0x80]);
- add(A1, LDA);
- pmovsxbd(xmm5, xmm0);
- paddd(xmm8, xmm5);
- pshufd(xmm6, xmm0, 0x55);
- pmovsxbd(xmm6, xmm6);
- paddd(xmm9, xmm6);
- pshufd(xmm5, xmm0, 0xaa);
- pmovsxbd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- pshufd(xmm6, xmm0, 0xff);
- pmovsxbd(xmm6, xmm6);
- paddd(xmm11, xmm6);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(lab8);
- mov(A1, qword[ARG_BIAS]);
- movdqu(xword[A1], xmm8);
- movdqu(xword[A1+0x10], xmm9);
- movdqu(xword[A1+0x20], xmm10);
- movdqu(xword[A1+0x30], xmm11);
- add(qword[ARG_BIAS], 0x40);
- sub(N, 0x10);
- cmp(N, 0x10);
- jge(l8a8, T_NEAR);
- align(4);
-
-L(lae8);
- cmp(N, 0x8);
- jl(ld78, T_NEAR);
- align(4);
-
-L(laf4);
- mov(A1, A);
- add(A, 0x8);
- pxor(xmm8, xmm8);
- pxor(xmm9, xmm9);
- mov(I, M);
- sar(I, 0x3);
- jle(lc30, T_NEAR);
- align(4);
-
-L(lb14);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm1, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm2, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm3, qword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm1);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm1, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm2, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm3, qword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x60], xmm0);
- movdqu(xword[B-0x50], xmm1);
- sub(B, -64);
- dec(I);
- jg(lb14, T_NEAR);
- align(4);
-
-L(lc30);
- test(M, 0x4);
- jle(lcc8, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm1, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm2, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm3, qword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm1);
- sub(B, -32);
- align(4);
-
-L(lcc8);
- test(M, 0x2);
- jle(ld1c, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm1, qword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm6, xmm6);
- pmovsxwd(xmm6, xmm6);
- paddd(xmm9, xmm6);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(ld1c);
- test(M, 0x1);
- jle(ld54, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- pmovsxbd(xmm5, xmm0);
- pshufd(xmm6, xmm0, 0x55);
- pmovsxbd(xmm6, xmm6);
- paddd(xmm8, xmm5);
- paddd(xmm9, xmm6);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(ld54);
- mov(A1, qword[ARG_BIAS]);
- movdqu(xword[A1], xmm8);
- movdqu(xword[A1+0x10], xmm9);
- add(qword[ARG_BIAS], 0x20);
- sub(N, 0x8);
- cmp(N, 0x8);
- jge(laf4, T_NEAR);
- align(4);
-
-L(ld78);
- cmp(N, 0x4);
- jl(lf3c, T_NEAR);
- align(4);
-
-L(ld84);
- mov(A1, A);
- add(A, 0x4);
- pxor(xmm7, xmm7);
- mov(I, M);
- sar(I, 0x3);
- jle(le58, T_NEAR);
- align(4);
-
-L(ld9c);
- movd(xmm0, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm1, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm2, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm3, dword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movd(xmm0, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm1, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm2, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm3, dword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x70], xmm0);
- sub(B, -32);
- dec(I);
- jg(ld9c, T_NEAR);
- align(4);
-
-L(le58);
- test(M, 0x4);
- jle(lebc, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm1, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm2, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm3, dword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(lebc);
- test(M, 0x2);
- jle(lef8, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm1, dword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(lef8);
- test(M, 0x1);
- jle(lf1c, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- pmovsxbd(xmm5, xmm0);
- paddd(xmm7, xmm5);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(lf1c);
- mov(A1, qword[ARG_BIAS]);
- movdqu(xword[A1], xmm7);
- add(qword[ARG_BIAS], 0x10);
- sub(N, 0x4);
- cmp(N, 0x4);
- jge(ld84, T_NEAR);
- align(4);
-
-L(lf3c);
- cmp(N, 0x2);
- jl(l111a, T_NEAR);
- align(4);
-
-L(lf48);
- mov(A1, A);
- add(A, 0x2);
- pxor(xmm7, xmm7);
- mov(LDA3, M);
- sar(LDA3, 0x3);
- jle(l1024, T_NEAR);
- align(4);
-
-L(lf60);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm1, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm2, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm3, eax, 0x0);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm1, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm2, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm3, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm4, eax, 0x0);
- punpcklbw(xmm1, xmm2);
- punpcklbw(xmm3, xmm4);
- punpcklwd(xmm1, xmm3);
- punpcklqdq(xmm0, xmm1);
- pshufd(xmm6, xmm0, 0xd8);
- pmovsxbw(xmm5, xmm6);
- movhlps(xmm6, xmm6);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- dec(LDA3);
- jg(lf60, T_NEAR);
- align(4);
-
-L(l1024);
- test(M, 0x4);
- jle(l1090, T_NEAR);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm1, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm2, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm3, eax, 0x0);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l1090);
- test(M, 0x2);
- jle(l10d4, T_NEAR);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm1, eax, 0x0);
- punpcklbw(xmm0, xmm1);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l10d4);
- test(M, 0x1);
- jle(l10fc, T_NEAR);
- mov(ax, word[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- pmovsxbd(xmm5, xmm0);
- paddd(xmm7, xmm5);
- mov(word[B-0x80], ax);
- sub(B, -2);
- align(4);
-
-L(l10fc);
- mov(A1, qword[ARG_BIAS]);
- movq(qword[A1], xmm7);
- add(qword[ARG_BIAS], 0x8);
- sub(N, 0x2);
- cmp(N, 0x2);
- jge(lf48, T_NEAR);
- align(4);
-
-L(l111a);
- cmp(N, 0x1);
- jl(l12bc, T_NEAR);
- align(4);
-
-L(l1124);
- mov(A1, A);
- add(A, 0x1);
- pxor(xmm7, xmm7);
- mov(LDA3, M);
- sar(LDA3, 0x3);
- jle(l11d4, T_NEAR);
- align(4);
-
-L(l113c);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x3);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x4);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x5);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x6);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x7);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- dec(LDA3);
- jg(l113c, T_NEAR);
- align(4);
-
-L(l11d4);
- test(M, 0x4);
- jle(l1234, T_NEAR);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x3);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l1234);
- test(M, 0x2);
- jle(l1278, T_NEAR);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x0);
- mov(byte[B-0x80], al);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x1);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- mov(byte[B-0x7f], al);
- sub(B, -2);
- align(4);
-
-L(l1278);
- test(M, 0x1);
- jle(l129c, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- pmovsxbd(xmm5, xmm0);
- paddd(xmm7, xmm5);
- mov(byte[B-0x80], al);
- sub(B, -1);
- align(4);
-
-L(l129c);
- mov(A1, qword[ARG_BIAS]);
- movd(dword[A1], xmm7);
- add(qword[ARG_BIAS], 0x4);
- sub(N, 0x1);
- cmp(N, 0x1);
- jge(l1124, T_NEAR);
- align(4);
-
-L(l12bc);
-
- postamble();
-}
-outLocalLabel();
-
-#undef M
-#undef N
-#undef A
-#undef LDA
-#undef ALPHA
-#undef B
-#undef I
-#undef A1
-#undef A2
-#undef LDA3
-#ifdef _WIN32
-#undef ARG_ALPHA
-#undef ARG_B
-#endif
-#undef ARG_BIAS
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp
deleted file mode 100644
index a4f4ff09c6..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp
+++ /dev/null
@@ -1,3163 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "jit_generator.hpp"
-#include "common.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-jit_avx512_core_u8_copy_sum_at_kern::jit_avx512_core_u8_copy_sum_at_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
-{
-
-#ifndef _WIN32
-#define M rdi
-#define N rsi
-#define A rdx
-#define LDA rcx
-#define ALPHA r8
-#define B r9
-
-#define I rax
-#define A1 r10
-#define A2 r8
-#define LDA3 r11
-
-#define ARG_BIAS 24+stacksize+rsp
-
-#else
-
-#define M rcx
-#define N rdx
-#define A r8
-#define LDA r9
-#define ALPHA rax
-#define B rdi
-
-#define I rax
-#define A1 rsi
-#define A2 r10
-#define LDA3 r11
-
-#define ARG_ALPHA 40+stacksize+rsp
-#define ARG_B 48+stacksize+rsp
-#define ARG_BIAS 72+stacksize+rsp
-
-#endif
-
-inLocalLabel();
-{
-
-Xbyak::Label l1750;
-Xbyak::Label l1b6c;
-Xbyak::Label l1e14;
-Xbyak::Label l20;
-Xbyak::Label l2068;
-Xbyak::Label l226c;
-Xbyak::Label l22b8;
-Xbyak::Label l22c4;
-Xbyak::Label l22f4;
-Xbyak::Label l26b4;
-Xbyak::Label l28cc;
-Xbyak::Label l2a2c;
-Xbyak::Label l2b5c;
-Xbyak::Label l2c64;
-Xbyak::Label l2c94;
-Xbyak::Label l2ca0;
-Xbyak::Label l2cc8;
-Xbyak::Label l2eac;
-Xbyak::Label l2fc0;
-Xbyak::Label l3078;
-Xbyak::Label l3118;
-Xbyak::Label l319c;
-Xbyak::Label l31c0;
-Xbyak::Label l31cc;
-Xbyak::Label l31ec;
-Xbyak::Label l32e4;
-Xbyak::Label l3378;
-Xbyak::Label l33dc;
-Xbyak::Label l3434;
-Xbyak::Label l347c;
-Xbyak::Label l349c;
-Xbyak::Label l34a8;
-Xbyak::Label l34c8;
-Xbyak::Label l3558;
-Xbyak::Label l35b0;
-Xbyak::Label l35f4;
-Xbyak::Label l3638;
-Xbyak::Label l366c;
-Xbyak::Label l368a;
-Xbyak::Label l3694;
-Xbyak::Label l36a8;
-Xbyak::Label l36ec;
-Xbyak::Label l3728;
-Xbyak::Label l3760;
-Xbyak::Label l3794;
-Xbyak::Label l37b8;
-Xbyak::Label l37d8;
-Xbyak::Label l5cc;
-Xbyak::Label l6c;
-Xbyak::Label l968;
-Xbyak::Label lc80;
-Xbyak::Label lf1c;
-Xbyak::Label lf64;
-Xbyak::Label lf70;
-Xbyak::Label lfb4;
-
- preamble();
- auto stacksize = get_size_of_abi_save_regs();
-#ifdef _WIN32
- mov(ALPHA, ptr[ARG_ALPHA]);
- mov(B, ptr[ARG_B]);
-#endif
-
- mov(N, qword[N]);
- mov(M, qword[M]);
- mov(LDA, qword[LDA]);
- sub(A, -128);
- sub(B, -128);
- lea(LDA3, ptr[LDA+LDA*2]);
- cmp(N, 0x30);
- jl(lf64, T_NEAR);
- align(4);
-
-L(l20);
- mov(A1, A);
- mov(I, LDA);
- shl(I, 0x5);
- lea(I, ptr[I+LDA*8]);
- lea(I, ptr[I+LDA*8]);
- add(A, I);
- vxorps(ymm8, ymm8, ymm8);
- vxorps(ymm9, ymm9, ymm9);
- vxorps(ymm10, ymm10, ymm10);
- vxorps(ymm11, ymm11, ymm11);
- vxorps(ymm12, ymm12, ymm12);
- vxorps(ymm13, ymm13, ymm13);
- vxorps(ymm14, ymm14, ymm14);
- vxorps(ymm15, ymm15, ymm15);
- mov(I, M);
- sar(I, 0x3);
- jle(l5cc, T_NEAR);
- align(4);
-
-L(l6c);
- vmovq(xmm0, qword[A1-0x80]);
- vmovq(xmm1, qword[A1+LDA*1-0x80]);
- vmovq(xmm2, qword[A1+LDA*2-0x80]);
- vmovq(xmm3, qword[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- vpunpckldq(xmm1, xmm0, xmm1);
- vpunpckldq(xmm3, xmm2, xmm3);
- vpunpcklqdq(xmm0, xmm1, xmm3);
- vpunpckhqdq(xmm1, xmm1, xmm3);
- vmovdqu(xword[B-0x80], xmm0);
- vmovdqu(xword[B+0x40], xmm1);
- vmovq(xmm2, qword[A2-0x80]);
- vmovq(xmm3, qword[A2+LDA*1-0x80]);
- vmovq(xmm4, qword[A2+LDA*2-0x80]);
- vmovq(xmm5, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm3, xmm2, xmm3);
- vpunpckldq(xmm5, xmm4, xmm5);
- vpunpcklqdq(xmm2, xmm3, xmm5);
- vpunpckhqdq(xmm3, xmm3, xmm5);
- vmovdqu(xword[B-0x70], xmm2);
- vmovdqu(xword[B+0x50], xmm3);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm2);
- vmovhlps(xmm7, xmm2, xmm2);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm8, ymm8, ymm5);
- vpmovsxbw(ymm5, xmm1);
- vmovhlps(xmm6, xmm1, xmm1);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm3);
- vmovhlps(xmm7, xmm3, xmm3);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm8, ymm8, ymm5);
- vmovq(xmm0, qword[A2-0x80]);
- vmovq(xmm1, qword[A2+LDA*1-0x80]);
- vmovq(xmm2, qword[A2+LDA*2-0x80]);
- vmovq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm1, xmm0, xmm1);
- vpunpckldq(xmm3, xmm2, xmm3);
- vpunpcklqdq(xmm0, xmm1, xmm3);
- vpunpckhqdq(xmm1, xmm1, xmm3);
- vmovdqu(xword[B-0x60], xmm0);
- vmovdqu(xword[B+0x60], xmm1);
- vmovq(xmm2, qword[A2-0x80]);
- vmovq(xmm3, qword[A2+LDA*1-0x80]);
- vmovq(xmm4, qword[A2+LDA*2-0x80]);
- vmovq(xmm5, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm3, xmm2, xmm3);
- vpunpckldq(xmm5, xmm4, xmm5);
- vpunpcklqdq(xmm2, xmm3, xmm5);
- vpunpckhqdq(xmm3, xmm3, xmm5);
- vmovdqu(xword[B-0x50], xmm2);
- vmovdqu(xword[B+0x70], xmm3);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm2);
- vmovhlps(xmm7, xmm2, xmm2);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm9, ymm9, ymm5);
- vpmovsxbw(ymm5, xmm1);
- vmovhlps(xmm6, xmm1, xmm1);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm3);
- vmovhlps(xmm7, xmm3, xmm3);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm9, ymm9, ymm5);
- vmovq(xmm0, qword[A2-0x80]);
- vmovq(xmm1, qword[A2+LDA*1-0x80]);
- vmovq(xmm2, qword[A2+LDA*2-0x80]);
- vmovq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm1, xmm0, xmm1);
- vpunpckldq(xmm3, xmm2, xmm3);
- vpunpcklqdq(xmm0, xmm1, xmm3);
- vpunpckhqdq(xmm1, xmm1, xmm3);
- vmovdqu(xword[B-0x40], xmm0);
- vmovdqu(xword[B+0x80], xmm1);
- vmovq(xmm2, qword[A2-0x80]);
- vmovq(xmm3, qword[A2+LDA*1-0x80]);
- vmovq(xmm4, qword[A2+LDA*2-0x80]);
- vmovq(xmm5, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm3, xmm2, xmm3);
- vpunpckldq(xmm5, xmm4, xmm5);
- vpunpcklqdq(xmm2, xmm3, xmm5);
- vpunpckhqdq(xmm3, xmm3, xmm5);
- vmovdqu(xword[B-0x30], xmm2);
- vmovdqu(xword[B+0x90], xmm3);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm2);
- vmovhlps(xmm7, xmm2, xmm2);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm10, ymm10, ymm5);
- vpmovsxbw(ymm5, xmm1);
- vmovhlps(xmm6, xmm1, xmm1);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm3);
- vmovhlps(xmm7, xmm3, xmm3);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm10, ymm10, ymm5);
- vmovq(xmm0, qword[A2-0x80]);
- vmovq(xmm1, qword[A2+LDA*1-0x80]);
- vmovq(xmm2, qword[A2+LDA*2-0x80]);
- vmovq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm1, xmm0, xmm1);
- vpunpckldq(xmm3, xmm2, xmm3);
- vpunpcklqdq(xmm0, xmm1, xmm3);
- vpunpckhqdq(xmm1, xmm1, xmm3);
- vmovdqu(xword[B-0x20], xmm0);
- vmovdqu(xword[B+0xa0], xmm1);
- vmovq(xmm2, qword[A2-0x80]);
- vmovq(xmm3, qword[A2+LDA*1-0x80]);
- vmovq(xmm4, qword[A2+LDA*2-0x80]);
- vmovq(xmm5, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm3, xmm2, xmm3);
- vpunpckldq(xmm5, xmm4, xmm5);
- vpunpcklqdq(xmm2, xmm3, xmm5);
- vpunpckhqdq(xmm3, xmm3, xmm5);
- vmovdqu(xword[B-0x10], xmm2);
- vmovdqu(xword[B+0xb0], xmm3);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm2);
- vmovhlps(xmm7, xmm2, xmm2);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm11, ymm11, ymm5);
- vpmovsxbw(ymm5, xmm1);
- vmovhlps(xmm6, xmm1, xmm1);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm3);
- vmovhlps(xmm7, xmm3, xmm3);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm11, ymm11, ymm5);
- vmovq(xmm0, qword[A2-0x80]);
- vmovq(xmm1, qword[A2+LDA*1-0x80]);
- vmovq(xmm2, qword[A2+LDA*2-0x80]);
- vmovq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm1, xmm0, xmm1);
- vpunpckldq(xmm3, xmm2, xmm3);
- vpunpcklqdq(xmm0, xmm1, xmm3);
- vpunpckhqdq(xmm1, xmm1, xmm3);
- vmovdqu(xword[B], xmm0);
- vmovdqu(xword[B+0xc0], xmm1);
- vmovq(xmm2, qword[A2-0x80]);
- vmovq(xmm3, qword[A2+LDA*1-0x80]);
- vmovq(xmm4, qword[A2+LDA*2-0x80]);
- vmovq(xmm5, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm3, xmm2, xmm3);
- vpunpckldq(xmm5, xmm4, xmm5);
- vpunpcklqdq(xmm2, xmm3, xmm5);
- vpunpckhqdq(xmm3, xmm3, xmm5);
- vmovdqu(xword[B+0x10], xmm2);
- vmovdqu(xword[B+0xd0], xmm3);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm2);
- vmovhlps(xmm7, xmm2, xmm2);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm12, ymm12, ymm5);
- vpmovsxbw(ymm5, xmm1);
- vmovhlps(xmm6, xmm1, xmm1);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm3);
- vmovhlps(xmm7, xmm3, xmm3);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm12, ymm12, ymm5);
- vmovq(xmm0, qword[A2-0x80]);
- vmovq(xmm1, qword[A2+LDA*1-0x80]);
- vmovq(xmm2, qword[A2+LDA*2-0x80]);
- vmovq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm1, xmm0, xmm1);
- vpunpckldq(xmm3, xmm2, xmm3);
- vpunpcklqdq(xmm0, xmm1, xmm3);
- vpunpckhqdq(xmm1, xmm1, xmm3);
- vmovdqu(xword[B+0x20], xmm0);
- vmovdqu(xword[B+0xe0], xmm1);
- vmovq(xmm2, qword[A2-0x80]);
- vmovq(xmm3, qword[A2+LDA*1-0x80]);
- vmovq(xmm4, qword[A2+LDA*2-0x80]);
- vmovq(xmm5, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm3, xmm2, xmm3);
- vpunpckldq(xmm5, xmm4, xmm5);
- vpunpcklqdq(xmm2, xmm3, xmm5);
- vpunpckhqdq(xmm3, xmm3, xmm5);
- vmovdqu(xword[B+0x30], xmm2);
- vmovdqu(xword[B+0xf0], xmm3);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm2);
- vmovhlps(xmm7, xmm2, xmm2);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm13, ymm13, ymm5);
- vpmovsxbw(ymm5, xmm1);
- vmovhlps(xmm6, xmm1, xmm1);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm3);
- vmovhlps(xmm7, xmm3, xmm3);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm13, ymm13, ymm5);
- sub(A1, -8);
- sub(B, -384);
- dec(I);
- jg(l6c, T_NEAR);
- align(4);
-
-L(l5cc);
- test(M, 0x4);
- jle(l968, T_NEAR);
- vmovd(xmm0, dword[A1-0x80]);
- vmovd(xmm1, dword[A1+LDA*1-0x80]);
- vmovd(xmm2, dword[A1+LDA*2-0x80]);
- vmovd(xmm3, dword[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- vpunpckldq(xmm0, xmm0, xmm1);
- vpunpckldq(xmm2, xmm2, xmm3);
- vpunpcklqdq(xmm0, xmm0, xmm2);
- vmovdqu(xword[B-0x80], xmm0);
- vmovd(xmm1, dword[A2-0x80]);
- vmovd(xmm2, dword[A2+LDA*1-0x80]);
- vmovd(xmm3, dword[A2+LDA*2-0x80]);
- vmovd(xmm4, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm1, xmm1, xmm2);
- vpunpckldq(xmm3, xmm3, xmm4);
- vpunpcklqdq(xmm1, xmm1, xmm3);
- vmovdqu(xword[B-0x70], xmm1);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm1);
- vmovhlps(xmm7, xmm1, xmm1);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm8, ymm8, ymm5);
- vmovd(xmm0, dword[A2-0x80]);
- vmovd(xmm1, dword[A2+LDA*1-0x80]);
- vmovd(xmm2, dword[A2+LDA*2-0x80]);
- vmovd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm0, xmm0, xmm1);
- vpunpckldq(xmm2, xmm2, xmm3);
- vpunpcklqdq(xmm0, xmm0, xmm2);
- vmovdqu(xword[B-0x60], xmm0);
- vmovd(xmm1, dword[A2-0x80]);
- vmovd(xmm2, dword[A2+LDA*1-0x80]);
- vmovd(xmm3, dword[A2+LDA*2-0x80]);
- vmovd(xmm4, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm1, xmm1, xmm2);
- vpunpckldq(xmm3, xmm3, xmm4);
- vpunpcklqdq(xmm1, xmm1, xmm3);
- vmovdqu(xword[B-0x50], xmm1);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm1);
- vmovhlps(xmm7, xmm1, xmm1);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm9, ymm9, ymm5);
- vmovd(xmm0, dword[A2-0x80]);
- vmovd(xmm1, dword[A2+LDA*1-0x80]);
- vmovd(xmm2, dword[A2+LDA*2-0x80]);
- vmovd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm0, xmm0, xmm1);
- vpunpckldq(xmm2, xmm2, xmm3);
- vpunpcklqdq(xmm0, xmm0, xmm2);
- vmovdqu(xword[B-0x40], xmm0);
- vmovd(xmm1, dword[A2-0x80]);
- vmovd(xmm2, dword[A2+LDA*1-0x80]);
- vmovd(xmm3, dword[A2+LDA*2-0x80]);
- vmovd(xmm4, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm1, xmm1, xmm2);
- vpunpckldq(xmm3, xmm3, xmm4);
- vpunpcklqdq(xmm1, xmm1, xmm3);
- vmovdqu(xword[B-0x30], xmm1);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm1);
- vmovhlps(xmm7, xmm1, xmm1);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm10, ymm10, ymm5);
- vmovd(xmm0, dword[A2-0x80]);
- vmovd(xmm1, dword[A2+LDA*1-0x80]);
- vmovd(xmm2, dword[A2+LDA*2-0x80]);
- vmovd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm0, xmm0, xmm1);
- vpunpckldq(xmm2, xmm2, xmm3);
- vpunpcklqdq(xmm0, xmm0, xmm2);
- vmovdqu(xword[B-0x20], xmm0);
- vmovd(xmm1, dword[A2-0x80]);
- vmovd(xmm2, dword[A2+LDA*1-0x80]);
- vmovd(xmm3, dword[A2+LDA*2-0x80]);
- vmovd(xmm4, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm1, xmm1, xmm2);
- vpunpckldq(xmm3, xmm3, xmm4);
- vpunpcklqdq(xmm1, xmm1, xmm3);
- vmovdqu(xword[B-0x10], xmm1);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm1);
- vmovhlps(xmm7, xmm1, xmm1);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm11, ymm11, ymm5);
- vmovd(xmm0, dword[A2-0x80]);
- vmovd(xmm1, dword[A2+LDA*1-0x80]);
- vmovd(xmm2, dword[A2+LDA*2-0x80]);
- vmovd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm0, xmm0, xmm1);
- vpunpckldq(xmm2, xmm2, xmm3);
- vpunpcklqdq(xmm0, xmm0, xmm2);
- vmovdqu(xword[B], xmm0);
- vmovd(xmm1, dword[A2-0x80]);
- vmovd(xmm2, dword[A2+LDA*1-0x80]);
- vmovd(xmm3, dword[A2+LDA*2-0x80]);
- vmovd(xmm4, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm1, xmm1, xmm2);
- vpunpckldq(xmm3, xmm3, xmm4);
- vpunpcklqdq(xmm1, xmm1, xmm3);
- vmovdqu(xword[B+0x10], xmm1);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm1);
- vmovhlps(xmm7, xmm1, xmm1);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm12, ymm12, ymm5);
- vmovd(xmm0, dword[A2-0x80]);
- vmovd(xmm1, dword[A2+LDA*1-0x80]);
- vmovd(xmm2, dword[A2+LDA*2-0x80]);
- vmovd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm0, xmm0, xmm1);
- vpunpckldq(xmm2, xmm2, xmm3);
- vpunpcklqdq(xmm0, xmm0, xmm2);
- vmovdqu(xword[B+0x20], xmm0);
- vmovd(xmm1, dword[A2-0x80]);
- vmovd(xmm2, dword[A2+LDA*1-0x80]);
- vmovd(xmm3, dword[A2+LDA*2-0x80]);
- vmovd(xmm4, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpunpckldq(xmm1, xmm1, xmm2);
- vpunpckldq(xmm3, xmm3, xmm4);
- vpunpcklqdq(xmm1, xmm1, xmm3);
- vmovdqu(xword[B+0x30], xmm1);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxbw(ymm6, xmm1);
- vmovhlps(xmm7, xmm1, xmm1);
- vpmovsxbw(ymm7, xmm7);
- vphaddw(ymm6, ymm6, ymm7);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm13, ymm13, ymm5);
- sub(A1, -4);
- sub(B, -192);
- align(4);
-
-L(l968);
- test(M, 0x2);
- jle(lc80, T_NEAR);
- mov(ax, word[A1-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x0);
- mov(ax, word[A1+LDA*1-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x1);
- mov(ax, word[A1+LDA*2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x2);
- mov(ax, word[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- vpinsrw(xmm0, xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpinsrw(xmm0, xmm0, eax, 0x7);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm8, ymm8, ymm5);
- vmovdqu(xword[B-0x80], xmm0);
- mov(ax, word[A2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x0);
- mov(ax, word[A2+LDA*1-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x1);
- mov(ax, word[A2+LDA*2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x2);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpinsrw(xmm0, xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x7);
- lea(A2, ptr[A2+LDA*4]);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm9, ymm9, ymm5);
- vmovdqu(xword[B-0x70], xmm0);
- mov(ax, word[A2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x0);
- mov(ax, word[A2+LDA*1-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x1);
- mov(ax, word[A2+LDA*2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x2);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpinsrw(xmm0, xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x7);
- lea(A2, ptr[A2+LDA*4]);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm10, ymm10, ymm5);
- vmovdqu(xword[B-0x60], xmm0);
- mov(ax, word[A2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x0);
- mov(ax, word[A2+LDA*1-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x1);
- mov(ax, word[A2+LDA*2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x2);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpinsrw(xmm0, xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x7);
- lea(A2, ptr[A2+LDA*4]);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm11, ymm11, ymm5);
- vmovdqu(xword[B-0x50], xmm0);
- mov(ax, word[A2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x0);
- mov(ax, word[A2+LDA*1-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x1);
- mov(ax, word[A2+LDA*2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x2);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpinsrw(xmm0, xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x7);
- lea(A2, ptr[A2+LDA*4]);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm12, ymm12, ymm5);
- vmovdqu(xword[B-0x40], xmm0);
- mov(ax, word[A2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x0);
- mov(ax, word[A2+LDA*1-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x1);
- mov(ax, word[A2+LDA*2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x2);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpinsrw(xmm0, xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- vpinsrw(xmm0, xmm0, eax, 0x7);
- lea(A2, ptr[A2+LDA*4]);
- vpmovsxbw(ymm5, xmm0);
- vmovhlps(xmm6, xmm0, xmm0);
- vpmovsxbw(ymm6, xmm6);
- vphaddw(ymm5, ymm5, ymm6);
- vpmovsxwd(ymm5, xmm5);
- vpaddd(ymm13, ymm13, ymm5);
- vmovdqu(xword[B-0x30], xmm0);
- sub(A1, -2);
- sub(B, -96);
- align(4);
-
-L(lc80);
- test(M, 0x1);
- jle(lf1c, T_NEAR);
- mov(al, byte[A1-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x0);
- mov(al, byte[A1+LDA*1-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x1);
- mov(al, byte[A1+LDA*2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x2);
- mov(al, byte[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- vpinsrb(xmm0, xmm0, eax, 0x3);
- mov(al, byte[A2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x4);
- mov(al, byte[A2+LDA*1-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x5);
- mov(al, byte[A2+LDA*2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x6);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpinsrb(xmm0, xmm0, eax, 0x7);
- mov(al, byte[A2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x8);
- mov(al, byte[A2+LDA*1-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x9);
- mov(al, byte[A2+LDA*2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0xa);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpinsrb(xmm0, xmm0, eax, 0xb);
- mov(al, byte[A2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0xc);
- mov(al, byte[A2+LDA*1-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0xd);
- mov(al, byte[A2+LDA*2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0xe);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpinsrb(xmm0, xmm0, eax, 0xf);
- vpmovsxbd(ymm7, xmm0);
- vpaddd(ymm8, ymm8, ymm7);
- vmovhlps(xmm7, xmm0, xmm0);
- vpmovsxbd(ymm7, xmm7);
- vpaddd(ymm9, ymm9, ymm7);
- vmovdqu(xword[B-0x80], xmm0);
- mov(al, byte[A2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x0);
- mov(al, byte[A2+LDA*1-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x1);
- mov(al, byte[A2+LDA*2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x2);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpinsrb(xmm0, xmm0, eax, 0x3);
- mov(al, byte[A2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x4);
- mov(al, byte[A2+LDA*1-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x5);
- mov(al, byte[A2+LDA*2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x6);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpinsrb(xmm0, xmm0, eax, 0x7);
- mov(al, byte[A2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x8);
- mov(al, byte[A2+LDA*1-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x9);
- mov(al, byte[A2+LDA*2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0xa);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpinsrb(xmm0, xmm0, eax, 0xb);
- mov(al, byte[A2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0xc);
- mov(al, byte[A2+LDA*1-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0xd);
- mov(al, byte[A2+LDA*2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0xe);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpinsrb(xmm0, xmm0, eax, 0xf);
- vpmovsxbd(ymm7, xmm0);
- vpaddd(ymm10, ymm10, ymm7);
- vmovhlps(xmm7, xmm0, xmm0);
- vpmovsxbd(ymm7, xmm7);
- vpaddd(ymm11, ymm11, ymm7);
- vmovdqu(xword[B-0x70], xmm0);
- mov(al, byte[A2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x0);
- mov(al, byte[A2+LDA*1-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x1);
- mov(al, byte[A2+LDA*2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x2);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpinsrb(xmm0, xmm0, eax, 0x3);
- mov(al, byte[A2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x4);
- mov(al, byte[A2+LDA*1-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x5);
- mov(al, byte[A2+LDA*2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x6);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpinsrb(xmm0, xmm0, eax, 0x7);
- mov(al, byte[A2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x8);
- mov(al, byte[A2+LDA*1-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0x9);
- mov(al, byte[A2+LDA*2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0xa);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpinsrb(xmm0, xmm0, eax, 0xb);
- mov(al, byte[A2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0xc);
- mov(al, byte[A2+LDA*1-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0xd);
- mov(al, byte[A2+LDA*2-0x80]);
- vpinsrb(xmm0, xmm0, eax, 0xe);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- vpinsrb(xmm0, xmm0, eax, 0xf);
- vpmovsxbd(ymm7, xmm0);
- vpaddd(ymm12, ymm12, ymm7);
- vmovhlps(xmm7, xmm0, xmm0);
- vpmovsxbd(ymm7, xmm7);
- vpaddd(ymm13, ymm13, ymm7);
- vmovdqu(xword[B-0x60], xmm0);
- sub(B, -48);
- align(4);
-
-L(lf1c);
- mov(A1, qword[ARG_BIAS]);
- vmovdqu(yword[A1], ymm8);
- vmovdqu(yword[A1+0x20], ymm9);
- vmovdqu(yword[A1+0x40], ymm10);
- vmovdqu(yword[A1+0x60], ymm11);
- vmovdqu(yword[A1+0x80], ymm12);
- vmovdqu(yword[A1+0xa0], ymm13);
- add(qword[ARG_BIAS], 0xc0);
- sub(N, 0x30);
- cmp(N, 0x30);
- jge(l20, T_NEAR);
- vzeroupper();
- align(4);
-
-L(lf64);
- cmp(N, 0x20);
- jl(l22b8, T_NEAR);
- align(4);
-
-L(lf70);
- mov(A1, A);
- mov(I, LDA);
- shl(I, 0x5);
- add(A, I);
- pxor(xmm8, xmm8);
- pxor(xmm9, xmm9);
- pxor(xmm10, xmm10);
- pxor(xmm11, xmm11);
- pxor(xmm12, xmm12);
- pxor(xmm13, xmm13);
- pxor(xmm14, xmm14);
- pxor(xmm15, xmm15);
- mov(I, M);
- sar(I, 0x4);
- jle(l1750, T_NEAR);
- align(4);
-
-L(lfb4);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1+LDA*1-0x80]);
- movdqu(xmm2, xword[A1+LDA*2-0x80]);
- movdqu(xmm3, xword[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x80], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B+0x80], xmm4);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B+0x100], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x70], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B+0x10], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B+0x90], xmm4);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B+0x110], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movdqu(xword[B-0x60], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movdqu(xword[B+0x20], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movdqu(xword[B+0xa0], xmm4);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movdqu(xword[B+0x120], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm11, xmm5);
- movdqu(xword[B-0x50], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm11, xmm5);
- movdqu(xword[B+0x30], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm11, xmm5);
- movdqu(xword[B+0xb0], xmm4);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm11, xmm5);
- movdqu(xword[B+0x130], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm12, xmm5);
- movdqu(xword[B-0x40], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm12, xmm5);
- movdqu(xword[B+0x40], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm12, xmm5);
- movdqu(xword[B+0xc0], xmm4);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm12, xmm5);
- movdqu(xword[B+0x140], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm13, xmm5);
- movdqu(xword[B-0x30], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm13, xmm5);
- movdqu(xword[B+0x50], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm13, xmm5);
- movdqu(xword[B+0xd0], xmm4);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm13, xmm5);
- movdqu(xword[B+0x150], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm14, xmm5);
- movdqu(xword[B-0x20], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm14, xmm5);
- movdqu(xword[B+0x60], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm14, xmm5);
- movdqu(xword[B+0xe0], xmm4);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm14, xmm5);
- movdqu(xword[B+0x160], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm15, xmm5);
- movdqu(xword[B-0x10], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm15, xmm5);
- movdqu(xword[B+0x70], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm15, xmm5);
- movdqu(xword[B+0xf0], xmm4);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm15, xmm5);
- movdqu(xword[B+0x170], xmm3);
- sub(A1, -16);
- sub(B, -512);
- dec(I);
- jg(lfb4, T_NEAR);
- align(4);
-
-L(l1750);
- test(M, 0x8);
- jle(l1b6c, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- movq(xmm1, qword[A1+LDA*1-0x80]);
- movq(xmm2, qword[A1+LDA*2-0x80]);
- movq(xmm3, qword[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x80], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x70], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B+0x10], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movdqu(xword[B-0x60], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movdqu(xword[B+0x20], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm11, xmm5);
- movdqu(xword[B-0x50], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm11, xmm5);
- movdqu(xword[B+0x30], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm12, xmm5);
- movdqu(xword[B-0x40], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm12, xmm5);
- movdqu(xword[B+0x40], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm13, xmm5);
- movdqu(xword[B-0x30], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm13, xmm5);
- movdqu(xword[B+0x50], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm14, xmm5);
- movdqu(xword[B-0x20], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm14, xmm5);
- movdqu(xword[B+0x60], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm15, xmm5);
- movdqu(xword[B-0x10], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm15, xmm5);
- movdqu(xword[B+0x70], xmm1);
- sub(A1, -8);
- sub(B, -256);
- align(4);
-
-L(l1b6c);
- test(M, 0x4);
- jle(l1e14, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- movd(xmm1, dword[A1+LDA*1-0x80]);
- movd(xmm2, dword[A1+LDA*2-0x80]);
- movd(xmm3, dword[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x70], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movdqu(xword[B-0x60], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm11, xmm5);
- movdqu(xword[B-0x50], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm12, xmm5);
- movdqu(xword[B-0x40], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm13, xmm5);
- movdqu(xword[B-0x30], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm14, xmm5);
- movdqu(xword[B-0x20], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm15, xmm5);
- movdqu(xword[B-0x10], xmm0);
- sub(A1, -4);
- sub(B, -128);
- align(4);
-
-L(l1e14);
- test(M, 0x2);
- jle(l2068, T_NEAR);
- mov(ax, word[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A1+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrw(xmm0, eax, 0x7);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm6, xmm6);
- pmovsxwd(xmm6, xmm6);
- paddd(xmm9, xmm6);
- movdqu(xword[B-0x80], xmm0);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- pinsrw(xmm0, eax, 0x7);
- lea(A2, ptr[A2+LDA*4]);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm6, xmm6);
- pmovsxwd(xmm6, xmm6);
- paddd(xmm11, xmm6);
- movdqu(xword[B-0x70], xmm0);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- pinsrw(xmm0, eax, 0x7);
- lea(A2, ptr[A2+LDA*4]);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm12, xmm5);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm6, xmm6);
- pmovsxwd(xmm6, xmm6);
- paddd(xmm13, xmm6);
- movdqu(xword[B-0x60], xmm0);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- pinsrw(xmm0, eax, 0x7);
- lea(A2, ptr[A2+LDA*4]);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm14, xmm5);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm6, xmm6);
- pmovsxwd(xmm6, xmm6);
- paddd(xmm15, xmm6);
- movdqu(xword[B-0x50], xmm0);
- sub(A1, -2);
- sub(B, -64);
- align(4);
-
-L(l2068);
- test(M, 0x1);
- jle(l226c, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A1+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- pinsrb(xmm0, eax, 0x3);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x4);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x5);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x6);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0x7);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x8);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x9);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0xa);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0xb);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0xc);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0xd);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0xe);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0xf);
- pmovsxbd(xmm5, xmm0);
- paddd(xmm8, xmm5);
- pshufd(xmm6, xmm0, 0x55);
- pmovsxbd(xmm6, xmm6);
- paddd(xmm9, xmm6);
- pshufd(xmm5, xmm0, 0xaa);
- pmovsxbd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- pshufd(xmm6, xmm0, 0xff);
- pmovsxbd(xmm6, xmm6);
- paddd(xmm11, xmm6);
- movdqu(xword[B-0x80], xmm0);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0x3);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x4);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x5);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x6);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0x7);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x8);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x9);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0xa);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0xb);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0xc);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0xd);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0xe);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0xf);
- pmovsxbd(xmm5, xmm0);
- paddd(xmm12, xmm5);
- pshufd(xmm6, xmm0, 0x55);
- pmovsxbd(xmm6, xmm6);
- paddd(xmm13, xmm6);
- pshufd(xmm5, xmm0, 0xaa);
- pmovsxbd(xmm5, xmm5);
- paddd(xmm14, xmm5);
- pshufd(xmm6, xmm0, 0xff);
- pmovsxbd(xmm6, xmm6);
- paddd(xmm15, xmm6);
- movdqu(xword[B-0x70], xmm0);
- sub(B, -32);
- align(4);
-
-L(l226c);
- mov(A1, qword[ARG_BIAS]);
- movdqu(xword[A1], xmm8);
- movdqu(xword[A1+0x10], xmm9);
- movdqu(xword[A1+0x20], xmm10);
- movdqu(xword[A1+0x30], xmm11);
- movdqu(xword[A1+0x40], xmm12);
- movdqu(xword[A1+0x50], xmm13);
- movdqu(xword[A1+0x60], xmm14);
- movdqu(xword[A1+0x70], xmm15);
- add(qword[ARG_BIAS], 0x80);
- sub(N, 0x20);
- cmp(N, 0x20);
- jge(lf70, T_NEAR);
- align(4);
-
-L(l22b8);
- cmp(N, 0x10);
- jl(l2c94, T_NEAR);
- align(4);
-
-L(l22c4);
- mov(A1, A);
- mov(I, LDA);
- shl(I, 0x4);
- add(A, I);
- pxor(xmm8, xmm8);
- pxor(xmm9, xmm9);
- pxor(xmm10, xmm10);
- pxor(xmm11, xmm11);
- mov(I, M);
- sar(I, 0x4);
- jle(l26b4, T_NEAR);
- align(4);
-
-L(l22f4);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1+LDA*1-0x80]);
- movdqu(xmm2, xword[A1+LDA*2-0x80]);
- movdqu(xmm3, xword[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x80], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x40], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B], xmm4);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B+0x40], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x70], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x30], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B+0x10], xmm4);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B+0x50], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movdqu(xword[B-0x60], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movdqu(xword[B-0x20], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movdqu(xword[B+0x20], xmm4);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movdqu(xword[B+0x60], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm11, xmm5);
- movdqu(xword[B-0x50], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm11, xmm5);
- movdqu(xword[B-0x10], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm11, xmm5);
- movdqu(xword[B+0x30], xmm4);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm11, xmm5);
- movdqu(xword[B+0x70], xmm3);
- sub(A1, -16);
- sub(B, -256);
- dec(I);
- jg(l22f4, T_NEAR);
- align(4);
-
-L(l26b4);
- test(M, 0x8);
- jle(l28cc, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- movq(xmm1, qword[A1+LDA*1-0x80]);
- movq(xmm2, qword[A1+LDA*2-0x80]);
- movq(xmm3, qword[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x80], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x40], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x70], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x30], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movdqu(xword[B-0x60], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movdqu(xword[B-0x20], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm11, xmm5);
- movdqu(xword[B-0x50], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm11, xmm5);
- movdqu(xword[B-0x10], xmm1);
- sub(A1, -8);
- sub(B, -128);
- align(4);
-
-L(l28cc);
- test(M, 0x4);
- jle(l2a2c, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- movd(xmm1, dword[A1+LDA*1-0x80]);
- movd(xmm2, dword[A1+LDA*2-0x80]);
- movd(xmm3, dword[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x70], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movdqu(xword[B-0x60], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm11, xmm5);
- movdqu(xword[B-0x50], xmm0);
- sub(A1, -4);
- sub(B, -64);
- align(4);
-
-L(l2a2c);
- test(M, 0x2);
- jle(l2b5c, T_NEAR);
- mov(ax, word[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A1+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrw(xmm0, eax, 0x7);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm6, xmm6);
- pmovsxwd(xmm6, xmm6);
- paddd(xmm9, xmm6);
- movdqu(xword[B-0x80], xmm0);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- pinsrw(xmm0, eax, 0x7);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm6, xmm6);
- pmovsxwd(xmm6, xmm6);
- paddd(xmm11, xmm6);
- movdqu(xword[B-0x70], xmm0);
- sub(A1, -2);
- sub(B, -32);
- align(4);
-
-L(l2b5c);
- test(M, 0x1);
- jle(l2c64, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A1+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A1+LDA3*1-0x80]);
- lea(A2, ptr[A1+LDA*4]);
- pinsrb(xmm0, eax, 0x3);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x4);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x5);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x6);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0x7);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x8);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x9);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0xa);
- mov(al, byte[A2+LDA3*1-0x80]);
- lea(A2, ptr[A2+LDA*4]);
- pinsrb(xmm0, eax, 0xb);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0xc);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0xd);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0xe);
- mov(al, byte[A2+LDA3*1-0x80]);
- pinsrb(xmm0, eax, 0xf);
- pmovsxbd(xmm5, xmm0);
- paddd(xmm8, xmm5);
- pshufd(xmm6, xmm0, 0x55);
- pmovsxbd(xmm6, xmm6);
- paddd(xmm9, xmm6);
- pshufd(xmm5, xmm0, 0xaa);
- pmovsxbd(xmm5, xmm5);
- paddd(xmm10, xmm5);
- pshufd(xmm6, xmm0, 0xff);
- pmovsxbd(xmm6, xmm6);
- paddd(xmm11, xmm6);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l2c64);
- mov(A1, qword[ARG_BIAS]);
- movdqu(xword[A1], xmm8);
- movdqu(xword[A1+0x10], xmm9);
- movdqu(xword[A1+0x20], xmm10);
- movdqu(xword[A1+0x30], xmm11);
- add(qword[ARG_BIAS], 0x40);
- sub(N, 0x10);
- cmp(N, 0x10);
- jge(l22c4, T_NEAR);
- align(4);
-
-L(l2c94);
- cmp(N, 0x8);
- jl(l31c0, T_NEAR);
- align(4);
-
-L(l2ca0);
- mov(A1, A);
- lea(A2, ptr[A1+LDA*4]);
- lea(I, ptr[A1+LDA*8]);
- mov(A, I);
- pxor(xmm8, xmm8);
- pxor(xmm9, xmm9);
- mov(I, M);
- sar(I, 0x4);
- jle(l2eac, T_NEAR);
- align(4);
-
-L(l2cc8);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1+LDA*1-0x80]);
- movdqu(xmm2, xword[A1+LDA*2-0x80]);
- movdqu(xmm3, xword[A1+LDA3*1-0x80]);
- sub(A1, -16);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x80], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x60], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x40], xmm4);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x20], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- sub(A2, -16);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x70], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x50], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x30], xmm4);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x10], xmm3);
- sub(B, -128);
- dec(I);
- jg(l2cc8, T_NEAR);
- align(4);
-
-L(l2eac);
- test(M, 0x8);
- jle(l2fc0, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- movq(xmm1, qword[A1+LDA*1-0x80]);
- movq(xmm2, qword[A1+LDA*2-0x80]);
- movq(xmm3, qword[A1+LDA3*1-0x80]);
- sub(A1, -8);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x80], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x60], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- sub(A2, -8);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x70], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x50], xmm1);
- sub(B, -64);
- align(4);
-
-L(l2fc0);
- test(M, 0x4);
- jle(l3078, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- movd(xmm1, dword[A1+LDA*1-0x80]);
- movd(xmm2, dword[A1+LDA*2-0x80]);
- movd(xmm3, dword[A1+LDA3*1-0x80]);
- sub(A1, -4);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- sub(A2, -4);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x70], xmm0);
- sub(B, -32);
- align(4);
-
-L(l3078);
- test(M, 0x2);
- jle(l3118, T_NEAR);
- mov(ax, word[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A1+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A1+LDA3*1-0x80]);
- sub(A1, -2);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- sub(A2, -2);
- pinsrw(xmm0, eax, 0x7);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm6, xmm6);
- pmovsxwd(xmm6, xmm6);
- paddd(xmm9, xmm6);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l3118);
- test(M, 0x1);
- jle(l319c, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A1+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A1+LDA3*1-0x80]);
- pinsrb(xmm0, eax, 0x3);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x4);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x5);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x6);
- mov(al, byte[A2+LDA3*1-0x80]);
- pinsrb(xmm0, eax, 0x7);
- pmovsxbd(xmm5, xmm0);
- pshufd(xmm6, xmm0, 0x55);
- pmovsxbd(xmm6, xmm6);
- paddd(xmm8, xmm5);
- paddd(xmm9, xmm6);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l319c);
- mov(A1, qword[ARG_BIAS]);
- movdqu(xword[A1], xmm8);
- movdqu(xword[A1+0x10], xmm9);
- add(qword[ARG_BIAS], 0x20);
- sub(N, 0x8);
- cmp(N, 0x8);
- jge(l2ca0, T_NEAR);
- align(4);
-
-L(l31c0);
- cmp(N, 0x4);
- jl(l349c, T_NEAR);
- align(4);
-
-L(l31cc);
- mov(A1, A);
- lea(A2, ptr[A1+LDA*2]);
- lea(I, ptr[A1+LDA*4]);
- mov(A, I);
- pxor(xmm7, xmm7);
- mov(I, M);
- sar(I, 0x4);
- jle(l32e4, T_NEAR);
- align(4);
-
-L(l31ec);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1+LDA*1-0x80]);
- sub(A1, -16);
- movdqu(xmm2, xword[A2-0x80]);
- movdqu(xmm3, xword[A2+LDA*1-0x80]);
- sub(A2, -16);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x80], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x70], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x60], xmm4);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x50], xmm3);
- sub(B, -64);
- dec(I);
- jg(l31ec, T_NEAR);
- align(4);
-
-L(l32e4);
- test(M, 0x8);
- jle(l3378, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- movq(xmm1, qword[A1+LDA*1-0x80]);
- sub(A1, -8);
- movq(xmm2, qword[A2-0x80]);
- movq(xmm3, qword[A2+LDA*1-0x80]);
- sub(A2, -8);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x80], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x70], xmm1);
- sub(B, -32);
- align(4);
-
-L(l3378);
- test(M, 0x4);
- jle(l33dc, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- movd(xmm1, dword[A1+LDA*1-0x80]);
- sub(A1, -4);
- movd(xmm2, dword[A2-0x80]);
- movd(xmm3, dword[A2+LDA*1-0x80]);
- sub(A2, -4);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l33dc);
- test(M, 0x2);
- jle(l3434, T_NEAR);
- mov(ax, word[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1+LDA*1-0x80]);
- sub(A1, -2);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A2+LDA*1-0x80]);
- sub(A2, -2);
- pinsrw(xmm0, eax, 0x3);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l3434);
- test(M, 0x1);
- jle(l347c, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x3);
- pmovsxbd(xmm5, xmm0);
- paddd(xmm7, xmm5);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l347c);
- mov(A1, qword[ARG_BIAS]);
- movdqu(xword[A1], xmm7);
- add(qword[ARG_BIAS], 0x10);
- sub(N, 0x4);
- cmp(N, 0x4);
- jge(l31cc, T_NEAR);
- align(4);
-
-L(l349c);
- cmp(N, 0x2);
- jl(l368a, T_NEAR);
- align(4);
-
-L(l34a8);
- mov(A1, A);
- lea(A2, ptr[A1+LDA*1]);
- lea(I, ptr[A1+LDA*2]);
- mov(A, I);
- pxor(xmm7, xmm7);
- mov(I, M);
- sar(I, 0x4);
- jle(l3558, T_NEAR);
- align(4);
-
-L(l34c8);
- movdqu(xmm0, xword[A1-0x80]);
- sub(A1, -16);
- movdqu(xmm1, xword[A2-0x80]);
- sub(A2, -16);
- movdqa(xmm2, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm2, xmm1);
- pshufd(xmm6, xmm0, 0xd8);
- pmovsxbw(xmm5, xmm6);
- movhlps(xmm6, xmm6);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x80], xmm0);
- pshufd(xmm6, xmm2, 0xd8);
- pmovsxbw(xmm5, xmm6);
- movhlps(xmm6, xmm6);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x70], xmm2);
- sub(B, -32);
- dec(I);
- jg(l34c8, T_NEAR);
- align(4);
-
-L(l3558);
- test(M, 0x8);
- jle(l35b0, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- sub(A1, -8);
- movq(xmm1, qword[A2-0x80]);
- sub(A2, -8);
- punpckldq(xmm0, xmm1);
- pshufd(xmm6, xmm0, 0xd8);
- pmovsxbw(xmm5, xmm6);
- movhlps(xmm6, xmm6);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l35b0);
- test(M, 0x4);
- jle(l35f4, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- sub(A1, -4);
- movd(xmm1, dword[A2-0x80]);
- sub(A2, -4);
- punpckldq(xmm0, xmm1);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l35f4);
- test(M, 0x2);
- jle(l3638, T_NEAR);
- mov(ax, word[A1-0x80]);
- sub(A1, -2);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A2-0x80]);
- sub(A2, -2);
- pinsrw(xmm0, eax, 0x1);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l3638);
- test(M, 0x1);
- jle(l366c, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(byte[B-0x80], al);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(byte[B-0x7f], al);
- sub(B, -2);
- pmovsxbd(xmm5, xmm0);
- paddd(xmm7, xmm5);
- align(4);
-
-L(l366c);
- mov(A1, qword[ARG_BIAS]);
- movq(qword[A1], xmm7);
- add(qword[ARG_BIAS], 0x8);
- sub(N, 0x2);
- cmp(N, 0x2);
- jge(l34a8, T_NEAR);
- align(4);
-
-L(l368a);
- cmp(N, 0x1);
- jl(l37d8, T_NEAR);
- align(4);
-
-L(l3694);
- mov(A1, A);
- add(A, LDA);
- pxor(xmm7, xmm7);
- mov(I, M);
- sar(I, 0x4);
- jle(l36ec, T_NEAR);
- align(4);
-
-L(l36a8);
- movdqu(xmm0, xword[A1-0x80]);
- sub(A1, -16);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- dec(I);
- jg(l36a8, T_NEAR);
- align(4);
-
-L(l36ec);
- test(M, 0x8);
- jle(l3728, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- sub(A1, -8);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l3728);
- test(M, 0x4);
- jle(l3760, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- sub(A1, -4);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l3760);
- test(M, 0x2);
- jle(l3794, T_NEAR);
- mov(ax, word[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- mov(word[B-0x80], ax);
- sub(A1, -2);
- sub(B, -2);
- align(4);
-
-L(l3794);
- test(M, 0x1);
- jle(l37b8, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrb(xmm0, eax, 0x0);
- pmovsxbd(xmm5, xmm0);
- paddd(xmm7, xmm5);
- mov(byte[B-0x80], al);
- sub(B, -1);
- align(4);
-
-L(l37b8);
- mov(A1, qword[ARG_BIAS]);
- movd(dword[A1], xmm7);
- add(qword[ARG_BIAS], 0x4);
- sub(N, 0x1);
- cmp(N, 0x1);
- jge(l3694, T_NEAR);
- align(4);
-
-L(l37d8);
-
- postamble();
-}
-outLocalLabel();
-
-#undef M
-#undef N
-#undef A
-#undef LDA
-#undef ALPHA
-#undef B
-#undef I
-#undef A1
-#undef A2
-#undef LDA3
-#ifdef _WIN32
-#undef ARG_ALPHA
-#undef ARG_B
-#endif
-#undef ARG_BIAS
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp
deleted file mode 100644
index c7f1393c9d..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp
+++ /dev/null
@@ -1,821 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "jit_generator.hpp"
-#include "common.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-jit_avx512_core_u8_copy_sum_bn_kern::jit_avx512_core_u8_copy_sum_bn_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
-{
-
-#ifndef _WIN32
-#define M rdi
-#define N rsi
-#define A rdx
-#define LDA rcx
-#define ALPHA r8
-#define B r9
-
-#define I rax
-#define A1 r10
-#define A2 r8
-#define LDA3 r11
-
-#define ARG_BIAS 24+stacksize+rsp
-
-#else
-
-#define M rcx
-#define N rdx
-#define A r8
-#define LDA r9
-#define ALPHA rax
-#define B rdi
-
-#define I rax
-#define A1 rsi
-#define A2 r10
-#define LDA3 r11
-
-#define ARG_ALPHA 40+stacksize+rsp
-#define ARG_B 48+stacksize+rsp
-#define ARG_BIAS 72+stacksize+rsp
-
-#endif
-
-inLocalLabel();
-{
-
-Xbyak::Label l20;
-Xbyak::Label l22c;
-Xbyak::Label l340;
-Xbyak::Label l3f8;
-Xbyak::Label l48;
-Xbyak::Label l498;
-Xbyak::Label l51c;
-Xbyak::Label l540;
-Xbyak::Label l54c;
-Xbyak::Label l56c;
-Xbyak::Label l664;
-Xbyak::Label l6f8;
-Xbyak::Label l75c;
-Xbyak::Label l7b4;
-Xbyak::Label l7fc;
-Xbyak::Label l81c;
-Xbyak::Label l828;
-Xbyak::Label l848;
-Xbyak::Label l8d8;
-Xbyak::Label l930;
-Xbyak::Label l974;
-Xbyak::Label l9b8;
-Xbyak::Label l9ec;
-Xbyak::Label la0a;
-Xbyak::Label la14;
-Xbyak::Label la28;
-Xbyak::Label la6c;
-Xbyak::Label laa8;
-Xbyak::Label lae0;
-Xbyak::Label lb14;
-Xbyak::Label lb38;
-Xbyak::Label lb58;
-
- preamble();
- auto stacksize = get_size_of_abi_save_regs();
-#ifdef _WIN32
- mov(ALPHA, ptr[ARG_ALPHA]);
- mov(B, ptr[ARG_B]);
-#endif
-
- mov(N, qword[N]);
- mov(M, qword[M]);
- mov(LDA, qword[LDA]);
- sub(A, -128);
- sub(B, -128);
- lea(LDA3, ptr[LDA+LDA*2]);
- cmp(N, 0x8);
- jl(l540, T_NEAR);
- align(4);
-
-L(l20);
- mov(A1, A);
- lea(A2, ptr[A1+LDA*4]);
- lea(I, ptr[A1+LDA*8]);
- mov(A, I);
- pxor(xmm8, xmm8);
- pxor(xmm9, xmm9);
- mov(I, M);
- sar(I, 0x4);
- jle(l22c, T_NEAR);
- align(4);
-
-L(l48);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1+LDA*1-0x80]);
- movdqu(xmm2, xword[A1+LDA*2-0x80]);
- movdqu(xmm3, xword[A1+LDA3*1-0x80]);
- sub(A1, -16);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x80], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x60], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x40], xmm4);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x20], xmm3);
- movdqu(xmm0, xword[A2-0x80]);
- movdqu(xmm1, xword[A2+LDA*1-0x80]);
- movdqu(xmm2, xword[A2+LDA*2-0x80]);
- movdqu(xmm3, xword[A2+LDA3*1-0x80]);
- sub(A2, -16);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x70], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x50], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x30], xmm4);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x10], xmm3);
- sub(B, -128);
- dec(I);
- jg(l48, T_NEAR);
- align(4);
-
-L(l22c);
- test(M, 0x8);
- jle(l340, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- movq(xmm1, qword[A1+LDA*1-0x80]);
- movq(xmm2, qword[A1+LDA*2-0x80]);
- movq(xmm3, qword[A1+LDA3*1-0x80]);
- sub(A1, -8);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x80], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x60], xmm1);
- movq(xmm0, qword[A2-0x80]);
- movq(xmm1, qword[A2+LDA*1-0x80]);
- movq(xmm2, qword[A2+LDA*2-0x80]);
- movq(xmm3, qword[A2+LDA3*1-0x80]);
- sub(A2, -8);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x70], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x50], xmm1);
- sub(B, -64);
- align(4);
-
-L(l340);
- test(M, 0x4);
- jle(l3f8, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- movd(xmm1, dword[A1+LDA*1-0x80]);
- movd(xmm2, dword[A1+LDA*2-0x80]);
- movd(xmm3, dword[A1+LDA3*1-0x80]);
- sub(A1, -4);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movd(xmm0, dword[A2-0x80]);
- movd(xmm1, dword[A2+LDA*1-0x80]);
- movd(xmm2, dword[A2+LDA*2-0x80]);
- movd(xmm3, dword[A2+LDA3*1-0x80]);
- sub(A2, -4);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x70], xmm0);
- sub(B, -32);
- align(4);
-
-L(l3f8);
- test(M, 0x2);
- jle(l498, T_NEAR);
- mov(ax, word[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A1+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A1+LDA3*1-0x80]);
- sub(A1, -2);
- pinsrw(xmm0, eax, 0x3);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x4);
- mov(ax, word[A2+LDA*1-0x80]);
- pinsrw(xmm0, eax, 0x5);
- mov(ax, word[A2+LDA*2-0x80]);
- pinsrw(xmm0, eax, 0x6);
- mov(ax, word[A2+LDA3*1-0x80]);
- sub(A2, -2);
- pinsrw(xmm0, eax, 0x7);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm6, xmm6);
- pmovsxwd(xmm6, xmm6);
- paddd(xmm9, xmm6);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l498);
- test(M, 0x1);
- jle(l51c, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A1+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A1+LDA3*1-0x80]);
- pinsrb(xmm0, eax, 0x3);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x4);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x5);
- mov(al, byte[A2+LDA*2-0x80]);
- pinsrb(xmm0, eax, 0x6);
- mov(al, byte[A2+LDA3*1-0x80]);
- pinsrb(xmm0, eax, 0x7);
- pmovsxbd(xmm5, xmm0);
- pshufd(xmm6, xmm0, 0x55);
- pmovsxbd(xmm6, xmm6);
- paddd(xmm8, xmm5);
- paddd(xmm9, xmm6);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l51c);
- mov(A1, qword[ARG_BIAS]);
- movdqu(xword[A1], xmm8);
- movdqu(xword[A1+0x10], xmm9);
- add(qword[ARG_BIAS], 0x20);
- sub(N, 0x8);
- cmp(N, 0x8);
- jge(l20, T_NEAR);
- align(4);
-
-L(l540);
- cmp(N, 0x4);
- jl(l81c, T_NEAR);
- align(4);
-
-L(l54c);
- mov(A1, A);
- lea(A2, ptr[A1+LDA*2]);
- lea(I, ptr[A1+LDA*4]);
- mov(A, I);
- pxor(xmm7, xmm7);
- mov(I, M);
- sar(I, 0x4);
- jle(l664, T_NEAR);
- align(4);
-
-L(l56c);
- movdqu(xmm0, xword[A1-0x80]);
- movdqu(xmm1, xword[A1+LDA*1-0x80]);
- sub(A1, -16);
- movdqu(xmm2, xword[A2-0x80]);
- movdqu(xmm3, xword[A2+LDA*1-0x80]);
- sub(A2, -16);
- movdqa(xmm4, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm4, xmm1);
- movdqa(xmm5, xmm2);
- punpckldq(xmm2, xmm3);
- punpckhdq(xmm5, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- movdqa(xmm3, xmm4);
- punpcklqdq(xmm4, xmm5);
- punpckhqdq(xmm3, xmm5);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x80], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x70], xmm1);
- pmovsxbw(xmm5, xmm4);
- movhlps(xmm6, xmm4);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x60], xmm4);
- pmovsxbw(xmm5, xmm3);
- movhlps(xmm6, xmm3);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x50], xmm3);
- sub(B, -64);
- dec(I);
- jg(l56c, T_NEAR);
- align(4);
-
-L(l664);
- test(M, 0x8);
- jle(l6f8, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- movq(xmm1, qword[A1+LDA*1-0x80]);
- sub(A1, -8);
- movq(xmm2, qword[A2-0x80]);
- movq(xmm3, qword[A2+LDA*1-0x80]);
- sub(A2, -8);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklqdq(xmm0, xmm2);
- punpckhqdq(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x80], xmm0);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x70], xmm1);
- sub(B, -32);
- align(4);
-
-L(l6f8);
- test(M, 0x4);
- jle(l75c, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- movd(xmm1, dword[A1+LDA*1-0x80]);
- sub(A1, -4);
- movd(xmm2, dword[A2-0x80]);
- movd(xmm3, dword[A2+LDA*1-0x80]);
- sub(A2, -4);
- punpckldq(xmm0, xmm1);
- punpckldq(xmm2, xmm3);
- punpcklqdq(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l75c);
- test(M, 0x2);
- jle(l7b4, T_NEAR);
- mov(ax, word[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1+LDA*1-0x80]);
- sub(A1, -2);
- pinsrw(xmm0, eax, 0x1);
- mov(ax, word[A2-0x80]);
- pinsrw(xmm0, eax, 0x2);
- mov(ax, word[A2+LDA*1-0x80]);
- sub(A2, -2);
- pinsrw(xmm0, eax, 0x3);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l7b4);
- test(M, 0x1);
- jle(l7fc, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A2+LDA*1-0x80]);
- pinsrb(xmm0, eax, 0x3);
- pmovsxbd(xmm5, xmm0);
- paddd(xmm7, xmm5);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l7fc);
- mov(A1, qword[ARG_BIAS]);
- movdqu(xword[A1], xmm7);
- add(qword[ARG_BIAS], 0x10);
- sub(N, 0x4);
- cmp(N, 0x4);
- jge(l54c, T_NEAR);
- align(4);
-
-L(l81c);
- cmp(N, 0x2);
- jl(la0a, T_NEAR);
- align(4);
-
-L(l828);
- mov(A1, A);
- lea(A2, ptr[A1+LDA*1]);
- lea(I, ptr[A1+LDA*2]);
- mov(A, I);
- pxor(xmm7, xmm7);
- mov(I, M);
- sar(I, 0x4);
- jle(l8d8, T_NEAR);
- align(4);
-
-L(l848);
- movdqu(xmm0, xword[A1-0x80]);
- sub(A1, -16);
- movdqu(xmm1, xword[A2-0x80]);
- sub(A2, -16);
- movdqa(xmm2, xmm0);
- punpckldq(xmm0, xmm1);
- punpckhdq(xmm2, xmm1);
- pshufd(xmm6, xmm0, 0xd8);
- pmovsxbw(xmm5, xmm6);
- movhlps(xmm6, xmm6);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x80], xmm0);
- pshufd(xmm6, xmm2, 0xd8);
- pmovsxbw(xmm5, xmm6);
- movhlps(xmm6, xmm6);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x70], xmm2);
- sub(B, -32);
- dec(I);
- jg(l848, T_NEAR);
- align(4);
-
-L(l8d8);
- test(M, 0x8);
- jle(l930, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- sub(A1, -8);
- movq(xmm1, qword[A2-0x80]);
- sub(A2, -8);
- punpckldq(xmm0, xmm1);
- pshufd(xmm6, xmm0, 0xd8);
- pmovsxbw(xmm5, xmm6);
- movhlps(xmm6, xmm6);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l930);
- test(M, 0x4);
- jle(l974, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- sub(A1, -4);
- movd(xmm1, dword[A2-0x80]);
- sub(A2, -4);
- punpckldq(xmm0, xmm1);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l974);
- test(M, 0x2);
- jle(l9b8, T_NEAR);
- mov(ax, word[A1-0x80]);
- sub(A1, -2);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A2-0x80]);
- sub(A2, -2);
- pinsrw(xmm0, eax, 0x1);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l9b8);
- test(M, 0x1);
- jle(l9ec, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrb(xmm0, eax, 0x0);
- mov(byte[B-0x80], al);
- mov(al, byte[A2-0x80]);
- pinsrb(xmm0, eax, 0x1);
- mov(byte[B-0x7f], al);
- sub(B, -2);
- pmovsxbd(xmm5, xmm0);
- paddd(xmm7, xmm5);
- align(4);
-
-L(l9ec);
- mov(A1, qword[ARG_BIAS]);
- movq(qword[A1], xmm7);
- add(qword[ARG_BIAS], 0x8);
- sub(N, 0x2);
- cmp(N, 0x2);
- jge(l828, T_NEAR);
- align(4);
-
-L(la0a);
- cmp(N, 0x1);
- jl(lb58, T_NEAR);
- align(4);
-
-L(la14);
- mov(A1, A);
- add(A, LDA);
- pxor(xmm7, xmm7);
- mov(I, M);
- sar(I, 0x4);
- jle(la6c, T_NEAR);
- align(4);
-
-L(la28);
- movdqu(xmm0, xword[A1-0x80]);
- sub(A1, -16);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- dec(I);
- jg(la28, T_NEAR);
- align(4);
-
-L(la6c);
- test(M, 0x8);
- jle(laa8, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- sub(A1, -8);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(laa8);
- test(M, 0x4);
- jle(lae0, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- sub(A1, -4);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(lae0);
- test(M, 0x2);
- jle(lb14, T_NEAR);
- mov(ax, word[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- mov(word[B-0x80], ax);
- sub(A1, -2);
- sub(B, -2);
- align(4);
-
-L(lb14);
- test(M, 0x1);
- jle(lb38, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrb(xmm0, eax, 0x0);
- pmovsxbd(xmm5, xmm0);
- paddd(xmm7, xmm5);
- mov(byte[B-0x80], al);
- sub(B, -1);
- align(4);
-
-L(lb38);
- mov(A1, qword[ARG_BIAS]);
- movd(dword[A1], xmm7);
- add(qword[ARG_BIAS], 0x4);
- sub(N, 0x1);
- cmp(N, 0x1);
- jge(la14, T_NEAR);
- align(4);
-
-L(lb58);
-
- postamble();
-}
-outLocalLabel();
-
-#undef M
-#undef N
-#undef A
-#undef LDA
-#undef ALPHA
-#undef B
-#undef I
-#undef A1
-#undef A2
-#undef LDA3
-#ifdef _WIN32
-#undef ARG_ALPHA
-#undef ARG_B
-#endif
-#undef ARG_BIAS
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp
deleted file mode 100644
index afe4f1713e..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp
+++ /dev/null
@@ -1,647 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "jit_generator.hpp"
-#include "common.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-jit_avx512_core_u8_copy_sum_bt_kern::jit_avx512_core_u8_copy_sum_bt_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
-{
-
-#ifndef _WIN32
-#define M rdi
-#define N rsi
-#define A rdx
-#define LDA rcx
-#define ALPHA r8
-#define B r9
-
-#define I rax
-#define A1 r10
-#define A2 r8
-#define LDA3 r11
-
-#define ARG_BIAS 24+stacksize+rsp
-
-#else
-
-#define M rcx
-#define N rdx
-#define A r8
-#define LDA r9
-#define ALPHA rax
-#define B rdi
-
-#define I rax
-#define A1 rsi
-#define A2 r10
-#define LDA3 r11
-
-#define ARG_ALPHA 40+stacksize+rsp
-#define ARG_B 48+stacksize+rsp
-#define ARG_BIAS 72+stacksize+rsp
-
-#endif
-
-inLocalLabel();
-{
-
-Xbyak::Label l15c;
-Xbyak::Label l1f4;
-Xbyak::Label l20;
-Xbyak::Label l248;
-Xbyak::Label l280;
-Xbyak::Label l2a4;
-Xbyak::Label l2b0;
-Xbyak::Label l2c8;
-Xbyak::Label l384;
-Xbyak::Label l3e8;
-Xbyak::Label l40;
-Xbyak::Label l424;
-Xbyak::Label l448;
-Xbyak::Label l468;
-Xbyak::Label l474;
-Xbyak::Label l48c;
-Xbyak::Label l550;
-Xbyak::Label l5bc;
-Xbyak::Label l600;
-Xbyak::Label l628;
-Xbyak::Label l646;
-Xbyak::Label l650;
-Xbyak::Label l668;
-Xbyak::Label l700;
-Xbyak::Label l760;
-Xbyak::Label l7a4;
-Xbyak::Label l7c8;
-Xbyak::Label l7e8;
-
- preamble();
- auto stacksize = get_size_of_abi_save_regs();
-#ifdef _WIN32
- mov(ALPHA, ptr[ARG_ALPHA]);
- mov(B, ptr[ARG_B]);
-#endif
-
- mov(M, qword[M]);
- mov(N, qword[N]);
- mov(LDA, qword[LDA]);
- lea(LDA3, ptr[LDA+LDA*2]);
- sub(A, -128);
- sub(B, -128);
- cmp(N, 0x8);
- jl(l2a4, T_NEAR);
- align(4);
-
-L(l20);
- mov(A1, A);
- add(A, 0x8);
- pxor(xmm8, xmm8);
- pxor(xmm9, xmm9);
- mov(I, M);
- sar(I, 0x3);
- jle(l15c, T_NEAR);
- align(4);
-
-L(l40);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm1, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm2, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm3, qword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm1);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm1, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm2, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm3, qword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x60], xmm0);
- movdqu(xword[B-0x50], xmm1);
- sub(B, -64);
- dec(I);
- jg(l40, T_NEAR);
- align(4);
-
-L(l15c);
- test(M, 0x4);
- jle(l1f4, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm1, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm2, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm3, qword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- movdqa(xmm1, xmm0);
- punpcklwd(xmm0, xmm2);
- punpckhwd(xmm1, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- pmovsxbw(xmm5, xmm1);
- movhlps(xmm6, xmm1);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm9, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movdqu(xword[B-0x70], xmm1);
- sub(B, -32);
- align(4);
-
-L(l1f4);
- test(M, 0x2);
- jle(l248, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- movq(xmm1, qword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm8, xmm5);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm6, xmm6);
- pmovsxwd(xmm6, xmm6);
- paddd(xmm9, xmm6);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l248);
- test(M, 0x1);
- jle(l280, T_NEAR);
- movq(xmm0, qword[A1-0x80]);
- add(A1, LDA);
- pmovsxbd(xmm5, xmm0);
- pshufd(xmm6, xmm0, 0x55);
- pmovsxbd(xmm6, xmm6);
- paddd(xmm8, xmm5);
- paddd(xmm9, xmm6);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l280);
- mov(A1, qword[ARG_BIAS]);
- movdqu(xword[A1], xmm8);
- movdqu(xword[A1+0x10], xmm9);
- add(qword[ARG_BIAS], 0x20);
- sub(N, 0x8);
- cmp(N, 0x8);
- jge(l20, T_NEAR);
- align(4);
-
-L(l2a4);
- cmp(N, 0x4);
- jl(l468, T_NEAR);
- align(4);
-
-L(l2b0);
- mov(A1, A);
- add(A, 0x4);
- pxor(xmm7, xmm7);
- mov(I, M);
- sar(I, 0x3);
- jle(l384, T_NEAR);
- align(4);
-
-L(l2c8);
- movd(xmm0, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm1, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm2, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm3, dword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x80], xmm0);
- movd(xmm0, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm1, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm2, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm3, dword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x70], xmm0);
- sub(B, -32);
- dec(I);
- jg(l2c8, T_NEAR);
- align(4);
-
-L(l384);
- test(M, 0x4);
- jle(l3e8, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm1, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm2, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm3, dword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- movhlps(xmm6, xmm0);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- align(4);
-
-L(l3e8);
- test(M, 0x2);
- jle(l424, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- add(A1, LDA);
- movd(xmm1, dword[A1-0x80]);
- add(A1, LDA);
- punpcklbw(xmm0, xmm1);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l424);
- test(M, 0x1);
- jle(l448, T_NEAR);
- movd(xmm0, dword[A1-0x80]);
- pmovsxbd(xmm5, xmm0);
- paddd(xmm7, xmm5);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l448);
- mov(A1, qword[ARG_BIAS]);
- movdqu(xword[A1], xmm7);
- add(qword[ARG_BIAS], 0x10);
- sub(N, 0x4);
- cmp(N, 0x4);
- jge(l2b0, T_NEAR);
- align(4);
-
-L(l468);
- cmp(N, 0x2);
- jl(l646, T_NEAR);
- align(4);
-
-L(l474);
- mov(A1, A);
- add(A, 0x2);
- pxor(xmm7, xmm7);
- mov(LDA3, M);
- sar(LDA3, 0x3);
- jle(l550, T_NEAR);
- align(4);
-
-L(l48c);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm1, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm2, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm3, eax, 0x0);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm1, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm2, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm3, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm4, eax, 0x0);
- punpcklbw(xmm1, xmm2);
- punpcklbw(xmm3, xmm4);
- punpcklwd(xmm1, xmm3);
- punpcklqdq(xmm0, xmm1);
- pshufd(xmm6, xmm0, 0xd8);
- pmovsxbw(xmm5, xmm6);
- movhlps(xmm6, xmm6);
- pmovsxbw(xmm6, xmm6);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movdqu(xword[B-0x80], xmm0);
- sub(B, -16);
- dec(LDA3);
- jg(l48c, T_NEAR);
- align(4);
-
-L(l550);
- test(M, 0x4);
- jle(l5bc, T_NEAR);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm1, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm2, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm3, eax, 0x0);
- punpcklbw(xmm0, xmm1);
- punpcklbw(xmm2, xmm3);
- punpcklwd(xmm0, xmm2);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- align(4);
-
-L(l5bc);
- test(M, 0x2);
- jle(l600, T_NEAR);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm0, eax, 0x0);
- mov(ax, word[A1-0x80]);
- add(A1, LDA);
- pinsrw(xmm1, eax, 0x0);
- punpcklbw(xmm0, xmm1);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l600);
- test(M, 0x1);
- jle(l628, T_NEAR);
- mov(ax, word[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- pmovsxbd(xmm5, xmm0);
- paddd(xmm7, xmm5);
- mov(word[B-0x80], ax);
- sub(B, -2);
- align(4);
-
-L(l628);
- mov(A1, qword[ARG_BIAS]);
- movq(qword[A1], xmm7);
- add(qword[ARG_BIAS], 0x8);
- sub(N, 0x2);
- cmp(N, 0x2);
- jge(l474, T_NEAR);
- align(4);
-
-L(l646);
- cmp(N, 0x1);
- jl(l7e8, T_NEAR);
- align(4);
-
-L(l650);
- mov(A1, A);
- add(A, 0x1);
- pxor(xmm7, xmm7);
- mov(LDA3, M);
- sar(LDA3, 0x3);
- jle(l700, T_NEAR);
- align(4);
-
-L(l668);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x3);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x4);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x5);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x6);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x7);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm6);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movq(qword[B-0x80], xmm0);
- sub(B, -8);
- dec(LDA3);
- jg(l668, T_NEAR);
- align(4);
-
-L(l700);
- test(M, 0x4);
- jle(l760, T_NEAR);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x0);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x1);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x2);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x3);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- movd(dword[B-0x80], xmm0);
- sub(B, -4);
- align(4);
-
-L(l760);
- test(M, 0x2);
- jle(l7a4, T_NEAR);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x0);
- mov(byte[B-0x80], al);
- mov(al, byte[A1-0x80]);
- add(A1, LDA);
- pinsrb(xmm0, eax, 0x1);
- pmovsxbw(xmm5, xmm0);
- phaddw(xmm5, xmm5);
- pmovsxwd(xmm5, xmm5);
- paddd(xmm7, xmm5);
- mov(byte[B-0x7f], al);
- sub(B, -2);
- align(4);
-
-L(l7a4);
- test(M, 0x1);
- jle(l7c8, T_NEAR);
- mov(al, byte[A1-0x80]);
- pinsrw(xmm0, eax, 0x0);
- pmovsxbd(xmm5, xmm0);
- paddd(xmm7, xmm5);
- mov(byte[B-0x80], al);
- sub(B, -1);
- align(4);
-
-L(l7c8);
- mov(A1, qword[ARG_BIAS]);
- movd(dword[A1], xmm7);
- add(qword[ARG_BIAS], 0x4);
- sub(N, 0x1);
- cmp(N, 0x1);
- jge(l650, T_NEAR);
- align(4);
-
-L(l7e8);
-
- postamble();
-}
-outLocalLabel();
-
-#undef M
-#undef N
-#undef A
-#undef LDA
-#undef ALPHA
-#undef B
-#undef I
-#undef A1
-#undef A2
-#undef LDA3
-#ifdef _WIN32
-#undef ARG_ALPHA
-#undef ARG_B
-#endif
-#undef ARG_BIAS
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp
deleted file mode 100644
index 4fc11afcbc..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp
+++ /dev/null
@@ -1,116 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <cstdint>
-
-#include "math_utils.hpp"
-#include "mkldnn_thread.hpp"
-#include "utils.hpp"
-
-#include "../f32/ref_gemm_f32.hpp"
-#include "jit_generator.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <typename b_dt>
-mkldnn_status_t ref_gemm_s8x8s32(const char *transa, const char *transb,
- const char *offsetc, const int *M, const int *N, const int *K,
- const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
- const b_dt *B, const int *LDB, const int8_t *bo, const float *beta,
- int32_t *C, const int *LDC, const int32_t *co) {
-
- if (*M == 0 || *N == 0 || *K == 0)
- return mkldnn_success;
-
- bool OCisR = (*offsetc == 'R' || *offsetc == 'r');
- bool OCisC = (*offsetc == 'C' || *offsetc == 'c');
- bool AisN = (*transa == 'N' || *transa == 'n');
- bool BisN = (*transb == 'N' || *transb == 'n');
-
- int m = *M, n = *N, k = *K, lda = *LDA, ldb = *LDB, ldc = *LDC;
- size_t sizeA = AisN ? lda * k : lda * m;
- size_t sizeB = BisN ? ldb * n : ldb * k;
- size_t sizeC = ldc * n;
-
- double *dA = (double *)malloc(sizeA * sizeof(double), PAGE_4K);
- double *dB = (double *)malloc(sizeB * sizeof(double), PAGE_4K);
- double *dC = (double *)malloc(sizeC * sizeof(double), PAGE_4K);
-
- if (utils::any_null(dA, dB, dC)) {
- free(dA);
- free(dB);
- free(dC);
- return mkldnn_out_of_memory;
- }
-
- auto da_setter = [=] (int i, int j, double v) { dA[j * lda + i] = v; };
- auto db_setter = [=] (int i, int j, double v) { dB[j * ldb + i] = v; };
-
- auto ia_accessor = [=] (int i, int j) { return A[j * lda + i]; };
- auto ib_accessor = [=] (int i, int j) { return B[j * ldb + i]; };
-
- const int a_rows = AisN ? m : k;
- const int a_cols = AisN ? k : m;
- mkldnn::impl::parallel_nd(a_cols, a_rows, [&](int j, int i) {
- da_setter(i, j,
- static_cast<double>(ia_accessor(i, j)) + static_cast<double>(ao[0]));
- });
-
- const int b_rows = BisN ? k : n;
- const int b_cols = BisN ? n : k;
- mkldnn::impl::parallel_nd(b_cols, b_rows, [&](int j, int i) {
- db_setter(i, j,
- static_cast<double>(ib_accessor(i, j)) + static_cast<double>(bo[0]));
- });
- double one = 1.0, zero = 0.0;
- ref_gemm<double>(transa, transb, M, N, K, &one, dA, LDA, dB, LDB, &zero,
- dC, LDC, nullptr);
-
- auto i2d = [=] (int32_t v) { return static_cast<double>(v); };
- auto f2d = [=] (float v) { return static_cast<double>(v); };
-
- mkldnn::impl::parallel_nd(n, m, [&] (int j, int i) {
- double coffset = OCisR ? i2d(co[j]) : OCisC ? i2d(co[i]) : i2d(co[0]);
- double val = ((*beta == 0.0f) ? 0.0 : f2d(*beta) * i2d(C[i + j * ldc]))
- + f2d(*alpha) * dC[i + j * ldc] + coffset;
- C[i + j * ldc] = math::out_round<int32_t>(math::saturate<int32_t>(val));
- });
-
- free(dA);
- free(dB);
- free(dC);
- return mkldnn_success;
-}
-
-template mkldnn_status_t ref_gemm_s8x8s32<uint8_t>(
- const char *transa, const char *transb, const char *offsetc,
- const int *M, const int *N, const int *K,
- const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
- const uint8_t *B, const int *LDB, const int8_t *bo,
- const float *beta, int32_t *C, const int *LDC, const int32_t *co);
-
-template mkldnn_status_t ref_gemm_s8x8s32<int8_t>(
- const char *transa, const char *transb, const char *offsetc,
- const int *M, const int *N, const int *K,
- const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
- const int8_t *B, const int *LDB, const int8_t *bo,
- const float *beta, int32_t *C, const int *LDC, const int32_t *co);
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp
deleted file mode 100644
index 6c0370ae99..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp
+++ /dev/null
@@ -1,38 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef REF_GEMM_S8X8S32_HPP
-#define REF_GEMM_S8X8S32_HPP
-
-#include <stdint.h>
-
-#include "mkldnn_types.h"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <typename b_dt>
-mkldnn_status_t ref_gemm_s8x8s32(const char *transa, const char *transb,
- const char *offsetc, const int *M, const int *N, const int *K,
- const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
- const b_dt *B, const int *LDB, const int8_t *bo, const float *beta,
- int32_t *C, const int *LDC, const int32_t *co);
-
-}
-}
-}
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp
deleted file mode 100644
index de1035f3b2..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp
+++ /dev/null
@@ -1,180 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "common.hpp"
-#include "nstl.hpp"
-#include "math_utils.hpp"
-
-#include "../gemm.hpp"
-#include "jit_avx512_core_gemm_s8u8s32.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-void compensation_init(const char *offsetC, int32_t *compensation, int len,
- const int32_t *oc) {
- bool OCisC = (*offsetC == 'C' || *offsetC == 'c');
- bool OCisF = (*offsetC == 'F' || *offsetC == 'f');
-
- if (OCisF && (*oc) != 0) {
- for (int i = 0; i < len; i++)
- compensation[i] = *oc;
- } else if (OCisC) {
- for (int i = 0; i < len; i++)
- compensation[i] = oc[i];
- } else {
- parallel_nd(len, [=](int i) { compensation[i] = 0; });
- }
-}
-
-void compensation_compute(bool transa, int m, int k, float alpha,
- const int8_t *a, int lda, int32_t *compensation) {
- if (!transa) {
- const int L2_cache_size = get_cache_size(2, true);
- const int blocking_factor = nstl::min(k, L2_cache_size / lda + 1);
- const int npanels = k / blocking_factor;
- const bool has_tile = k % blocking_factor > 0;
-
- parallel_nd(npanels, m, [&](int j, int i) {
- int32_t val = 0;
- for (int jb = 0; jb < blocking_factor; jb++) {
- val += a[(i + (ptrdiff_t)j * blocking_factor * lda)
- + (ptrdiff_t)jb * lda];
- }
- if (alpha != 1.0f) {
- val = math::out_round<int32_t>(math::saturate<int32_t>(
- (double)val * alpha * -128.0));
- } else {
- val *= -128;
- }
- fetch_and_add(&compensation[i], val);
- });
-
- if (has_tile) {
- parallel_nd(m, [=](int i) {
- int32_t val = 0;
- for (int j = npanels * blocking_factor; j < k; j++) {
- val += a[i + (ptrdiff_t)j * lda];
- }
- if (alpha != 1.0f) {
- val = math::out_round<int32_t>(math::saturate<int32_t>(
- (double)val * alpha * -128.0));
- } else {
- val *= -128;
- }
- fetch_and_add(&compensation[i], val);
- });
- }
- } else {
- parallel_nd(m, [=](int i) {
- int32_t val = 0;
- for (int j = 0; j < k; j++) {
- val += a[j + (ptrdiff_t)i * lda];
- }
- if (alpha != 1.0f) {
- val = math::out_round<int32_t>(math::saturate<int32_t>(
- (double)val * alpha * -128.0));
- } else {
- val *= -128;
- }
- compensation[i] += val;
- });
- }
-}
-
-void copy_and_shift_b(bool transb, int k, int n, uint8_t *b_u8, int ldb_u8,
- const int8_t *b_s8, int ldb_s8) {
- const int b_cols = transb ? k : n;
-
- parallel_nd(b_cols, [=](int j) {
- const int b_rows = transb ? n : k;
-
- uint8_t *pb_u8 = b_u8 + j * ldb_u8;
- const int8_t *pb_s8 = b_s8 + j * ldb_s8;
-
- for (int i = 0; i < b_rows; i++) {
- (*pb_u8) = (*pb_s8) + 128;
- pb_u8++;
- pb_s8++;
- }
- });
-}
-
-/**
- * gemm_s8s8s32 operation is defined as follows:
- * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset + compensation
- *
- * where
- * - compensation is a vector of length m that contains computed compensation
- * that may contain C_offset if applicable. The compensation is applied inside
- * gemm_s8u8s32 as a C_offset
- * - B_shift is a k-by-n matrix, every element of B_shift is equal to 128
- *
- * What is the compensation:
- * In order to prepare the matrix B for gemm_s8u8s32 call the B_shift is applied:
- * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset =
- * alpha * op(A) * op(B) + alpha * op(A) * B_shift + beta * C + C_offset
- * compensation = -alpha * op(A) * B_shift
- * Since B_shift is a matrix, every element of which is equal to 128 then
- * - if op(A) = A: compensation contains sum of the elements in each row
- * scaled by -128 * alpha
- * - if op(A) = A**T: compensation contains sum of the elements in each column
- * scaled by -128 * alpha
- *
- * The rest of parameters is described in mkldnn.h
- */
-mkldnn_status_t simple_gemm_s8s8s32(
- const char *transA, const char *transB, const char *offsetC,
- const int *m, const int *n, const int *k,
- const float *alpha, const int8_t *a, const int *lda, const int8_t *oa,
- const int8_t *b, const int *ldb, const int8_t *ob,
- const float *beta, int32_t *c, const int *ldc, const int32_t *oc) {
- if (*oa != 0 || *ob != 0) return mkldnn_unimplemented;
-
- int M = *m, N = *n, K = *k;
- bool transa = (*transA == 'T' || *transA == 't');
- bool transb = (*transB == 'T' || *transB == 't');
- int ld = transb ? N : K;
-
- uint8_t *b_u8 = (uint8_t *)malloc(sizeof(uint8_t) * K * N, 64);
- int32_t *compensation = (int32_t *)malloc(sizeof(int32_t) * M, 64);
-
- if (utils::any_null(b_u8, compensation)) {
- free(b_u8);
- free(compensation);
- return mkldnn_out_of_memory;
- }
-
- compensation_init(offsetC, compensation, M, oc);
- compensation_compute(transa, M, K, *alpha, a, *lda, compensation);
- copy_and_shift_b(transb, K, N, b_u8, ld, b, *ldb);
-
- gemm_s8x8s32(transA, transB, "C", m, n, k, alpha, a, lda, oa, b_u8,
- &ld, ob, beta, c, ldc, compensation);
-
- if ((*offsetC == 'R' || *offsetC == 'r'))
- parallel_nd(M, N,
- [=](int i, int j) { c[i + (ptrdiff_t)j * *ldc] += oc[j]; });
-
- free(b_u8);
- free(compensation);
-
- return mkldnn_success;
-}
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp
deleted file mode 100644
index 03a3d2f7e0..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp
+++ /dev/null
@@ -1,37 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef SIMPLE_GEMM_S8S8S32_HPP
-#define SIMPLE_GEMM_S8S8S32_HPP
-
-#include <stdint.h>
-#include "mkldnn_types.h"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-mkldnn_status_t simple_gemm_s8s8s32(
- const char *transA, const char *transB, const char *offsetC,
- const int *m, const int *n, const int *k,
- const float *alpha, const int8_t *a, const int *lda, const int8_t *oa,
- const int8_t *b, const int *ldb, const int8_t *ob,
- const float *beta, int32_t *c, const int *ldc, const int32_t *oc);
-}
-}
-}
-
-#endif // SIMPLE_GEMM_S8S8S32_HPP
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp
deleted file mode 100644
index 604a728b47..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp
+++ /dev/null
@@ -1,307 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "mkldnn_types.h"
-
-#include "c_types_map.hpp"
-#include "gemm_convolution.hpp"
-#include "utils.hpp"
-#include "type_helpers.hpp"
-#include "mkldnn_thread.hpp"
-#include "ref_eltwise.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::memory_tracking::names;
-using namespace mkldnn::impl::utils;
-
-void gemm_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
-
- auto col = scratchpad(ctx).get<data_t>(key_conv_gemm_col);
-
- const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
-
- const int M = jcp.os * jcp.od;
- const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id;
- const size_t dst_step = jcp.oc * M;
- const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
-
- assert(IMPLICATION(
- jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow));
- assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1));
-
- const int K = jcp.ic * jcp.ks;
- const int N = jcp.oc;
-
- if (jcp.im2col_sz && jcp.id != 1)
- parallel_nd(jcp.im2col_sz * jcp.nthr,
- [&](ptrdiff_t i) { col[i] = (data_t)0; });
-
- const int nb_oh = div_up(jcp.oh, jcp.oh_block);
- const int nb_ow = div_up(jcp.ow, jcp.ow_block);
- const size_t work_amount = jcp.ngroups * jcp.mb * jcp.od * nb_oh * nb_ow;
- parallel(jcp.nthr, [&](const int ithr, const int nthr) {
- data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
-
- int g{ 0 }, n{ 0 }, od{ 0 }, ohb{ 0 }, owb{ 0 };
- size_t start = 0, end = 0;
-
- balance211(work_amount, nthr, ithr, start, end);
- nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb,
- nb_oh, owb, nb_ow);
- for (size_t iwork = start; iwork < end; ++iwork) {
- int oh = ohb * jcp.oh_block;
- int ow = owb * jcp.ow_block;
- const data_t *_src = src + (n * jcp.ngroups + g) * src_step;
- const data_t *_weights = weights + g * weights_g_size;
- data_t *_dst_im = dst + (n * jcp.ngroups + g) * dst_step;
- const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh);
- const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow);
- if (jcp.im2col_sz) {
- if (jcp.id == 1)
- jit_gemm_convolution_utils::im2col(
- jcp, _src, _col, oh, h_step, ow, w_step);
- else
- jit_gemm_convolution_utils::im2col_3d(jcp, _src, _col, od);
- }
-
- const data_t one = 1.0;
-
- const int m = h_step * w_step;
- const int LDA = jcp.im2col_sz ? m : M;
- data_t *_dst = _dst_im + od * jcp.os + oh * jcp.ow + ow;
-
- extended_sgemm("N", "N", &m, &N, &K, &one,
- jcp.im2col_sz ? _col : _src + od * m, &LDA, _weights, &K,
- &this->beta_, _dst, &M);
-
- data_t *d = _dst;
- if (eltwise_) {
- // fast branch for ReLU case
- if (eltwise_->alg_ == alg_kind::eltwise_relu) {
- parallel_nd(jcp.oc, [&](const int oc) {
- data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0;
- data_t *d_ = d + oc * M;
- PRAGMA_OMP_SIMD()
- for (int oS = 0; oS < m; ++oS) {
- d_[oS] += b;
- if (d_[oS] < 0) d_[oS] *= eltwise_->alpha_;
- }
- });
- } else {
- parallel_nd(jcp.oc, [&](const int oc) {
- data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0;
- data_t *d_ = d + oc * M;
- PRAGMA_OMP_SIMD()
- for (int oS = 0; oS < m; ++oS) {
- d_[oS] += b;
- d_[oS] = eltwise_->compute_scalar(d_[oS]);
- }
- });
- }
- } else if (jcp.with_bias) {
- parallel_nd(jcp.oc, [&](const int oc) {
- data_t b = bias[g * jcp.oc + oc];
- data_t *d_ = d + oc * M;
- PRAGMA_OMP_SIMD()
- for (int oS = 0; oS < m; ++oS) {
- d_[oS] += b;
- }
- });
- }
- nd_iterator_step(g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb, nb_oh,
- owb, nb_ow);
- }
- });
-}
-
-void gemm_convolution_bwd_data_t::execute_backward_data(
- const exec_ctx_t &ctx) const {
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
-
- auto col = scratchpad(ctx).get<data_t>(key_conv_gemm_col);
-
- const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
-
- const int M = jcp.os * jcp.od;
- const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id;
- const size_t dst_step = jcp.oc * M;
- const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
-
- const int m = jcp.os;
- const int K = jcp.oc;
- const int N = jcp.ic * jcp.ks;
- const int LDC = jcp.im2col_sz ? m : M;
-
- const size_t work_amount = (size_t)jcp.ngroups * jcp.mb;
-
- if (jcp.id > 1) {
- const ptrdiff_t diff_src_sz = (ptrdiff_t)(work_amount * src_step);
- parallel_nd(diff_src_sz, [&](ptrdiff_t i) { diff_src[i] = (data_t)0; });
- }
-
- parallel(jcp.nthr, [&](const int ithr, const int nthr) {
- data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
-
- int g{0}, n{0};
- size_t start = 0, end = 0;
- balance211(work_amount, nthr, ithr, start, end);
- nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb);
- for (size_t iwork = start; iwork < end; ++iwork) {
-
- data_t *_diff_src = diff_src + (n * jcp.ngroups + g)*src_step;
- const data_t *_weights = weights + g * weights_g_size;
- for (int od = 0; od < jcp.od; ++od) {
- const data_t *_diff_dst = diff_dst + (n * jcp.ngroups + g)
- *dst_step + od * m;
-
- const data_t zero = 0.0, one = 1.0;
- extended_sgemm("N", "T", &m, &N, &K, &one, _diff_dst, &M,
- _weights, &N, &zero,
- jcp.im2col_sz ? _col:_diff_src + od * m, &LDC);
-
- if (jcp.im2col_sz) {
- if (jcp.id == 1)
- jit_gemm_convolution_utils::col2im(jcp, _col,
- _diff_src);
- else
- jit_gemm_convolution_utils::col2im_3d(jcp, _col,
- _diff_src, od);
- }
- }
- nd_iterator_step(g, jcp.ngroups, n, jcp.mb);
- }
- });
-}
-
-void gemm_convolution_bwd_weights_t::execute_backward_weights(
- const exec_ctx_t &ctx) const {
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
- auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
-
- auto col = scratchpad(ctx).get<data_t>(key_conv_gemm_col);
- auto wei_reduction = scratchpad(ctx).get<data_t>(key_conv_wei_reduction);
-
- const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
-
- const int K = jcp.os * jcp.od;
- const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id;
- const size_t dst_step = jcp.oc * K;
- const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
-
- const int k = jcp.os;
- const int N = jcp.oc;
- const int M = jcp.ic * jcp.ks;
- const int LDA = jcp.im2col_sz ? k : K;
-
- parallel_nd(jcp.im2col_sz * jcp.nthr,
- [&](ptrdiff_t i) { col[i] = (data_t)0; });
-
- parallel(jcp.nthr, [&](const int ithr, const int nthr) {
- int ithr_g, nthr_g, ithr_mb, nthr_mb;
- size_t g_start{0}, g_end{0}, mb_start{0}, mb_end{0};
-
- const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
- jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups,
- mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb);
-
- assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
- const int need_reduction = nthr_mb != 1;
-
- if (ithr_g != -1 && ithr_mb != -1) {
- balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
- balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
-
- assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
-
- data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
- data_t *weights_reduce_base = wei_reduction
- + ithr_g * nthr_mb * weights_g_size;
- data_t *weights_reduce = weights_reduce_base
- + ithr_mb * weights_g_size;
-
- for (size_t g = g_start; g < g_end; ++g) {
- data_t *_diff_weights = need_reduction
- ? weights_reduce : (diff_weights + g * weights_g_size);
- for (size_t mb = mb_start; mb < mb_end; ++mb) {
- const data_t *_src = src + (mb*jcp.ngroups+g)*src_step;
- for (int od = 0; od < jcp.od; ++od) {
- const data_t *_diff_dst = diff_dst
- + (mb*jcp.ngroups+g)*dst_step + od * k;
-
- if (jcp.im2col_sz) {
- if (jcp.id == 1)
- jit_gemm_convolution_utils::im2col(
- jcp, _src, _col, 0, jcp.oh, 0, jcp.ow);
- else
- jit_gemm_convolution_utils::im2col_3d(jcp, _src,
- _col, od);
- }
-
- const data_t zero = 0.0, one = 1.0;
- extended_sgemm(
- "T", "N", &M, &N, &k, &one,
- jcp.im2col_sz ? _col : _src + od * k,
- &LDA, _diff_dst, &K,
- mb == mb_start && od == 0 ? &zero : &one,
- _diff_weights, &M);
- }
- }
- }
- if (need_reduction) {
- mkldnn_thr_barrier();
- data_t *weights_base = diff_weights + g_start * weights_g_size;
- jit_gemm_convolution_utils::bwd_weights_reduction_par(
- ithr_mb, nthr_mb, jcp, weights_reduce_base, weights_base);
- }
- } else
- if (need_reduction) { mkldnn_thr_barrier(); }
- });
-
- if (jcp.with_bias) {
- parallel_nd(jcp.ngroups, jcp.oc, [&](int g, int oc) {
- data_t db = 0;
- size_t offset_ = (size_t)g * dst_step + (size_t)oc * K;
- for (int mb = 0; mb < jcp.mb; ++mb)
- {
- size_t offset = offset_ + (size_t)mb * jcp.ngroups * dst_step;
- for (int od = 0; od < jcp.od; ++od)
- for (int oh = 0; oh < jcp.oh; ++oh)
- PRAGMA_OMP_SIMD(reduction(+:db))
- for (int ow = 0; ow < jcp.ow; ++ow) {
- db += diff_dst[offset];
- offset++;
- }
- }
- diff_bias[g*jcp.oc+oc] = db;
- });
- }
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp
deleted file mode 100644
index 302e46369a..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp
+++ /dev/null
@@ -1,250 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_GEMM_CONVOLUTION_HPP
-#define CPU_JIT_GEMM_CONVOLUTION_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-
-#include "gemm_convolution_utils.hpp"
-#include "gemm/gemm.hpp"
-#include "ref_eltwise.hpp"
-
-#include "cpu_convolution_pd.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct gemm_convolution_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_fwd_pd_t {
- pd_t(engine_t *engine,
- const convolution_desc_t *adesc, const primitive_attr_t *attr,
- const typename pd_t::base_class *hint_fwd_pd)
- : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_fwd_t);
-
- status_t init() {
- bool ok = true
- && is_fwd()
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::f32, data_type::f32, data_type::f32)
- && !has_zero_dim_memory()
- && set_default_formats_common(dat_tag(), wei_tag(), dat_tag())
- && post_ops_ok()
- && memory_desc_matches_tag(*src_md(), dat_tag())
- && memory_desc_matches_tag(*dst_md(), dat_tag())
- && memory_desc_matches_tag(*weights_md(), wei_tag());
- if (!ok) return status::unimplemented;
-
- auto scratchpad = scratchpad_registry().registrar();
- return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
- *desc(), src_md(), weights_md(0), dst_md(),
- mkldnn_get_max_threads());
- }
-
- jit_gemm_conv_conf_t jcp_;
-
- protected:
- format_tag_t dat_tag() const {
- using namespace format_tag;
- return utils::pick(ndims() - 3, ncw, nchw, ncdhw);
- }
-
- format_tag_t wei_tag() const {
- using namespace format_tag;
- return with_groups()
- ? utils::pick(ndims() - 3, goiw, goihw, goidhw)
- : utils::pick(ndims() - 3, oiw, oihw, oidhw);
- }
-
- bool post_ops_ok() const {
- auto const &po = attr()->post_ops_;
- auto is_eltwise = [&](int idx)
- { return po.entry_[idx].is_eltwise(); };
- auto is_sum = [&](int idx) { return po.entry_[idx].is_sum(); };
-
- switch (po.len_) {
- case 0: return true; // no post_ops
- case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
- case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
- default: return false;
- }
- return false;
- }
- };
-
- gemm_convolution_fwd_t(const pd_t *apd)
- : cpu_primitive_t(apd, true)
- , eltwise_(nullptr)
- {
- const auto &post_ops = pd()->attr()->post_ops_;
- const data_t one = 1.0, zero = 0.0;
- beta_ = post_ops.find(primitive_kind::sum) >= 0 ? one : zero;
-
- const int entry_idx = post_ops.find(primitive_kind::eltwise);
- if (entry_idx != -1) eltwise_ = new ref_eltwise_scalar_fwd_t(
- post_ops.entry_[entry_idx].eltwise);
- }
-
- ~gemm_convolution_fwd_t() { delete eltwise_; }
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- data_t beta_;
-
- ref_eltwise_scalar_fwd_t* eltwise_;
-};
-
-struct gemm_convolution_bwd_data_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_bwd_data_pd_t {
- pd_t(engine_t *engine,
- const convolution_desc_t *adesc, const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_data_t);
-
- status_t init() {
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_data
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::undef, data_type::f32, data_type::f32)
- && !has_zero_dim_memory()
- && set_default_formats_common(dat_tag(), wei_tag(), dat_tag())
- && memory_desc_matches_tag(*diff_src_md(), dat_tag())
- && memory_desc_matches_tag(*diff_dst_md(), dat_tag())
- && memory_desc_matches_tag(*weights_md(), wei_tag());
- if (!ok) return status::unimplemented;
-
- auto scratchpad = scratchpad_registry().registrar();
- return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
- *desc(), diff_src_md(), weights_md(0), diff_dst_md(),
- mkldnn_get_max_threads());
- }
-
- jit_gemm_conv_conf_t jcp_;
-
- protected:
- format_tag_t dat_tag() const {
- using namespace format_tag;
- return utils::pick(ndims() - 3, ncw, nchw, ncdhw);
- }
-
- format_tag_t wei_tag() const {
- using namespace format_tag;
- return with_groups()
- ? utils::pick(ndims() - 3, goiw, goihw, goidhw)
- : utils::pick(ndims() - 3, oiw, oihw, oidhw);
- }
- };
-
- gemm_convolution_bwd_data_t(const pd_t *apd)
- : cpu_primitive_t(apd, true) {}
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward_data(ctx);
- return status::success;
- }
-
-private:
- void execute_backward_data(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-struct gemm_convolution_bwd_weights_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_bwd_weights_pd_t {
- pd_t(engine_t *engine,
- const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_weights_t);
-
- status_t init() {
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_weights
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::f32, data_type::f32, data_type::f32)
- && !has_zero_dim_memory()
- && set_default_formats_common(dat_tag(), wei_tag(), dat_tag())
- && memory_desc_matches_tag(*src_md(), dat_tag())
- && memory_desc_matches_tag(*diff_dst_md(), dat_tag())
- && memory_desc_matches_tag(*diff_weights_md(), wei_tag());
- if (!ok) return status::unimplemented;
-
- auto scratchpad = scratchpad_registry().registrar();
- return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
- *desc(), src_md(), diff_weights_md(0), diff_dst_md(),
- mkldnn_get_max_threads());
- }
-
- jit_gemm_conv_conf_t jcp_;
-
- protected:
- format_tag_t dat_tag() const {
- using namespace format_tag;
- return utils::pick(ndims() - 3, ncw, nchw, ncdhw);
- }
-
- format_tag_t wei_tag() const {
- using namespace format_tag;
- return with_groups()
- ? utils::pick(ndims() - 3, goiw, goihw, goidhw)
- : utils::pick(ndims() - 3, oiw, oihw, oidhw);
- }
- };
-
- gemm_convolution_bwd_weights_t(const pd_t *apd)
- : cpu_primitive_t(apd, true) {}
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward_weights(ctx);
- return status::success;
- }
-
-private:
- void execute_backward_weights(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp
deleted file mode 100644
index f133b1e62b..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp
+++ /dev/null
@@ -1,771 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "mkldnn_types.h"
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "mkldnn_thread.hpp"
-#include "utils.hpp"
-#include "cpu_isa_traits.hpp"
-
-#include "gemm_convolution_utils.hpp"
-#include "jit_generator.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::utils;
-using namespace prop_kind;
-using namespace data_type;
-
-namespace jit_gemm_convolution_utils {
-
-void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col,
- int od)
-{
- const size_t OHW = jcp.oh * jcp.ow;
- const size_t im_step = jcp.ih * jcp.iw * jcp.id;
- const size_t col_step = jcp.ks * OHW;
-
- parallel_nd(jcp.ic, [&](int ic) {
- const float *__restrict im_loc = im + ic * im_step;
- float *__restrict col_loc = col + ic * col_step;
- int id = od * jcp.stride_d - jcp.f_pad;
- for (int kd = 0; kd < jcp.kd; ++kd) {
- float *__restrict col_ = col_loc + kd * jcp.kh * jcp.kw * OHW;
- if (id < 0 || id >= jcp.id) {
- int ih_ = -jcp.t_pad;
- for (int kh = 0; kh < jcp.kh; ++kh) {
- int ih = ih_;
- for (int oh = 0; oh < jcp.oh; ++oh) {
- if (ih < 0 || ih >= jcp.ih) {
- ih += jcp.stride_h;
- continue;
- }
- int iw_ = -jcp.l_pad;
- for (int kw = 0; kw < jcp.kw; ++kw) {
- int iw = iw_;
- for (int ow = 0; ow < jcp.ow; ++ow) {
- if (iw < 0 || iw >= jcp.iw) {
- iw += jcp.stride_w;
- continue;
- }
-
- const size_t col_idx = kw * OHW + oh * jcp.ow
- + ow;
-
- col_[col_idx] = 0;
- iw += jcp.stride_w;
- }
- iw_ += (1 + jcp.dilate_w);
- }
- ih += jcp.stride_h;
- }
- ih_ += (1 + jcp.dilate_h);
- col_ += jcp.kw * OHW;
- }
- } else {
- const float *__restrict im_ = im_loc + id * jcp.ih * jcp.iw;
- int ih_ = -jcp.t_pad;
- for (int kh = 0; kh < jcp.kh; ++kh) {
- int ih = ih_;
- for (int oh = 0; oh < jcp.oh; ++oh) {
- if (ih < 0 || ih >= jcp.ih) {
- ih += jcp.stride_h;
- continue;
- }
- int iw_ = -jcp.l_pad;
- for (int kw = 0; kw < jcp.kw; ++kw) {
- int iw = iw_;
- for (int ow = 0; ow < jcp.ow; ++ow) {
- if (iw < 0 || iw >= jcp.iw) {
- iw += jcp.stride_w;
- continue;
- }
-
- const size_t col_idx = kw * OHW + oh * jcp.ow
- + ow;
- const size_t im_idx = ih * jcp.iw + iw;
-
- col_[col_idx] = im_[im_idx];
- iw += jcp.stride_w;
- }
- iw_ += (1 + jcp.dilate_w);
- }
- ih += jcp.stride_h;
- }
- ih_ += (1 + jcp.dilate_h);
- col_ += jcp.kw * OHW;
- }
- }
- id += (1 + jcp.dilate_d);
- }
- });
-}
-
-/* col[ic][kh][kw][oh][ow] <-- im2col(im[ic][ih][iw]) */
-void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im,
- float *__restrict col, int hs, int hb, int ws, int wb) {
- const size_t im_step = jcp.is;
- const size_t col_step = jcp.ks * hb * wb;
- if (jcp.stride_w == 1) {
- // Generated code is more optimized for stride_w == 1
- // because innermost loop is by width
- auto ker = [&](int ic, int kh, int kw, int oh) {
- const float *__restrict im_ = im + ic * im_step;
- float *__restrict col_
- = col + ic * col_step + ((kh * jcp.kw + kw) * hb + oh) * wb;
-
- const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
- + kh * (1 + jcp.dilate_h);
- if (ih < 0 || ih >= jcp.ih) {
- for (int ow = 0; ow < wb; ++ow)
- col_[ow] = 0.f;
- } else {
- for (int ow = 0; ow < wb; ++ow) {
- const int iw = ow + ws - jcp.l_pad + kw * (1 + jcp.dilate_w);
- if (iw < 0 || iw >= jcp.iw)
- col_[ow] = 0.f;
- else {
- const size_t im_idx = ih * jcp.iw + iw;
- col_[ow] = im_[im_idx];
- }
- }
- }
- };
-
- if (jcp.outer_threading) {
- for (int ic = 0; ic < jcp.ic; ic++)
- for (int kh = 0; kh < jcp.kh; kh++)
- for (int kw = 0; kw < jcp.kw; kw++)
- for (int oh = 0; oh < hb; oh++)
- ker(ic, kh, kw, oh);
- }
- else {
- parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, ker);
- }
- } else if (jcp.ic == 1) {
- parallel_nd(jcp.kh, hb, [&](int kh, int oh) {
- const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
- + kh * (1 + jcp.dilate_h);
- if (ih < 0 || ih >= jcp.ih)
- for (int kw = 0; kw < jcp.kw; ++kw) {
- for (int ow = 0; ow < wb; ++ow) {
- const size_t col_idx
- = ((kh * jcp.kw + kw) * hb + oh) * wb + ow;
- col[col_idx] = 0;
- }
- }
- else
- for (int kw = 0; kw < jcp.kw; ++kw) {
- for (int ow = 0; ow < wb; ++ow) {
- const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad
- + kw * (1 + jcp.dilate_w);
- const size_t col_idx
- = ((kh * jcp.kw + kw) * hb + oh) * wb + ow;
- const size_t im_idx = ih * jcp.iw + iw;
- if (iw < 0 || iw >= jcp.iw)
- col[col_idx] = 0;
- else
- col[col_idx] = im[im_idx];
- }
- }
- });
- } else {
-
- parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb,
- [&](int ic, int kh, int kw, int oh) {
- const float *__restrict im_ = im + ic * im_step;
- float *__restrict col_ = col + ic * col_step
- + ((kh * jcp.kw + kw) * hb + oh) * wb;
-
- const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
- + kh * (1 + jcp.dilate_h);
- if (ih < 0 || ih >= jcp.ih) {
- for (int ow = 0; ow < wb; ++ow)
- col_[ow] = 0.f;
- } else {
- for (int ow = 0; ow < wb; ++ow) {
- const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad
- + kw * (1 + jcp.dilate_w);
- const size_t im_idx = ih * jcp.iw + iw;
- if (iw < 0 || iw >= jcp.iw)
- col_[ow] = 0.f;
- else
- col_[ow] = im_[im_idx];
- }
- }
- });
- }
-}
-
-inline int limit(int low, int upper, int value) {
- return nstl::max(low, nstl::min(upper, value));
-}
-
-/* col[kh][kw][ic][oh][ow] <-- im2col_u8(im[ih][iw][ic]) */
-template <typename T>
-void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im,
- T *__restrict imtr, uint8_t *__restrict col, int hs, int hb, int ws,
- int wb) {
- uint8_t shift = jcp.signed_input ? 128 : 0;
- const int dh = 1 + jcp.dilate_h;
- const int dw = 1 + jcp.dilate_w;
- const int sh = jcp.stride_h;
- const int sw = jcp.stride_w;
- const int im_iw_stride = jcp.ic * jcp.ngroups;
- const int im_ih_stride = jcp.iw * im_iw_stride;
- const int tp = jcp.t_pad;
- const int lp = jcp.l_pad;
-
- if (jcp.outer_threading && sh == 1 && sw == 1 && dh == 1 && dw == 1) {
- /* im[ih][iw][ic] --> imtr[ic][ih][iw] --> col[kh][kw][ic][oh][ow] */
- const int hp = hs - tp;
- const int wp = ws - lp;
- const int ih_start = limit(0, jcp.ih, hp);
- const int ih_end = limit(0, jcp.ih, hp + hb + jcp.kh);
- const int iw_start = limit(0, jcp.iw, wp);
- const int iw_end = limit(0, jcp.iw, wp + wb + jcp.kw);
-
- const int ihb = ih_end - ih_start;
- const int iwb = iw_end - iw_start;
-
- const int imtr_ic_stride = ihb * iwb;
- const ptrdiff_t imtr_idx_shift = ih_start * iwb + iw_start;
- for (int ic = 0; ic < jcp.ic; ic++) {
- const ptrdiff_t imtr_idx_ic = ic * imtr_ic_stride - imtr_idx_shift;
- for (int ih = ih_start; ih < ih_end; ih++) {
- const ptrdiff_t im_idx_ih = ic + ih * im_ih_stride;
- const ptrdiff_t imtr_idx_ih = imtr_idx_ic + ih * iwb;
- for (int iw = iw_start; iw < iw_end; iw++)
- imtr[imtr_idx_ih + iw] = im[im_idx_ih + iw * im_iw_stride];
- }
- }
-
- const int col_ic_str = hb * wb;
- const int col_kw_stride = jcp.ic * col_ic_str;
- const int col_kh_stride = jcp.kw * col_kw_stride;
-
- const int oh_init = ih_start - hp;
- const int ow_init = iw_start - wp;
- for (int kh = 0; kh < jcp.kh; kh++) {
- const ptrdiff_t col_idx_kh = kh * col_kh_stride;
- const int oh_kh = oh_init - kh;
- const int oh_start = limit(0, hb, oh_kh);
- const int oh_end = limit(0, hb, oh_kh + ihb);
- for (int kw = 0; kw < jcp.kw; kw++) {
- const ptrdiff_t col_idx_kw
- = col_idx_kh + kw * jcp.ic * col_ic_str;
- const int ow_kw = ow_init - kw;
- const int imtr_shift = oh_kh * iwb + ow_kw;
- const int ow_start = limit(0, wb, ow_kw);
- const int ow_end = limit(0, wb, ow_kw + iwb);
- for (int ic = 0; ic < jcp.ic; ic++) {
- const ptrdiff_t col_idx_ic = col_idx_kw + ic * col_ic_str;
- const int imtr_idx_ic = ic * imtr_ic_stride - imtr_shift;
- for (int oh = 0; oh < oh_start; oh++) {
- const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb;
- for (int ow = 0; ow < wb; ++ow)
- col[col_idx_oh + ow] = shift;
- }
- for (int oh = oh_start; oh < oh_end; oh++) {
- const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb;
- const ptrdiff_t imtr_idx_oh = imtr_idx_ic + oh * iwb;
- for (int ow = 0; ow < ow_start; ++ow)
- col[col_idx_oh + ow] = shift;
- for (int ow = ow_start; ow < ow_end; ++ow)
- col[col_idx_oh + ow]
- = imtr[imtr_idx_oh + ow] + shift;
- for (int ow = ow_end; ow < wb; ++ow)
- col[col_idx_oh + ow] = shift;
- }
- for (int oh = oh_end; oh < hb; oh++) {
- const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb;
- for (int ow = 0; ow < wb; ++ow)
- col[col_idx_oh + ow] = shift;
- }
- }
- }
- }
- } else {
- parallel_nd(jcp.kh, jcp.kw, jcp.ic, hb,
- [&](int kh, int kw, int ic, int oh) {
- const int hp = tp - kh * dh;
- const int ih = (oh + hs) * sh - hp;
- const ptrdiff_t col_idx_base
- = (((kh * jcp.kw + kw) * jcp.ic + ic) * hb + oh) * wb;
- if (ih < 0 || ih >= jcp.ih)
- for (int ow = 0; ow < wb; ow++)
- col[col_idx_base + ow] = shift;
- else {
- const int wp = lp - kw * dw;
- const int ow_start = limit(0, wb, div_up(wp, sw) - ws);
- const int ow_end
- = limit(0, wb, div_up(jcp.iw + wp, sw) - ws);
- for (int ow = 0; ow < ow_start; ow++)
- col[col_idx_base + ow] = shift;
- const int iw_base = ws * sw - wp;
- const ptrdiff_t im_idx_base = ih * im_ih_stride + ic;
- for (int ow = ow_start; ow < ow_end; ow++) {
- const int iw = iw_base + ow * sw;
- const ptrdiff_t im_idx
- = im_idx_base + iw * im_iw_stride;
- col[col_idx_base + ow] = im[im_idx] + shift;
- }
- for (int ow = ow_end; ow < wb; ow++)
- col[col_idx_base + ow] = shift;
- }
- });
- }
-}
-
-template void im2col_u8<int8_t>(const jit_gemm_conv_conf_t &jcp,
- const int8_t *__restrict im, int8_t *__restrict imtr,
- uint8_t *__restrict col, int hs, int hb, int ws, int wb);
-template void im2col_u8<uint8_t>(const jit_gemm_conv_conf_t &jcp,
- const uint8_t *__restrict im, uint8_t *__restrict imtr,
- uint8_t *__restrict col, int hs, int hb, int ws, int wb);
-
-/* im[ih][iw][ic] <-- col2im_s32(col[oh][ow][kh][kw][ic]) */
-void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col,
- int32_t *__restrict im)
-{
- parallel(0, [&](const int ithr, const int nthr) {
- int h_nthr = nstl::min(jcp.ih, nthr);
- int w_nthr = nstl::min(jcp.iw, nthr / h_nthr);
- int h_ithr = 1, h_s = 0, h_e = 0, w_ithr = 1, w_s = 0, w_e = 0;
- if (ithr < h_nthr * w_nthr) {
- h_ithr = ithr / w_nthr;
- w_ithr = ithr % w_nthr;
- balance211(jcp.ih, h_nthr, h_ithr, h_s, h_e);
- balance211(jcp.iw, w_nthr, w_ithr, w_s, w_e);
- } else {
- h_ithr = w_ithr = -ithr;
- h_s = h_e = w_s = w_e = -1;
- }
-
- for (int ih = h_s; ih < h_e; ++ih) {
- for (int iw = w_s; iw < w_e; ++iw) {
- PRAGMA_OMP_SIMD()
- for (int ic = 0; ic < jcp.ic; ++ic) {
- im[(ih * jcp.iw + iw) * jcp.ic + ic] = 0;
- }
- }
- }
-
- // TODO: reduce region: [0.. oh] --> [h_s * sh .. h_e * sh]
- for (int oh = 0; oh < jcp.oh; ++oh) {
- for (int ow = 0; ow < jcp.ow; ++ow) {
- for (int kh = 0; kh < jcp.kh; ++kh) {
- const int ih = oh * jcp.stride_h
- - jcp.t_pad + kh * (1 + jcp.dilate_h);
- if (ih < h_s || ih >= h_e) continue;
-
- for (int kw = 0; kw < jcp.kw; ++kw) {
- const int iw = ow * jcp.stride_w
- - jcp.l_pad + kw * (1 + jcp.dilate_w);
- if (iw < w_s || iw >= w_e) continue;
-
- const size_t col_idx = (((oh * jcp.ow + ow) * jcp.kh
- + kh) * jcp.kw + kw) * jcp.ic;
- const size_t im_idx
- = (ih * jcp.iw + iw) * jcp.ic;
- PRAGMA_OMP_SIMD()
- for (int ic = 0; ic < jcp.ic; ++ic) {
- im[im_idx + ic] += col[col_idx + ic];
- }
- }
- }
- }
- }
- });
-}
-
-void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im,
- int od)
-{
- parallel_nd(jcp.ic, [&](int ic) {
- const float *__restrict col_ = col + (size_t)ic * jcp.ks * jcp.os;
- float *__restrict im_ic = im + (size_t)ic * jcp.ih * jcp.iw * jcp.id;
-
- int id = od * jcp.stride_d - jcp.f_pad;
- for (int kd = 0; kd < jcp.kd; ++kd) {
- if (id < 0 || id >= jcp.id) {
- col_ += jcp.kh * jcp.kw * jcp.os;
- id += (1 + jcp.dilate_d);
- continue;
- }
-
- float *__restrict im_ = im_ic + id * jcp.ih * jcp.iw;
-
- for (int oh = 0; oh < jcp.oh; ++oh) {
- for (int kh = 0; kh < jcp.kh; ++kh) {
- const int ih = oh * jcp.stride_h - jcp.t_pad
- + kh * (1 + jcp.dilate_h);
- if (ih < 0 || ih >= jcp.ih) continue;
-
- for (int ow = 0; ow < jcp.ow; ++ow) {
- for (int kw = 0; kw < jcp.kw; ++kw) {
- const int iw = ow * jcp.stride_w - jcp.l_pad
- + kw * (1 + jcp.dilate_w);
- if (iw < 0 || iw >= jcp.iw) continue;
-
- const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow;
- const size_t im_idx = ih*jcp.iw + iw;
- im_[im_idx] += col_[col_idx];
- }}
- }}
-
- col_ += jcp.kh * jcp.kw * jcp.os;
- id += (1 + jcp.dilate_d);
- }
- });
-}
-
-void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im) {
- const size_t col_step = jcp.ks * jcp.os;
- const size_t im_step = jcp.ih * jcp.iw;
- const int iS = jcp.ih * jcp.iw;
-
- parallel_nd(jcp.ic, [&](int ic) {
- float *__restrict im_ = im + ic * im_step;
- const float *__restrict col_ = col + ic * col_step;
- PRAGMA_OMP_SIMD()
- for (int is = 0; is < iS; ++is) im_[is] = 0.;
-
- for (int kh = 0; kh < jcp.kh; ++kh) {
- for (int oh = 0; oh < jcp.oh; ++oh) {
- const int ih =
- oh * jcp.stride_h - jcp.t_pad + kh * (1 + jcp.dilate_h);
- if (ih < 0 || ih >= jcp.ih) continue;
-
- for (int kw = 0; kw < jcp.kw; ++kw) {
- for (int ow = 0; ow < jcp.ow; ++ow) {
- const int iw =
- ow * jcp.stride_w - jcp.l_pad + kw * (1 + jcp.dilate_w);
- if (iw < 0 || iw >= jcp.iw) continue;
-
- const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow;
- const size_t im_idx = ih*jcp.iw + iw;
- im_[im_idx] += col_[col_idx];
- }
- }
- }
- }
- });
-}
-
-status_t init_conf(jit_gemm_conv_conf_t &jcp,
- memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d, int max_threads) {
- const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
- const int ndims = src_d.ndims();
- const int is_1d = ndims == 3;
- const int is_3d = ndims == 5;
-
- jcp.prop_kind = cd.prop_kind;
-
- jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
- jcp.mb = src_d.dims()[0];
-
- jcp.oc = dst_d.dims()[1] / jcp.ngroups;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
- jcp.id = is_3d ? src_d.dims()[2] : 1;
- jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
- jcp.iw = src_d.dims()[ndims - 1];
- jcp.od = is_3d ? dst_d.dims()[2] : 1;
- jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
- jcp.ow = dst_d.dims()[ndims - 1];
-
- jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
- jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
- jcp.kw = weights_d.dims()[with_groups + ndims - 1];
-
- jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
- jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
- jcp.l_pad = cd.padding[0][ndims - 3];
-
- jcp.stride_d = is_3d ? cd.strides[0] : 1;
- jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
- jcp.stride_w = cd.strides[ndims - 3];
-
- jcp.dilate_d = is_3d ? cd.dilates[0] : 0;
- jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
- jcp.dilate_w = cd.dilates[ndims - 3];
-
- jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef
- || cd.diff_bias_desc.format_kind != format_kind::undef;
-
- jcp.is = jcp.ih * jcp.iw;
- jcp.os = jcp.oh * jcp.ow;
- jcp.ks = jcp.kh * jcp.kw * jcp.kd;
-
- jcp.signed_input = src_d.data_type() == data_type::s8;
-
- jcp.im2col_sz = !everyone_is(true,
- jcp.ow == jcp.iw, jcp.oh == jcp.ih, jcp.od == jcp.id,
- jcp.stride_w == 1, jcp.stride_h == 1, jcp.stride_d == 1,
- jcp.ks == 1, !jcp.signed_input)
- ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os : 0;
-
- jcp.outer_threading = false;
-
- bool is_int8_conv = utils::one_of(src_d.data_type(), s32, s8, u8)
- && weights_d.data_type() == s8;
-
- const int vlen = mayiuse(avx512_common)
- ? cpu_isa_traits<avx512_common>::vlen
- : mayiuse(avx)
- ? cpu_isa_traits<avx>::vlen
- : mayiuse(sse42) ? cpu_isa_traits<sse42>::vlen : 4;
- const int simd_w = vlen / (is_int8_conv ? 1 : 4);
-
- const bool is_bwd_d = jcp.prop_kind == backward_data;
- const bool is_bwd_w = jcp.prop_kind == backward_weights;
- const bool is_fwd = !is_bwd_d && !is_bwd_w;
- jcp.oh_block = is_fwd ? jcp.oh : jcp.ih;
- jcp.ow_block = is_fwd ? jcp.ow : jcp.iw;
-
- using namespace memory_tracking::names;
- bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
-
- // TODO: maybe mitigate blocking restriction
- const int wei_size = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
- const int L2 = get_cache_size(2, true)
- / (is_int8_conv ? sizeof(int8_t) : sizeof(float));
- bool is_blocking_applicable = true
- && is_fwd && jcp.im2col_sz
- && jcp.id == 1 && jcp.od == 1
- && jcp.dilate_h == 0 && jcp.dilate_w == 0
- && !is_depthwise
- && wei_size < L2/2;
- if (is_blocking_applicable) {
- // looking for oh and ow blocking
- int h_block{ jcp.oh_block }, w_block{ jcp.ow_block };
- const int ic = jcp.ic;
- const int oc = jcp.oc;
- const int iw = jcp.iw;
- const int ow = jcp.ow;
- const int oh = jcp.oh;
- const int os = oh * ow;
-
- // 1. cache requirement
- int row_size = ic * ow * jcp.ks + 2 * (ic * iw + oc * ow);
- if (is_int8_conv) {
- // Heuristic rule: gemm needed a lot of memory for internal usage
- row_size *= 5;
- // memory for accumulators
- row_size += oc * ow * sizeof(uint32_t);
- // memory for transposition
- row_size += ic * iw;
- }
-
- h_block = nstl::max(1, nstl::min(oh, div_up(L2, row_size)));
- if (h_block == 1) {
- int col_size = ic * jcp.ks + 2 * (ic + oc);
- if (is_int8_conv) {
- col_size *= 5;
- col_size += oc * sizeof(uint32_t);
- col_size += ic;
- }
- w_block = nstl::max(1, nstl::min(ow, div_up(L2, col_size)));
- }
-
- // 2. threading requirement
- if (h_block != oh)
- h_block = nstl::max(1, rnd_dn(h_block, 4));
- if (w_block != ow)
- w_block = nstl::max(1, rnd_dn(w_block, simd_w));
-
- float thr_eff = 0.f;
- float thr_eff_treshold = 0.9f;
- if (w_block == ow) {
- do {
- int nb_h = div_up(oh, h_block);
- size_t work = jcp.ngroups * jcp.mb * jcp.od * nb_h;
- float disb = (float)oh / rnd_up(oh, h_block);
- thr_eff = (float)work / rnd_up(work, max_threads);
- thr_eff = (thr_eff + disb) / 2.f;
- if (thr_eff >= thr_eff_treshold)
- break;
- h_block = rnd_dn(h_block - 4, 4);
- } while (h_block > 0);
- }
- if (thr_eff < thr_eff_treshold) // we didn't find suitable h_block
- {
- h_block = 1;
- int nb_h = oh;
- do {
- int nb_w = div_up(ow, w_block);
- size_t work_amount = jcp.ngroups * jcp.mb * nb_h * nb_w;
- float disb = (float)ow / rnd_up(ow, w_block);
- thr_eff = (float)work_amount / rnd_up(work_amount, max_threads);
- thr_eff = (thr_eff + disb) / 2.f;
- if (thr_eff > thr_eff_treshold)
- break;
- w_block = rnd_dn(w_block - simd_w, simd_w);
- } while (w_block > 0);
- }
- h_block = nstl::max(1, h_block);
- w_block = nstl::max(1, w_block);
- const size_t inner_work = div_up(os, simd_w) * div_up(oc, simd_w);
- const float inner_thr_eff
- = (float)inner_work / rnd_up(inner_work, max_threads);
- if (thr_eff >= inner_thr_eff / 2 && h_block > 0 && w_block > 0) {
- jcp.oh_block = h_block;
- jcp.ow_block = w_block;
- jcp.outer_threading = true;
- }
- // updating jcp.im2col_sz
- if (jcp.oh_block != 1)
- jcp.ow_block = ow;
- jcp.im2col_sz = (ptrdiff_t)ic * jcp.ks * jcp.oh_block * jcp.ow_block;
- }
- // For threading selection in bwd_d we do:
- // 1. Rough estimation of efficiency for inner and outer threading.
- // 2. Gemm size estimation in assumption that it does not work
- // so effectively for small sizes.
- // 64K - this is heuristic gemm size per thread threshold.
- const int gemm_thrld = 64 * 1024;
-
- if (is_int8_conv) {
- if (is_fwd) {
- if (!jcp.outer_threading) {
- bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
- const size_t outer_work = jcp.ngroups * jcp.mb;
- const float outer_thr_eff
- = (float)outer_work / rnd_up(outer_work, max_threads);
- const size_t inner_work
- = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
- const float inner_thr_eff
- = (float)inner_work / rnd_up(inner_work, max_threads);
- jcp.outer_threading = (is_depthwise
- || (jcp.is / max_threads < 64 && jcp.mb != 1))
- && (outer_thr_eff / inner_thr_eff >= 1.f
- || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
- }
- jcp.nthr = jcp.outer_threading ? max_threads : 1;
- scratchpad.book(key_conv_gemm_col,
- sizeof(int8_t) * jcp.nthr * jcp.im2col_sz);
- scratchpad.book(key_conv_int_dat_in_acc_dt,
- sizeof(int32_t) * jcp.nthr * jcp.oh_block * jcp.ow_block * jcp.oc);
- scratchpad.book(key_conv_gemm_imtr,
- sizeof(int8_t) * jcp.nthr * jcp.is * jcp.ic);
- } else if (is_bwd_d) {
- bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
- const size_t outer_work = jcp.ngroups * jcp.mb;
- const float outer_thr_eff
- = (float)outer_work / rnd_up(outer_work, max_threads);
- const size_t inner_work
- = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
- const float inner_thr_eff
- = (float)inner_work / rnd_up(inner_work, max_threads);
- jcp.outer_threading = (is_depthwise
- || (jcp.is / max_threads < 64 && jcp.mb != 1))
- && (outer_thr_eff / inner_thr_eff >= 1.f
- || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
-
- jcp.nthr = jcp.outer_threading ? max_threads : 1;
- scratchpad.book(key_conv_gemm_col,
- sizeof(int32_t) * jcp.nthr * jcp.im2col_sz);
- scratchpad.book(key_conv_int_dat_in_acc_dt,
- sizeof(int32_t) * jcp.nthr * jcp.is * jcp.ic);
- } else if (is_bwd_w) {
- assert(!"unimplemented prop_kind");
- return status::unimplemented;
- }
- } else {
- if (is_fwd) {
- if (!jcp.outer_threading) {
- const size_t outer_work_amount = jcp.ngroups * jcp.mb * jcp.od;
- const float outer_thr_eff = (float)outer_work_amount
- / rnd_up(outer_work_amount, max_threads);
- const size_t inner_work_amount
- = div_up(jcp.os, simd_w) * div_up(jcp.oc, simd_w);
- const float inner_thr_eff = (float)inner_work_amount
- / rnd_up(inner_work_amount, max_threads);
- jcp.outer_threading = jcp.os / max_threads < 512
- && IMPLICATION(jcp.od == 1, jcp.mb != 1 || jcp.ngroups > 2)
- && (outer_thr_eff / inner_thr_eff >= 1.f
- || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
- }
- } else if (is_bwd_d) {
- const size_t outer_work_amount = jcp.ngroups * jcp.mb;
- const float outer_thr_eff = (float)outer_work_amount
- / rnd_up(outer_work_amount, max_threads);
- const size_t inner_work
- = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
- const float inner_thr_eff = (float)inner_work
- / rnd_up(inner_work, max_threads);
- jcp.outer_threading = (jcp.os / max_threads < 512 || jcp.ks < 64)
- && (jcp.mb != 1 || jcp.ngroups > 2)
- && (outer_thr_eff / inner_thr_eff >= 1.f
- || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
- } else if (is_bwd_w)
- jcp.outer_threading = jcp.os / max_threads < 256
- && (jcp.mb != 1 || jcp.ngroups > 2);
-
- jcp.nthr = jcp.outer_threading ? max_threads : 1;
- scratchpad.book(key_conv_gemm_col,
- sizeof(float) * jcp.nthr * jcp.im2col_sz);
-
- if (is_bwd_w) {
- jcp.need_wei_reduction = mkldnn_thr_syncable()
- ? jcp.mb != 1 && jcp.nthr != 1 : false;
- scratchpad.book(key_conv_wei_reduction,
- sizeof(float) * jcp.nthr * jcp.ngroups * weights_d.size());
- }
- }
-
- return status::success;
-}
-
-void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, int &ithr_g,
- int &nthr_g, int &ithr_mb, int &nthr_mb) {
- nthr_g = nstl::min(ngroups, nthr);
- nthr_mb = nstl::min(mb, nthr / nthr_g);
- if (ithr / nthr_mb >= ngroups) {
- ithr_g = ithr_mb = -1;
- } else {
- ithr_g = ithr / nthr_mb;
- ithr_mb = ithr % nthr_mb;
- }
-}
-
-void bwd_weights_reduction_par(int ithr, int nthr,
- const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws,
- float *weights) {
- const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
-
- size_t weights_start{0}, weights_end{0};
- balance211(weights_g_size, nthr, ithr, weights_start, weights_end);
-
- for (int i = 0; i < nthr; ++i) {
- const float *ws_i = weights_reduce_ws + i * weights_g_size;
- for (size_t s = weights_start; s < weights_end; ++s)
- weights[s] = (i == 0 ? 0 : weights[s]) + ws_i[s];
- }
-}
-
-};
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp
deleted file mode 100644
index e006789344..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp
+++ /dev/null
@@ -1,66 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_GEMM_CONVOLUTION_UTILS_HPP
-#define CPU_JIT_GEMM_CONVOLUTION_UTILS_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "mkldnn_thread.hpp"
-
-#include "cpu_convolution_pd.hpp"
-#include "cpu_engine.hpp"
-#include "jit_primitive_conf.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-namespace jit_gemm_convolution_utils {
-
-void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col,
- int od);
-void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im,
- float *__restrict col, int hs, int hb, int ws, int wb);
-template <typename T>
-void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im,
- T* __restrict imtr, uint8_t *__restrict col,
- int hs, int hb, int ws, int wb);
-
-void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col,
- int32_t *__restrict im);
-void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im,
- int od);
-void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im);
-
-status_t init_conf(jit_gemm_conv_conf_t &jcp,
- memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d, int max_threads);
-
-void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb,
- int &ithr_g, int &nthr_g, int &ithr_mb, int &nthr_mb);
-void bwd_weights_reduction_par(int ithr, int nthr,
- const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws,
- float *weights);
-
-}
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp
deleted file mode 100644
index 2872122f0d..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp
+++ /dev/null
@@ -1,156 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "mkldnn_thread.hpp"
-
-#include "gemm_inner_product.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::prop_kind;
-using namespace mkldnn::impl::data_type;
-using namespace mkldnn::impl::format_tag;
-using namespace mkldnn::impl::primitive_kind;
-
-template <impl::data_type_t data_type>
-void gemm_inner_product_fwd_t<data_type>::execute_forward(
- const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
-
- const int MB = pd()->MB();
- const int OC = pd()->OC();
- const int IC = pd()->IC_total_padded();
-
- bool wei_tr = !memory_desc_matches_one_of_tag(
- *pd()->weights_md(), hwio, dhwio, io);
-
- const auto &post_ops = pd()->attr()->post_ops_;
- const bool do_relu = post_ops.len_ == 1;
-
- float alpha = 1.0, beta = 0.0;
- extended_sgemm(wei_tr ? "T" : "N", "N", &OC, &MB, &IC, &alpha, weights,
- wei_tr ? &IC : &OC, src, &IC, &beta, dst, &OC, bias);
-
- if (do_relu) {
- float nslope = post_ops.entry_[0].eltwise.alpha;
- parallel_nd(MB, OC, [&](int mb, int oc) {
- size_t dst_off = mb * OC + oc;
- if (dst[dst_off] < 0)
- dst[dst_off] *= nslope;
- });
- }
-}
-
-template <impl::data_type_t data_type>
-void gemm_inner_product_bwd_data_t<data_type>::execute_backward_data(
- const exec_ctx_t &ctx) const {
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const int MB = pd()->MB();
- const int OC = pd()->OC();
- const int IC = pd()->IC_total_padded();
-
- bool wei_tr = memory_desc_matches_one_of_tag(
- *pd()->weights_md(), hwio, dhwio, io);
-
- float alpha = 1.0, beta = 0.0;
- extended_sgemm(wei_tr ? "T" : "N", "N", &IC, &MB, &OC, &alpha, weights,
- wei_tr ? &OC : &IC, diff_dst, &OC, &beta, diff_src, &IC);
-}
-
-template <impl::data_type_t data_type>
-void gemm_inner_product_bwd_weights_t<data_type>::execute_backward_weights(
- const exec_ctx_t &ctx) const {
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
- auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
-
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1));
-
- diff_dst += diff_dst_d.offset0();
-
- const int MB = pd()->MB();
- const int OC = pd()->OC();
- const int IC = pd()->IC_total_padded();
-
- bool wei_tr = memory_desc_matches_one_of_tag(
- *pd()->diff_weights_md(), hwio, dhwio, io);
-
- float alpha = 1.0, beta = 0.0;
- if (wei_tr)
- extended_sgemm("N", "T", &OC, &IC, &MB, &alpha, diff_dst, &OC, src, &IC,
- &beta, diff_weights, &OC);
- else
- extended_sgemm("N", "T", &IC, &OC, &MB, &alpha, src, &IC, diff_dst, &OC,
- &beta, diff_weights, &IC);
-
- if (diff_bias) {
- diff_bias += diff_bias_d.offset0();
- constexpr int blksize = 8;
- const int OC_blocks = OC / blksize;
- const int rem_OC = OC % blksize;
- parallel(0, [&](const int ithr, const int nthr) {
- int oc_st{0}, oc_e{0};
- balance211(OC_blocks, nthr, ithr, oc_st, oc_e);
- oc_st = oc_st * blksize;
- oc_e = oc_e * blksize;
-
- PRAGMA_OMP_SIMD()
- for (int oc = oc_st; oc < oc_e; ++oc) {
- diff_bias[oc] = diff_dst[oc];
- }
-
- for (int mb = 1; mb < MB; ++mb) {
- PRAGMA_OMP_SIMD()
- for (int oc = oc_st; oc < oc_e; ++oc) {
- diff_bias[oc] += diff_dst[mb * OC + oc];
- }
- }
-
- if (rem_OC != 0 && ithr == nthr-1) {
- for (int oc = OC_blocks * blksize; oc < OC; oc++)
- diff_bias[oc] = diff_dst[oc];
- for (int mb = 1; mb < MB; ++mb) {
- for (int oc = OC_blocks * blksize; oc < OC; oc++) {
- diff_bias[oc] += diff_dst[mb * OC + oc];
- }
- }
- }
- });
- }
-}
-
-template struct gemm_inner_product_fwd_t<data_type::f32>;
-template struct gemm_inner_product_bwd_data_t<data_type::f32>;
-template struct gemm_inner_product_bwd_weights_t<data_type::f32>;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp
deleted file mode 100644
index acf0a49b9a..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp
+++ /dev/null
@@ -1,157 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_GEMM_INNER_PRODUCT_HPP
-#define CPU_GEMM_INNER_PRODUCT_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "gemm/gemm.hpp"
-
-#include "cpu_inner_product_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <impl::data_type_t data_type>
-struct gemm_inner_product_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_inner_product_fwd_pd_t {
- using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t;
-
- DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_fwd_t);
-
- status_t init() {
- using namespace utils;
-
- bool ok = true
- && set_default_params() == status::success
- && is_fwd()
- && !has_zero_dim_memory()
- && everyone_is(data_type,
- src_md()->data_type,
- weights_md()->data_type,
- dst_md()->data_type,
- with_bias() ? weights_md(1)->data_type : data_type)
- && attr()->output_scales_.has_default_values()
- && attr()->post_ops_.len_ <= 1
- && IMPLICATION(attr()->post_ops_.len_ == 1,
- attr()->post_ops_.entry_[0].is_relu(true, false))
- && dense_gemm_consitency_check(src_md(), weights_md(),
- dst_md());
- return ok ? status::success : status::unimplemented;
- }
- };
-
- gemm_inner_product_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
- typedef typename prec_traits<data_type>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-template <impl::data_type_t data_type>
-struct gemm_inner_product_bwd_data_t: public cpu_primitive_t {
- struct pd_t: public cpu_inner_product_bwd_data_pd_t {
- using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t;
-
- DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_bwd_data_t);
-
- status_t init() {
- bool ok = true
- && set_default_params() == status::success
- && desc()->prop_kind == prop_kind::backward_data
- && !has_zero_dim_memory()
- && utils::everyone_is(data_type,
- diff_src_md()->data_type,
- weights_md()->data_type,
- diff_dst_md()->data_type)
- && attr()->has_default_values()
- && dense_gemm_consitency_check(diff_src_md(), weights_md(),
- diff_dst_md());
- return ok ? status::success : status::unimplemented;
- }
- };
-
- gemm_inner_product_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {}
- typedef typename prec_traits<data_type>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward_data(ctx);
- return status::success;
- }
-
-private:
- void execute_backward_data(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-template <impl::data_type_t data_type>
-struct gemm_inner_product_bwd_weights_t: public cpu_primitive_t {
- struct pd_t: public cpu_inner_product_bwd_weights_pd_t {
- using cpu_inner_product_bwd_weights_pd_t::cpu_inner_product_bwd_weights_pd_t;
-
- DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_bwd_weights_t);
-
- status_t init() {
- bool ok = true
- && set_default_params() == status::success
- && desc()->prop_kind == prop_kind::backward_weights
- && !has_zero_dim_memory()
- && utils::everyone_is(data_type,
- src_md()->data_type,
- diff_weights_md()->data_type,
- diff_dst_md()->data_type,
- with_bias() ? diff_weights_md(1)->data_type : data_type)
- && attr()->has_default_values()
- && dense_gemm_consitency_check(src_md(), diff_weights_md(),
- diff_dst_md());
-
- return ok ? status::success : status::unimplemented;
- }
- };
-
- gemm_inner_product_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {}
- typedef typename prec_traits<data_type>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward_weights(ctx);
- return status::success;
- }
-
-private:
- void execute_backward_weights(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp
deleted file mode 100644
index fed7e4d693..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp
+++ /dev/null
@@ -1,740 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "utils.hpp"
-#include "type_helpers.hpp"
-#include "mkldnn_thread.hpp"
-#include "math_utils.hpp"
-
-#include "simple_q10n.hpp"
-
-#include "gemm/gemm.hpp"
-#include "gemm_x8s8s32x_convolution.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::math;
-using namespace mkldnn::impl::memory_tracking::names;
-
-template <data_type_t src_type, data_type_t dst_type>
-void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::
-execute_forward(const exec_ctx_t &ctx) const {
- auto src_base = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
- auto wei_base = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto bia_base = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
- auto dst_base = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
-
- auto scratchpad = this->scratchpad(ctx);
-
- const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
-
- assert(IMPLICATION(
- jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow));
- assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1));
-
- parallel(jcp.nthr, [&](const int ithr, const int nthr) {
- execute_forward_thr(ithr, nthr, src_base, wei_base, bia_base, dst_base,
- scratchpad);
- });
-}
-
-template <data_type_t src_type, data_type_t dst_type>
-_gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::pp_ker_t(
- const pd_t *pd)
- : ker_(nullptr)
- , jcp_(pd->jcp_)
- , OC_(pd->jcp_.oc)
- , OS_(pd->jcp_.os)
- , bias_data_type_(data_type::undef)
- , bias_data_type_size_(0)
- , scale_idx_mult_(0)
- , do_bias_(false)
- , do_relu_(false)
- , do_sum_(false)
-{
- using namespace types;
-
- const auto dst_md = memory_desc_wrapper(pd->dst_md());
- dst_os_stride_ = dst_md.blk_off(0, 0, 0, 1);
-
- scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1));
-
- auto &post_ops = pd->attr()->post_ops_;
-
- int entry_idx = -1;
- for (int idx = 0; idx < post_ops.len_; ++idx) {
- const auto &e = post_ops.entry_[idx];
- if (e.is_relu(true, false)) {
- entry_idx = idx;
- break;
- }
- }
- do_relu_ = entry_idx >= 0;
-
- do_signed_scaling_ = jcp_.signed_input;
-
- do_sum_ = post_ops.contain(primitive_kind::sum, 0);
- do_bias_ = pd->with_bias();
- bias_data_type_ = pd->desc()->bias_desc.data_type;
- if (do_bias_) {
- assert(bias_data_type_ != data_type::undef);
- bias_data_type_size_ = data_type_size(bias_data_type_);
- }
- const size_t vlen_start
- = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
-
- for (size_t i = vlen_start; i > 0; i--) {
- if (OC_ % i == 0) {
- vlen_ = i;
- break;
- }
- }
-
- if (!mayiuse(avx512_core))
- // use fallback code for older CPUs
- return;
- else
- generate();
-}
-
-template <data_type_t src_type, data_type_t dst_type>
-void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::generate()
-{
- using namespace Xbyak;
- using namespace utils;
-
- // TODO: clean-up
- Reg64 reg_param = abi_param1;
- Reg64 reg_dst = rdx;
- Reg64 reg_acc = rax;
- Reg64 reg_bias = rbx;
- Reg64 reg_scales = rsi;
-
- Reg64 reg_len = r8;
- Reg64 reg_tmp = rcx; // intentional for shifting purposes
- Reg64 reg_oc_offset = r9;
- Reg64 reg_rem_mask_short = r10;
- Reg64 reg_rem_mask_vlen = r11;
- Opmask kreg_rem_mask_short = k1;
- Opmask kreg_rem_mask_vlen = k3;
- Opmask kreg_relu_cmp = k2;
-
- const size_t vlen = vlen_;
-
- Zmm vreg_zero = Zmm(0);
- Zmm vreg_scale = Zmm(1);
- Zmm vreg_nslope = Zmm(2);
- Zmm vreg_sum_scale = Zmm(3);
- Zmm vreg_signed_scale = Zmm(4);
-
- size_t def_unroll = 4;
- size_t max_unroll = 12;
- size_t zmm_step = 2;
- if (do_sum_) {
- max_unroll = 8;
- zmm_step = 3;
- }
-
- auto vreg_dst = [&](int idx) {
- return Zmm(5 + idx * zmm_step + 0);
- };
- auto vreg_bias = [&](int idx) {
- return Zmm(5 + idx * zmm_step + 1);
- };
- auto vreg_prev_dst = [&](int idx) {
- return Zmm(5 + idx * zmm_step + 2);
- };
-
- preamble();
-
-#define PARAM_OFF(x) offsetof(ker_args, x)
- mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
- mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]);
- mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]);
- mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]);
- mov(reg_len, ptr[reg_param + PARAM_OFF(len)]);
- mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]);
- vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]);
- vbroadcastss(vreg_sum_scale, ptr[reg_param + PARAM_OFF(sum_scale)]);
- vbroadcastss(vreg_signed_scale, ptr[reg_param + PARAM_OFF(signed_scale)]);
- if (scale_idx_mult_ == 0)
- vbroadcastss(vreg_scale, dword[reg_scales]);
-
-#undef PARAM_OFF
-
- mov(reg_rem_mask_vlen, 1);
- shl(reg_rem_mask_vlen, vlen);
- sub(reg_rem_mask_vlen, 1);
- kmovq(kreg_rem_mask_vlen, reg_rem_mask_vlen);
-
- if (do_relu_ || dst_type == data_type::u8)
- vxorps(vreg_zero, vreg_zero, vreg_zero);
-
- // Load accumulated value, convert to float, apply sum (if any),
- // bias (if any), scaling, and relu (if any);
- // then convert to destination type and store
- auto compute = [&](size_t offset, int idx, bool apply_mask) {
- auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)];
-
- if (scale_idx_mult_ > 0) {
- assert(scale_idx_mult_ == 1);
- auto scale_addr = ptr[reg_scales + offset * sizeof(float)];
- auto vreg_scale_ = vreg_scale;
- if (apply_mask)
- vreg_scale_ = vreg_scale_ | kreg_rem_mask_short;
- else
- vreg_scale_ = vreg_scale_ | kreg_rem_mask_vlen;
- vmovups(vreg_scale_, scale_addr);
- }
-
- auto vreg_dst_ = vreg_dst(idx);
- if (apply_mask)
- vreg_dst_ = vreg_dst_ | kreg_rem_mask_short;
- else
- vreg_dst_ = vreg_dst_ | kreg_rem_mask_vlen;
- vcvtdq2ps(vreg_dst_, acc_addr);
-
- if (do_signed_scaling_)
- vmulps(vreg_dst(idx), vreg_dst(idx), vreg_signed_scale);
-
- if (do_bias_) {
- auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_];
- auto vreg_bias_ = vreg_bias(idx);
- if (apply_mask)
- vreg_bias_ = vreg_bias_ | kreg_rem_mask_short;
- else
- vreg_bias_ = vreg_bias_ | kreg_rem_mask_vlen;
-
- switch (bias_data_type_) {
- case data_type::s8:
- vpmovsxbd(vreg_bias_, bias_addr);
- break;
- case data_type::u8:
- vpmovzxbd(vreg_bias_, bias_addr);
- break;
- case data_type::s32:
- case data_type::f32:
- vmovups(vreg_bias_, bias_addr);
- break;
- default: assert(!"unimplemented");
- }
- if (bias_data_type_ != data_type::f32)
- vcvtdq2ps(vreg_bias(idx), vreg_bias(idx));
- vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx));
- }
-
- vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale);
-
- auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)];
-
- if (do_sum_)
- {
- auto vreg_prev_dst_ = vreg_prev_dst(idx);
- if (apply_mask)
- vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_short;
- else
- vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_vlen;
-
- switch (dst_type) {
- case data_type::f32:
- case data_type::s32: vmovups(vreg_prev_dst_, dst_addr); break;
- case data_type::s8: vpmovsxbd(vreg_prev_dst_, dst_addr); break;
- case data_type::u8: vpmovzxbd(vreg_prev_dst_, dst_addr); break;
- default: assert(!"unsupported data type");
- }
- if (dst_type != data_type::f32)
- vcvtdq2ps(vreg_prev_dst(idx), vreg_prev_dst(idx));
-
- vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale);
- }
-
- if (do_relu_) {
- vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os);
- vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope);
- }
-
- if (dst_type != data_type::f32) {
- vcvtps2dq(vreg_dst(idx), vreg_dst(idx));
- }
-
- if (dst_type == data_type::u8)
- vpmaxsd(vreg_dst(idx), vreg_dst(idx), vreg_zero);
-
- switch (dst_type) {
- case data_type::s8:
- vpmovsdb(dst_addr, vreg_dst_);
- break;
- case data_type::u8:
- vpmovusdb(dst_addr, vreg_dst_);
- break;
- case data_type::f32:
- case data_type::s32:
- vmovups(dst_addr, vreg_dst_);
- break;
- default: assert(!"unimplemented");
- }
- };
-
- // Advance all pointers by an immediate
- auto advance_ptrs_imm = [&](size_t offset) {
- add(reg_dst, offset * sizeof(dst_data_t));
- add(reg_acc, offset * sizeof(acc_data_t));
- if (scale_idx_mult_) {
- assert(scale_idx_mult_ == 1);
- add(reg_scales, offset * sizeof(float));
- }
- if (do_bias_)
- add(reg_bias, offset * bias_data_type_size_);
- };
-
- // Advance all pointers by a value stored in a register
- auto advance_ptrs_reg = [&](Reg64 offset) {
- lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]);
- lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]);
- if (scale_idx_mult_) {
- assert(scale_idx_mult_ == 1);
- lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]);
- }
- if (do_bias_)
- lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]);
- };
-
- // Rewind pointers that point to data that is indexed by output channel
- // (bias or per-oc scaling factors)
- auto rewind_ptrs = [&]() {
- if (do_bias_)
- sub(reg_bias, OC_ * bias_data_type_size_);
- if (scale_idx_mult_) {
- assert(scale_idx_mult_ == 1);
- sub(reg_scales, OC_ * sizeof(float));
- }
- add(reg_dst, (dst_os_stride_ - OC_) * sizeof(dst_data_t));
- };
-
- // <--------- OC --------------->
- //
- // ^ ................+..............+-------------+.......................
- // | . : not accessed |Prologue loop| .
- // | . +--------------+-------------+ .
- // . | | .
- // O . | Main loop (unrolled) | .
- // S . | | .
- // . +--------------+-------------+ .
- // | . | Epilogue loop|not accessed : .
- // v ................+--------------+.............+.......................
-
- Label prologue_end;
- cmp(reg_oc_offset, 0);
- je(prologue_end, T_NEAR);
-
- // Prologue loop
- {
- mov(reg_tmp, OC_);
- sub(reg_tmp, reg_oc_offset);
- cmp(reg_tmp, reg_len);
- cmovg(reg_tmp, reg_len);
- sub(reg_len, reg_tmp);
-
- Label prologue_loop, prologue_loop_tail, prologue_loop_end;
- cmp(reg_tmp, vlen);
- jle(prologue_loop_tail, T_NEAR);
- L(prologue_loop); {
- compute(0, 0, false);
- advance_ptrs_imm(vlen);
- sub(reg_tmp, vlen);
- cmp(reg_tmp, vlen);
- jge(prologue_loop, T_NEAR);
- }
-
- L(prologue_loop_tail);
- mov(reg_rem_mask_short, 1);
- // cl == reg_tmp because reg_tmp <= vlen here
- shl(reg_rem_mask_short, cl);
- sub(reg_rem_mask_short, 1);
- jz(prologue_loop_end, T_NEAR);
-
- kmovq(kreg_rem_mask_short, reg_rem_mask_short);
- compute(0, 0, true);
- advance_ptrs_reg(reg_tmp);
-
- L(prologue_loop_end);
- rewind_ptrs();
- }
- L(prologue_end);
-
- // Main loop
- Label main_loop_end;
- {
- cmp(reg_len, OC_);
- jle(main_loop_end, T_NEAR);
-
- Label main_loop;
- L(main_loop); {
- size_t OC_loop, OC_tail;
- if (OC_ < max_unroll * vlen) {
- // Fully unroll small loops
- OC_loop = 0;
- OC_tail = OC_;
- }
- else {
- OC_loop = vlen * def_unroll;
- OC_tail = OC_ % OC_loop;
- }
-
- assert(!!OC_loop || !!OC_tail);
-
- if (OC_tail % vlen) {
- int vlen_tail = OC_tail % vlen;
- unsigned tail_mask = (1 << vlen_tail) - 1;
- mov(reg_tmp, tail_mask);
- kmovq(kreg_rem_mask_short, reg_tmp);
- }
-
- if (OC_loop) {
- mov(reg_tmp, rnd_dn(OC_, OC_loop));
- Label oc_loop;
- L(oc_loop); {
- for (size_t offset = 0; offset < OC_loop; offset += vlen)
- compute(offset, offset / vlen, false);
- advance_ptrs_imm(OC_loop);
- sub(reg_tmp, OC_loop);
- jnz(oc_loop);
- }
- }
-
- if (OC_tail) {
- for (size_t offset = 0; offset < OC_tail; offset += vlen) {
- bool use_mask = (offset + vlen) > OC_tail;
- compute(offset, offset / vlen, use_mask);
- }
- advance_ptrs_imm(OC_tail);
- }
-
- rewind_ptrs();
- sub(reg_len, OC_);
- cmp(reg_len, OC_);
- jge(main_loop, T_NEAR);
- }
- }
- L(main_loop_end);
-
- // Epilogue loop
- Label epilogue_end;
- {
- cmp(reg_len, 0);
- je(epilogue_end, T_NEAR);
-
- Label epilogue_loop, epilogue_loop_tail;
- cmp(reg_len, vlen);
- jle(epilogue_loop_tail, T_NEAR);
- L(epilogue_loop); {
- compute(0, 0, false);
- sub(reg_len, vlen);
- advance_ptrs_imm(vlen);
- cmp(reg_len, vlen);
- jge(epilogue_loop, T_NEAR);
- }
-
- L(epilogue_loop_tail);
- mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift
- mov(reg_rem_mask_short, 1);
- shl(reg_rem_mask_short, cl); // reg_tmp == rcx and reg_tail < vlen
- sub(reg_rem_mask_short, 1);
- jz(epilogue_end, T_NEAR);
- kmovq(kreg_rem_mask_short, reg_rem_mask_short);
- compute(0, 0, true);
- }
-
- L(epilogue_end);
-
- postamble();
-
- ker_ = getCode<decltype(ker_)>();
-}
-
-template <data_type_t src_type, data_type_t dst_type>
-void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::operator ()
- (dst_data_t *dst, const acc_data_t *acc, const char *bias,
- const float *scales, float nslope, float sum_scale, float signed_scale,
- int g, size_t start, size_t end)
-{
- using math::get_bias;
-
- if (end <= start)
- return;
-
- if (ker_) {
- // JIT
- ker_args args;
- size_t oc_offset = start % OC_;
- size_t os_offset = start / OC_;
- args.acc = acc + start;
- args.dst = dst + os_offset * dst_os_stride_ + oc_offset;
- args.bias = bias + (g * jcp_.oc + oc_offset) * bias_data_type_size_;
- args.scales = scales + scale_idx_mult_ * (g * jcp_.oc + oc_offset);
- args.nslope = nslope;
- args.sum_scale = sum_scale;
- args.signed_scale = signed_scale;
- args.len = end - start;
- args.oc_offset = oc_offset;
- ker_(&args);
- }
- else {
- // Fallback
- const size_t first_oc = start % OC_;
- const size_t last_oc = (end - 1) % OC_;
- const size_t first_os = start / OC_;
- const size_t last_os = (end - 1) / OC_;
- for (size_t os = first_os; os <= last_os; os++) {
- const size_t start_oc = (os == first_os) ? first_oc : 0;
- const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1;
- for (size_t oc = start_oc; oc <= end_oc; oc++) {
- const size_t acc_off = os * jcp_.oc + oc;
- const size_t dst_off = os * dst_os_stride_ + oc;
-
- float d = (float)(acc[acc_off]);
- if (jcp_.signed_input)
- d *= signed_scale;
-
- if (do_bias_)
- d += get_bias(bias, g * jcp_.oc + oc,
- bias_data_type_);
-
- d *= scales[(g * jcp_.oc + oc) * scale_idx_mult_];
- if (do_sum_)
- d += sum_scale * dst[dst_off];
- if (do_relu_ && d < 0)
- d *= nslope;
- dst[dst_off] = qz_a1b0<float, dst_data_t>()(d);
- }
- }
- }
-};
-
-template <data_type_t src_type, data_type_t dst_type>
-void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::
-execute_forward_thr(const int ithr, const int nthr, const src_data_t *src_base,
- const wei_data_t *wei_base, const char *bia_base, dst_data_t *dst_base,
- const memory_tracking::grantor_t &scratchpad) const {
- const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
-
- const auto src_md = memory_desc_wrapper(pd()->src_md());
- const size_t src_mb_stride = src_md.blk_off(1);
- const size_t src_g_stride = src_md.blk_off(0, 1) * jcp.ic;
-
- const auto wei_md = memory_desc_wrapper(pd()->weights_md(0));
- const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0;
-
- const auto dst_md = memory_desc_wrapper(pd()->dst_md());
- const size_t dst_mb_stride = dst_md.blk_off(1);
- const size_t dst_g_stride = dst_md.blk_off(0, 1) * jcp.oc;
-
- const float *scales = pd()->attr()->output_scales_.scales_;
-
- const auto &post_ops = pd()->attr()->post_ops_;
- const bool do_sum = post_ops.contain(primitive_kind::sum, 0);
- const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0;
-
- float nslope = 0;
- for (int idx = 0; idx < post_ops.len_; ++idx) {
- const auto &e = post_ops.entry_[idx];
- if (e.is_relu(true, false)) {
- nslope = e.eltwise.alpha;
- break;
- }
- }
-
- auto col = scratchpad.get<uint8_t>(key_conv_gemm_col)
- + (ptrdiff_t)ithr * jcp.im2col_sz;
- src_data_t *__restrict imtr = scratchpad.get<src_data_t>(key_conv_gemm_imtr)
- + (ptrdiff_t)ithr * jcp.is * jcp.ic;
- auto acc = scratchpad.get<acc_data_t>(key_conv_int_dat_in_acc_dt)
- + (ptrdiff_t)ithr * jcp.oh_block * jcp.ow_block * jcp.oc;
-
- const ptrdiff_t offset = (ptrdiff_t)jcp.ngroups * jcp.ks * jcp.ic * jcp.oc;
- const int32_t *_wei_comp = (const int32_t *)(wei_base + offset);
-
- int g{ 0 }, n{ 0 }, ohb{ 0 }, owb{ 0 };
- size_t start = 0, end = 0;
-
- const int nb_oh = div_up(jcp.oh, jcp.oh_block);
- const int nb_ow = div_up(jcp.ow, jcp.ow_block);
- const size_t work_amount = jcp.ngroups * jcp.mb * nb_oh * nb_ow;
- balance211(work_amount, nthr, ithr, start, end);
- nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb,
- nb_oh, owb, nb_ow);
-
- for (size_t iwork = start; iwork < end; ++iwork) {
- int oh = ohb * jcp.oh_block;
- int ow = owb * jcp.ow_block;
- const src_data_t *__restrict src = src_base + n * src_mb_stride
- + g * src_g_stride;
- const wei_data_t *__restrict wei = wei_base + g * wei_g_stride;
- dst_data_t *__restrict dst =
- dst_base + n * dst_mb_stride + g * dst_g_stride;
- const int32_t *wei_comp = _wei_comp + g * jcp.oc;
- const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh);
- const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow);
-
- if (jcp.im2col_sz)
- jit_gemm_convolution_utils::im2col_u8<src_data_t>(
- jcp, src, imtr, col, oh, h_step, ow, w_step);
-
- const int M = jcp.oc;
- const int K = jcp.ks * jcp.ic;
- const int N = h_step * w_step;
- const int LDA = M * jcp.ngroups;
- const int LDB = jcp.im2col_sz ? N : K;
- const char *BT = jcp.im2col_sz ? "T" : "N";
- const int8_t off_a = 0, off_b = 0;
- const int32_t off_c = 0;
- const float onef = 1.0, zerof = 0.0;
- gemm_s8x8s32("N", BT, jcp.signed_input ? "C" : "F",
- &M, &N, &K, &onef, wei, &LDA, &off_a,
- jcp.im2col_sz ? col : (uint8_t *)src, &LDB, &off_b,
- &zerof, acc, &M, jcp.signed_input ? wei_comp : &off_c);
-
- auto wei_adj_scale =
- (wei_md.extra().flags | memory_extra_flags::scale_adjust)
- ? wei_md.extra().scale_adjust : 1.f;
-
- parallel(0, [&](int ithr, int nthr) {
- size_t start, end;
- balance211((size_t)N * jcp.oc, nthr, ithr, start, end);
- (*pp_ker_)(dst + (oh * jcp.ow + ow) * pp_ker_->dst_os_stride_,
- acc, bia_base, scales, nslope, sum_scale,
- 1.f / wei_adj_scale, g, start, end);
- });
-
- nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh,
- owb, nb_ow);
- }
-}
-
-template <data_type_t dst_type>
-void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::
-execute_backward_data(const exec_ctx_t &ctx) const {
- auto diff_dst_base = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
- auto wei_base = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto bia_base = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
- auto diff_src_base = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC);
-
- auto scratchpad = this->scratchpad(ctx);
-
- const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
-
- parallel(jcp.nthr, [&](const int ithr, const int nthr) {
- execute_backward_data_thr(ithr, nthr, diff_dst_base, wei_base,
- bia_base, diff_src_base, scratchpad);
- });
-}
-
-template <data_type_t dst_type>
-void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::
-execute_backward_data_thr(const int ithr, const int nthr,
- const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base,
- const char *bia_base, diff_src_data_t *diff_src_base,
- const memory_tracking::grantor_t &scratchpad) const
-{
- const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
-
- const auto diff_dst_md = memory_desc_wrapper(pd()->diff_dst_md());
- const size_t diff_dst_mb_stride = diff_dst_md.blk_off(1);
- const size_t diff_dst_g_stride = diff_dst_md.blk_off(0, 1) * jcp.oc;
-
- const auto wei_md = memory_desc_wrapper(pd()->weights_md(0));
- const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0;
-
- const auto diff_src_md = memory_desc_wrapper(pd()->diff_src_md());
- const size_t diff_src_mb_stride = diff_src_md.blk_off(1);
- const size_t diff_src_g_stride = diff_src_md.blk_off(0, 1) * jcp.ic;
- const size_t diff_src_os_stride = diff_src_md.blk_off(0, 0, 0, 1);
-
- /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */
- const int scale_idx_mult = pd()->attr()->output_scales_.mask_ == (1 << 1);
- const float *scales = pd()->attr()->output_scales_.scales_;
- const size_t work_amount = jcp.ngroups * jcp.mb;
-
- auto col = scratchpad.get<acc_data_t>(key_conv_gemm_col)
- + (ptrdiff_t)ithr * jcp.im2col_sz;
- auto acc = scratchpad.get<acc_data_t>(key_conv_int_dat_in_acc_dt)
- + (ptrdiff_t)ithr * jcp.is * jcp.ic;
-
- int n{0}, g{0};
- size_t start = 0, end = 0;
-
- balance211(work_amount, nthr, ithr, start, end);
- nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups);
-
- for (size_t iwork = start; iwork < end; ++iwork) {
- const diff_dst_data_t *diff_dst = diff_dst_base
- + n * diff_dst_mb_stride + g * diff_dst_g_stride;
- const wei_data_t *wei = wei_base + g * wei_g_stride;
- diff_src_data_t *diff_src = diff_src_base + n * diff_src_mb_stride
- + g * diff_src_g_stride;
-
- const int M = jcp.ks * jcp.ic;
- const int N = jcp.os;
- const int K = jcp.oc;
- const int8_t off_a = 0, off_b = 0;
- const int32_t off_c = 0;
- const float onef = 1.0, zerof = 0.0;
- const int LD = K * jcp.ngroups;
-
- gemm_s8x8s32("T", "N", "F", &M, &N, &K, &onef,
- wei, &LD, &off_a, diff_dst, &LD, &off_b,
- &zerof, jcp.im2col_sz ? col : acc, &M, &off_c);
-
- if (jcp.im2col_sz)
- jit_gemm_convolution_utils::col2im_s32(jcp, col, acc);
-
- parallel_nd(jcp.is, jcp.ic, [&](int is, int ic) {
- float d = (float)acc[is * jcp.ic + ic];
- if (jcp.with_bias)
- d += get_bias(bia_base, g * jcp.ic + ic,
- pd()->desc()->bias_desc.data_type);
- d *= scales[(g * jcp.ic + ic) * scale_idx_mult];
- const size_t diff_src_off = is * diff_src_os_stride + ic;
- diff_src[diff_src_off] =
- qz_a1b0<float, diff_src_data_t>()(d);
- });
- nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
- }
-}
-
-using namespace data_type;
-
-template struct _gemm_x8s8s32x_convolution_fwd_t<u8, f32>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<u8, s32>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<u8, s8>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<u8, u8>;
-
-template struct _gemm_x8s8s32x_convolution_fwd_t<s8, f32>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<s8, s32>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<s8, s8>;
-template struct _gemm_x8s8s32x_convolution_fwd_t<s8, u8>;
-
-template struct _gemm_u8s8s32x_convolution_bwd_data_t<f32>;
-template struct _gemm_u8s8s32x_convolution_bwd_data_t<s32>;
-template struct _gemm_u8s8s32x_convolution_bwd_data_t<s8>;
-template struct _gemm_u8s8s32x_convolution_bwd_data_t<u8>;
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp
deleted file mode 100644
index 9e77b890d5..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp
+++ /dev/null
@@ -1,266 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef GEMM_X8S8S32X_CONVOLUTION_HPP
-#define GEMM_X8S8S32X_CONVOLUTION_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-
-#include "cpu_convolution_pd.hpp"
-#include "cpu_primitive.hpp"
-
-#include "jit_primitive_conf.hpp"
-#include "jit_generator.hpp"
-#include "gemm_convolution_utils.hpp"
-
-#include "gemm/gemm.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <data_type_t src_type, data_type_t dst_type>
-struct _gemm_x8s8s32x_convolution_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_fwd_pd_t {
- pd_t(engine_t *engine, const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const typename pd_t::base_class *hint_fwd_pd)
- : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(IGEMM_S8U8S32_IMPL_STR,
- _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>);
-
- status_t init() {
- using namespace data_type;
-
- bool ok = true
- && is_fwd()
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(src_type, s8, data_type::undef, dst_type,
- s32)
- && IMPLICATION(with_bias(), utils::one_of(
- desc()->bias_desc.data_type, f32, s32, s8, u8))
- && !has_zero_dim_memory()
- && set_default_formats_common(
- dat_tag(), format_tag::any, dat_tag())
- && post_ops_ok()
- && memory_desc_matches_tag(*src_md(), dat_tag())
- && memory_desc_matches_tag(*dst_md(), dat_tag())
- && set_or_check_wei_format();
- if (!ok) return status::unimplemented;
-
- auto scratchpad = scratchpad_registry().registrar();
- return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
- *desc(), src_md(), weights_md(0), dst_md(),
- mkldnn_get_max_threads());
- }
-
- jit_gemm_conv_conf_t jcp_;
-
- protected:
- format_tag_t dat_tag() const { return format_tag::nhwc; }
-
- bool set_or_check_wei_format() {
- using namespace format_tag;
-
- const bool is_src_s8 = src_md_.data_type == data_type::s8;
-
- memory_desc_t want_wei_md = weights_md_;
- memory_desc_init_by_tag(want_wei_md, with_groups() ? hwigo : hwio);
-
- if (is_src_s8) {
- want_wei_md.extra.flags = 0
- | memory_extra_flags::compensation_conv_s8s8
- | memory_extra_flags::scale_adjust;
- want_wei_md.extra.compensation_mask = (1 << 0)
- + (with_groups() ? (1 << 1) : 0);
- want_wei_md.extra.scale_adjust =
- mayiuse(avx512_core_vnni) ? 1.f : 0.5f;
- }
-
- if (weights_md_.format_kind == format_kind::any) {
- weights_md_ = want_wei_md;
- return true;
- }
-
- return weights_md_ == want_wei_md;
- }
-
- bool post_ops_ok() const {
- using namespace mkldnn::impl::primitive_kind;
- auto const &po = attr()->post_ops_;
- auto is_relu = [&](int idx) {
- return po.entry_[idx].is_relu(true, false); };
-
- switch (po.len_) {
- case 0: return true;
- case 1: return is_relu(0) || po.contain(sum, 0);
- case 2: return po.contain(sum, 0) && is_relu(1);
- default: return false;
- }
- return false;
- }
- };
-
- _gemm_x8s8s32x_convolution_fwd_t(const pd_t *apd)
- : cpu_primitive_t(apd, true), pp_ker_(nullptr)
- { pp_ker_ = new pp_ker_t(pd()); }
- ~_gemm_x8s8s32x_convolution_fwd_t() { delete pp_ker_; }
-
- typedef typename prec_traits<src_type>::type src_data_t;
- typedef typename prec_traits<data_type::s8>::type wei_data_t;
- typedef typename prec_traits<dst_type>::type dst_data_t;
- typedef typename prec_traits<data_type::s32>::type acc_data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- // XXX: this is throwaway code that will become unnecessary when we have a
- // sufficiently advanced igemm jit generator that supports quantization,
- // relu, and whatnot
- class pp_ker_t : jit_generator {
- public:
- DECLARE_CPU_JIT_AUX_FUNCTIONS(
- _gemm_x8s8s32x_convolution_fwd_t::pp_kernel);
- pp_ker_t(const pd_t *pd);
-
- void operator()(dst_data_t *dst, const acc_data_t *acc,
- const char *bias, const float *scales,
- float nslope, float sum_scale, float signed_scale,
- int g, size_t start, size_t end);
-
- size_t dst_os_stride_;
-
- private:
- void generate();
-
- struct ker_args {
- dst_data_t *dst;
- const acc_data_t *acc;
- const char *bias;
- const float *scales;
- float nslope;
- float sum_scale;
- float signed_scale;
- size_t len;
- size_t oc_offset;
- };
- void(*ker_)(const ker_args *args);
-
- const jit_gemm_conv_conf_t &jcp_;
- size_t OC_;
- size_t OS_;
- data_type_t bias_data_type_;
- size_t bias_data_type_size_;
- size_t scale_idx_mult_;
- bool do_bias_;
- bool do_relu_;
- bool do_sum_;
- bool do_signed_scaling_;
- size_t vlen_;
- };
-
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- void execute_forward(const exec_ctx_t &ctx) const;
- void execute_forward_thr(const int ithr, const int nthr,
- const src_data_t *src_base, const wei_data_t *wei_base,
- const char *bia_base, dst_data_t *dst_base,
- const memory_tracking::grantor_t &scratchpad) const;
-
- int nthr_;
- pp_ker_t *pp_ker_;
-
-};
-
-template <data_type_t dst_type>
-struct _gemm_u8s8s32x_convolution_bwd_data_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_bwd_data_pd_t{
- pd_t(engine_t *engine,
- const convolution_desc_t *adesc, const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(IGEMM_S8U8S32_IMPL_STR,
- _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>);
-
- status_t init() {
- using namespace data_type;
-
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_data
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(dst_type, s8, data_type::undef, u8, s32)
- && IMPLICATION(with_bias(), utils::one_of(
- desc()->bias_desc.data_type, f32, s32, s8, u8))
- && !has_zero_dim_memory()
- && set_default_formats_common(dat_tag(), wei_tag(), dat_tag())
- && attr()->post_ops_.has_default_values()
- && memory_desc_matches_tag(*diff_src_md(), dat_tag())
- && memory_desc_matches_tag(*diff_dst_md(), dat_tag())
- && memory_desc_matches_tag(*weights_md(), wei_tag());
- if (!ok) return status::unimplemented;
-
- auto scratchpad = scratchpad_registry().registrar();
- return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
- *desc(), diff_src_md(), weights_md(), diff_dst_md(),
- mkldnn_get_max_threads());
- }
-
- virtual bool support_bias() const override { return true; }
-
- jit_gemm_conv_conf_t jcp_;
-
- protected:
- format_tag_t dat_tag() const { return format_tag::nhwc; }
-
- format_tag_t wei_tag() const {
- return with_groups() ? format_tag::hwigo : format_tag::hwio;
- }
- };
-
- _gemm_u8s8s32x_convolution_bwd_data_t(const pd_t *apd)
- : cpu_primitive_t(apd, true) {}
-
- typedef typename prec_traits<data_type::u8>::type diff_dst_data_t;
- typedef typename prec_traits<data_type::s8>::type wei_data_t;
- typedef typename prec_traits<dst_type>::type diff_src_data_t;
- typedef typename prec_traits<data_type::s32>::type acc_data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward_data(ctx);
- return status::success;
- }
-
-private:
- void execute_backward_data(const exec_ctx_t &ctx) const;
- void execute_backward_data_thr(const int ithr, const int nthr,
- const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base,
- const char *bia_base, diff_src_data_t *diff_src_base,
- const memory_tracking::grantor_t &scratchpad) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp
deleted file mode 100644
index 1e435a233a..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp
+++ /dev/null
@@ -1,453 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "math_utils.hpp"
-#include "mkldnn_thread.hpp"
-#include "simple_q10n.hpp"
-
-#include "gemm/gemm.hpp"
-#include "gemm_x8s8s32x_inner_product.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace math;
-using namespace format_tag;
-using namespace memory_tracking::names;
-
-template<data_type_t src_type, data_type_t dst_type>
-gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type>::pp_kernel_t::pp_kernel_t(
- const pd_t *pd, bool dst_is_acc)
- : ker_(nullptr), OC_(pd->OC())
- , bias_data_type_(data_type::undef), bias_data_type_size_(0)
- , scale_idx_mult_(0), do_bias_(false), do_relu_(false)
-{
- using namespace types;
-
- scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1));
-
- auto &post_ops = pd->attr()->post_ops_;
- do_relu_ = post_ops.len_ == 1;
- do_bias_ = pd->with_bias();
- bias_data_type_ = pd->desc()->bias_desc.data_type;
- if (do_bias_) {
- assert(bias_data_type_ != data_type::undef);
- bias_data_type_size_ = data_type_size(bias_data_type_);
- }
-
- if (!mayiuse(avx512_core))
- // use fallback code for older CPUs since they do not have optimized
- // x8s8s32 GEMM anyways. The configuration variables above are used by
- // the fallback code.
- return;
- else
- generate();
-}
-
-template<data_type_t src_type, data_type_t dst_type>
-void gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type>::pp_kernel_t::generate()
-{
- using namespace Xbyak;
- using namespace utils;
-
- // TODO: clean-up
- Reg64 reg_param = abi_param1;
- Reg64 reg_dst = rdx;
- Reg64 reg_acc = rax;
- Reg64 reg_bias = rbx;
- Reg64 reg_scales = rsi;
-
- Reg64 reg_len = r8;
- Reg64 reg_tmp = rcx; // intentional for shifting purposes
- Reg64 reg_oc_offset = r9;
- Reg64 reg_rem_mask = r10;
- Opmask kreg_rem_mask = k1;
- Opmask kreg_relu_cmp = k2;
-
- const size_t vlen = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
-
- Zmm vreg_zero = Zmm(0);
- Zmm vreg_scale = Zmm(1);
- Zmm vreg_nslope = Zmm(2);
-
- auto vreg_dst = [&](int idx) { return Zmm(3 + idx * 2 + 0); };
- auto vreg_bias = [&](int idx) { return Zmm(3 + idx * 2 + 1); };
-
- preamble();
-
-#define PARAM_OFF(x) offsetof(ker_args, x)
- mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
- mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]);
- mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]);
- mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]);
- mov(reg_len, ptr[reg_param + PARAM_OFF(len)]);
- mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]);
- vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]);
- if (scale_idx_mult_ == 0)
- vbroadcastss(vreg_scale, dword[reg_scales]);
-#undef PARAM_OFF
-
- if (do_relu_ || dst_type == data_type::u8)
- vxorps(vreg_zero, vreg_zero, vreg_zero);
-
- // Load accumulated value, convert to float, apply bias (if any), scaling,
- // and relu (if any); then convert to destination type and store
- auto compute = [&](size_t offset, int idx, bool apply_mask) {
- auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)];
-
- if (scale_idx_mult_ > 0) {
- assert(scale_idx_mult_ == 1);
- auto scale_addr = ptr[reg_scales + offset * sizeof(float)];
- auto vreg_scale_ = vreg_scale;
- if (apply_mask)
- vreg_scale_ = vreg_scale_ | kreg_rem_mask;
- vmovups(vreg_scale, scale_addr);
- }
-
- auto vreg_dst_ = vreg_dst(idx);
- if (apply_mask)
- vreg_dst_ = vreg_dst_ | kreg_rem_mask;
- vcvtdq2ps(vreg_dst_, acc_addr);
-
- if (do_bias_) {
- auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_];
- auto vreg_bias_ = vreg_bias(idx);
- if (apply_mask)
- vreg_bias_ = vreg_bias_ | kreg_rem_mask;
-
- switch (bias_data_type_) {
- case data_type::s8:
- vpmovsxbd(vreg_bias_, bias_addr);
- break;
- case data_type::u8:
- vpmovzxbd(vreg_bias_, bias_addr);
- break;
- case data_type::s32:
- case data_type::f32:
- vmovups(vreg_bias_, bias_addr);
- break;
- default: assert(!"unimplemented");
- }
- if (bias_data_type_ != data_type::f32)
- vcvtdq2ps(vreg_bias(idx), vreg_bias(idx));
- vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx));
- }
-
- vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale);
- if (do_relu_) {
- vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os);
- vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope);
- }
-
- if (dst_type == data_type::u8)
- vmaxps(vreg_dst(idx), vreg_dst(idx), vreg_zero);
-
- if (dst_type != data_type::f32) {
- vcvtps2dq(vreg_dst(idx), vreg_dst(idx));
- }
-
- auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)];
- switch (dst_type) {
- case data_type::s8:
- vpmovsdb(dst_addr, vreg_dst_);
- break;
- case data_type::u8:
- vpmovusdb(dst_addr, vreg_dst_);
- break;
- case data_type::f32:
- case data_type::s32:
- vmovups(dst_addr, vreg_dst_);
- break;
- default: assert(!"unimplemented");
- }
- };
-
- // Advance all pointers by an immediate
- auto advance_ptrs_imm = [&](size_t offset) {
- add(reg_dst, offset * sizeof(dst_data_t));
- add(reg_acc, offset * sizeof(acc_data_t));
- if (scale_idx_mult_) {
- assert(scale_idx_mult_ == 1);
- add(reg_scales, offset * sizeof(float));
- }
- if (do_bias_)
- add(reg_bias, offset * bias_data_type_size_);
- };
-
- // Advance all pointers by a value stored in a register
- auto advance_ptrs_reg = [&](Reg64 offset) {
- lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]);
- lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]);
- if (scale_idx_mult_) {
- assert(scale_idx_mult_ == 1);
- lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]);
- }
- if (do_bias_)
- lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]);
- };
-
- // Rewind pointers that point to data that is indixed by output channel
- // (bias or per-oc scaling factors)
- auto rewind_ptrs = [&]() {
- if (do_bias_)
- sub(reg_bias, OC_ * bias_data_type_size_);
- if (scale_idx_mult_) {
- assert(scale_idx_mult_ == 1);
- sub(reg_scales, OC_ * sizeof(float));
- }
- };
-
- // <-------------------- OC ------------------------------->
- //
- // ^ +....................+----------------------------------+
- // | : not accessed | Prologue loop |
- // | +--------------------+----------------------------------+
- // | |
- // M | Main loop (unrolled) |
- // B | |
- // +--------------------------------+----------------------+
- // | | Epilogue loop | not accessed :
- // v +--------------------------------+......................+
-
- Label prologue_end;
- cmp(reg_oc_offset, 0);
- je(prologue_end, T_NEAR);
-
- // Prologue loop
- {
- mov(reg_tmp, OC_);
- sub(reg_tmp, reg_oc_offset);
- cmp(reg_tmp, reg_len);
- cmovg(reg_tmp, reg_len);
- sub(reg_len, reg_tmp);
-
- Label prologue_loop, prologue_loop_tail, prologue_loop_end;
- cmp(reg_tmp, vlen);
- jle(prologue_loop_tail, T_NEAR); // Skips for reg_tmp == 16 too (?)
- L(prologue_loop); {
- compute(0, 0, false);
- advance_ptrs_imm(vlen);
- sub(reg_tmp, vlen);
- cmp(reg_tmp, vlen);
- jge(prologue_loop, T_NEAR);
- }
-
- L(prologue_loop_tail);
- mov(reg_rem_mask, 1);
- shl(reg_rem_mask, cl); // cl == reg_tmp because reg_tmp <= vlen here
- sub(reg_rem_mask, 1);
- jz(prologue_loop_end, T_NEAR);
-
- kmovq(kreg_rem_mask, reg_rem_mask);
- compute(0, 0, true);
- advance_ptrs_reg(reg_tmp);
-
- L(prologue_loop_end);
- rewind_ptrs();
- }
- L(prologue_end);
-
- // Main loop
- Label main_loop_end;
- {
- cmp(reg_len, OC_);
- jle(main_loop_end, T_NEAR);
-
- Label main_loop;
- L(main_loop); {
- size_t def_unroll = 4;
- size_t max_unroll = 13;
-
- size_t OC_loop, OC_tail;
- if (OC_ < max_unroll * vlen) {
- // Fully unroll small loops
- OC_loop = 0;
- OC_tail = OC_;
- } else {
- OC_loop = vlen * def_unroll;
- OC_tail = OC_ % OC_loop;
- }
-
- assert(!!OC_loop || !!OC_tail);
-
- if (OC_tail % vlen) {
- int vlen_tail = OC_tail % vlen;
- unsigned tail_mask = (1 << vlen_tail) - 1;
- mov(reg_tmp, tail_mask);
- kmovq(kreg_rem_mask, reg_tmp);
- }
-
- if (OC_loop) {
- mov(reg_tmp, rnd_dn(OC_, OC_loop));
- Label oc_loop;
- L(oc_loop); {
- for (size_t offset = 0; offset < OC_loop; offset += vlen)
- compute(offset, offset / vlen, false);
- advance_ptrs_imm(OC_loop);
- sub(reg_tmp, OC_loop);
- jnz(oc_loop);
- }
- }
-
- if (OC_tail) {
- for (size_t offset = 0; offset < OC_tail; offset += vlen) {
- bool use_mask = (offset + vlen) > OC_tail;
- compute(offset, offset / vlen, use_mask);
- }
- advance_ptrs_imm(OC_tail);
- }
-
- rewind_ptrs();
- sub(reg_len, OC_);
- cmp(reg_len, OC_);
- jge(main_loop, T_NEAR);
- }
- }
- L(main_loop_end);
-
- // Epilogue loop
- Label epilogue_end;
- {
- cmp(reg_len, 0);
- je(epilogue_end, T_NEAR);
-
- Label epilogue_loop, epilogue_loop_tail;
- cmp(reg_len, vlen);
- jle(epilogue_loop_tail, T_NEAR); // Skips for reg_len == 16 (?)
- L(epilogue_loop); {
- compute(0, 0, false);
- sub(reg_len, vlen);
- advance_ptrs_imm(vlen);
- cmp(reg_len, vlen);
- jge(epilogue_loop, T_NEAR);
- }
-
- L(epilogue_loop_tail);
- mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift
- mov(reg_rem_mask, 1);
- shl(reg_rem_mask, cl); // reg_tmp == rcx and reg_tail < vlen == 16
- sub(reg_rem_mask, 1);
- jz(epilogue_end, T_NEAR);
- kmovq(kreg_rem_mask, reg_rem_mask);
- compute(0, 0, true);
- }
-
- L(epilogue_end);
-
- postamble();
-
- ker_ = getCode<decltype(ker_)>();
-}
-
-template<data_type_t src_type, data_type_t dst_type>
-void gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type>::pp_kernel_t::operator ()(
- dst_data_t *dst, const acc_data_t *acc,
- const char *bias, const float *scales, float nslope,
- size_t start, size_t end)
-{
- using math::get_bias;
-
- if (end <= start)
- return;
-
- if (ker_) {
- // JIT
- ker_args args;
- size_t oc_offset = start % OC_;
- args.dst = dst + start;
- args.acc = acc + start;
- args.bias = bias + oc_offset * bias_data_type_size_;
- args.scales = scales + scale_idx_mult_ * oc_offset;
- args.nslope = nslope;
- args.len = end - start;
- args.oc_offset = oc_offset;
- ker_(&args);
- } else {
- // Fallback
- size_t oc = start % OC_;
- for (size_t i = start; i < end; i++) {
- float d = (float)acc[i];
- float b = get_bias(bias, oc, bias_data_type_);
- d = d + b;
- d *= scales[oc * scale_idx_mult_];
- if (do_relu_ && d < 0)
- d *= nslope;
- dst[i] = qz_a1b0<float, dst_data_t>()(d);
- oc = (oc == OC_ - 1) ? 0 : oc + 1;
- }
- }
-};
-
-template <data_type_t src_type, data_type_t dst_type>
-void gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type>::execute_forward(
- const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
-
- const int MB = pd()->MB();
- const int OC = pd()->OC();
-
- bool wei_tr = memory_desc_matches_one_of_tag(
- *pd()->weights_md(), oiw, oihw, oidhw, oi);
-
- const int M = OC;
- const int N = MB;
- const int K = pd()->IC_total_padded();
- const int8_t off_a = 0, off_b = 0;
- const int32_t off_c = 0;
-
- const float *scales = pd()->attr()->output_scales_.scales_;
-
- const auto &post_ops = pd()->attr()->post_ops_;
- const bool do_relu = post_ops.len_ == 1;
- const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f;
-
- acc_data_t *acc = pd()->dst_is_acc_
- ? (acc_data_t *)dst
- : scratchpad(ctx).template get<acc_data_t>(key_iprod_int_dat_in_acc_dt);
-
- const float onef = 1.0, zerof = 0.0;
- gemm_s8x8s32(wei_tr ? "T" : "N", "N", "F", &M, &N, &K, &onef, weights,
- wei_tr ? &K : &M, &off_a, src, &K, &off_b, &zerof, acc, &M, &off_c);
-
- if (!pd()->attr()->has_default_values() || !pd()->dst_is_acc_
- || pd()->with_bias()) {
- const bool force_sequential = MB * OC < 2000;
- parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) {
- size_t start, end;
- balance211((size_t)OC * MB, nthr, ithr, start, end);
- (*pp_kernel_)(dst, acc, bias, scales, nslope, start, end);
- });
- }
-}
-
-using namespace data_type;
-
-template struct gemm_x8s8s32x_inner_product_fwd_t<u8, f32>;
-template struct gemm_x8s8s32x_inner_product_fwd_t<u8, s32>;
-template struct gemm_x8s8s32x_inner_product_fwd_t<u8, s8>;
-template struct gemm_x8s8s32x_inner_product_fwd_t<u8, u8>;
-template struct gemm_x8s8s32x_inner_product_fwd_t<s8, f32>;
-template struct gemm_x8s8s32x_inner_product_fwd_t<s8, s32>;
-template struct gemm_x8s8s32x_inner_product_fwd_t<s8, s8>;
-template struct gemm_x8s8s32x_inner_product_fwd_t<s8, u8>;
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp
deleted file mode 100644
index ac6a5c8f85..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp
+++ /dev/null
@@ -1,166 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef GEMM_X8S8S32X_INNER_PRODUCT_HPP
-#define GEMM_X8S8S32X_INNER_PRODUCT_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "gemm/gemm.hpp"
-#include "jit_generator.hpp"
-
-#include "cpu_inner_product_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <impl::data_type_t src_type, impl::data_type_t dst_type>
-struct gemm_x8s8s32x_inner_product_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_inner_product_fwd_pd_t {
- using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t;
-
- DECLARE_COMMON_PD_T(src_type == data_type::u8
- ? IGEMM_S8U8S32_IMPL_STR
- : IGEMM_S8S8S32_IMPL_STR,
- gemm_x8s8s32x_inner_product_fwd_t);
-
- status_t init() {
- using namespace data_type;
-
- bool ok = true
- && set_default_params() == status::success
- && is_fwd()
- && !has_zero_dim_memory()
- && src_md()->data_type == src_type
- && dst_md()->data_type == dst_type
- && weights_md()->data_type == s8
- && IMPLICATION(with_bias(), utils::one_of(
- weights_md(1)->data_type, f32, s32, s8, u8))
- && attr()->post_ops_.len_ <= 1
- && IMPLICATION(attr()->post_ops_.len_,
- attr()->post_ops_.entry_[0].is_relu(true, false))
- && dense_gemm_consitency_check(src_md(), weights_md(),
- dst_md());
- if (!ok) return status::unimplemented;
-
- dst_is_acc_ = utils::one_of(dst_type, s32, f32);
-
- init_scratchpad();
-
- return status::success;
- }
-
- bool dst_is_acc_;
-
- protected:
- status_t set_default_params() {
- using namespace format_tag;
- if (src_md_.format_kind == format_kind::any) {
- CHECK(memory_desc_init_by_tag(src_md_,
- utils::pick(ndims() - 2, nc, nwc, nhwc, ndhwc)));
- }
- if (dst_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(dst_md_, nc));
- if (weights_md_.format_kind == format_kind::any) {
- CHECK(memory_desc_init_by_tag(weights_md_,
- utils::pick(ndims() - 2, io, wio, hwio, dhwio)));
- }
- return inner_product_fwd_pd_t::set_default_params();
- }
-
- private:
- void init_scratchpad() {
- if (!dst_is_acc_) {
- auto scratchpad = scratchpad_registry().registrar();
- scratchpad.book(
- memory_tracking::names::key_iprod_int_dat_in_acc_dt,
- sizeof(acc_data_t) * MB() * OC());
- }
- }
- };
-
- gemm_x8s8s32x_inner_product_fwd_t(const pd_t *apd)
- : cpu_primitive_t(apd, true)
- { pp_kernel_ = new pp_kernel_t(apd, pd()->dst_is_acc_); }
- ~gemm_x8s8s32x_inner_product_fwd_t() { delete pp_kernel_; }
-
- typedef typename prec_traits<dst_type>::type data_t;
-
- typedef typename prec_traits<src_type>::type src_data_t;
- typedef typename prec_traits<data_type::s8>::type wei_data_t;
- typedef typename prec_traits<dst_type>::type dst_data_t;
- typedef typename prec_traits<data_type::s32>::type acc_data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- // XXX: this is throwaway code that will become unnecessary when we have a
- // sufficiently advanced igemm jit generator that supports quantization,
- // relu, and whatnot
- class pp_kernel_t: jit_generator {
- public:
- DECLARE_CPU_JIT_AUX_FUNCTIONS(
- gemm_x8s8s32x_inner_product_fwd_t::pp_kernel);
- pp_kernel_t(const pd_t *pd, bool dst_is_acc);
-
- void operator()(dst_data_t *dst, const acc_data_t *acc,
- const char *bias, const float *scales, float nslope,
- size_t start, size_t end);
- private:
- void generate();
-
- struct ker_args {
- dst_data_t *dst;
- const acc_data_t *acc;
- const char *bias;
- const float *scales;
- float nslope;
- size_t len;
- size_t oc_offset;
- };
- void (*ker_)(const ker_args *args);
-
- size_t OC_;
- data_type_t bias_data_type_;
- size_t bias_data_type_size_;
- size_t scale_idx_mult_;
- bool do_bias_;
- bool do_relu_;
- };
-
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- pp_kernel_t *pp_kernel_;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp
deleted file mode 100644
index 6fa251d465..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp
+++ /dev/null
@@ -1,674 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-* Copyright 2018 YANDEX LLC
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_memory.hpp"
-
-#include "jit_avx2_1x1_conv_kernel_f32.hpp"
-
-#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::prop_kind;
-using namespace mkldnn::impl::format_tag;
-using namespace mkldnn::impl::utils;
-
-using namespace Xbyak;
-
-void jit_avx2_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk)
-{
- mov(aux1_reg_bcast_data, reg_bcast_data);
- mov(aux_reg_output_data, reg_output_data);
- mov(bcast_loop_iter, reg_bcast_loop_work);
-
- Label bcast_loop, bcast_loop_tail;
-
- cmp(bcast_loop_iter, jcp.ur);
- jl(bcast_loop_tail, T_NEAR);
-
- L(bcast_loop); {
- assert(jcp.bcast_block % jcp.ur == 0);
- int num_substeps = jcp.bcast_block / jcp.ur;
- assert(num_substeps > 0 && num_substeps < 10);
- for (int i = 0; i < num_substeps; i++) {
- generate_reduce_loop(load_loop_blk, jcp.ur);
- if (i < num_substeps - 1) {
- add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
- add(aux_reg_output_data, jcp.bcast_loop_output_substep);
- } else {
- add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
- - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
- add(aux_reg_output_data, jcp.bcast_loop_output_step
- - (num_substeps - 1) * jcp.bcast_loop_output_substep);
- }
- }
- sub(bcast_loop_iter, jcp.bcast_block);
- cmp(bcast_loop_iter, jcp.bcast_block);
- jge(bcast_loop, T_NEAR);
- }
-
- L(bcast_loop_tail);
- if (jcp.ur_tail) {
- Label bcast_loop_tail_out;
- cmp(bcast_loop_iter, 0);
- jz(bcast_loop_tail_out, T_NEAR);
- generate_reduce_loop(load_loop_blk, jcp.ur_tail);
- L(bcast_loop_tail_out);
- }
-}
-
-void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop(
- int load_loop_blk, int ur)
-{
- auto vreg_load = [=](int i) {
- return Ymm(ur * load_loop_blk + i);
- };
-
- auto vreg_accum = [=](int i, int j) {
- return Ymm(j * load_loop_blk + i);
- };
-
- auto bias_ptr = [=](int i) {
- return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i];
- };
-
- auto bcast_ptr = [=](int u, int j) {
- assert(j < jcp.ur);
- assert(u <= jcp.reduce_loop_unroll);
- size_t offt;
- if (one_of(jcp.prop_kind,
- forward_training, forward_inference, backward_data))
- {
- assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data)
- ? jcp.oc_block : jcp.ic_block);
- auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is;
- offt = (u == jcp.reduce_loop_unroll)
- ? (height + j) * jcp.reduce_loop_unroll
- : j * jcp.reduce_loop_unroll + u;
- } else
- offt = u * jcp.ic_block + j;
- return ptr[aux_reg_bcast_data + sizeof(float) * offt];
- };
-
- auto load_ptr = [=](int u, int i) {
- size_t offt;
- size_t u0 = u % jcp.reduce_loop_unroll;
- size_t u1 = u / jcp.reduce_loop_unroll;
- switch (jcp.prop_kind) {
- case backward_data:
- offt = (i * jcp.oc_block + u0) * jcp.ic_block;
- break;
- case backward_weights:
- offt = (i * jcp.os + u0) * jcp.oc_block;
- break;
- default:
- offt = (i * jcp.ic + u0) * jcp.oc_block;
- }
- return ptr[aux_reg_load_data
- + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt];
- };
-
- auto output_ptr = [=](int i, int j) {
- switch (jcp.prop_kind) {
- case backward_data:
- return ptr[aux_reg_output_data +
- (i * jcp.is + j) * jcp.ic_block * sizeof(float)];
- case backward_weights:
- return ptr[aux_reg_output_data
- + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale
- + sizeof(float) * jcp.oc_block * j];
- default:
- return ptr[aux_reg_output_data +
- (i * jcp.os + j) * jcp.oc_block * sizeof(float)];
- }
- };
-
- auto init = [=]() {
- Label init_done, init_zero;
-
- if (jcp.with_bias && one_of(jcp.prop_kind, forward_training,
- forward_inference)) {
- test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
- jz(init_zero);
-
- for (int i = 0; i < load_loop_blk; i++)
- for (int j = 0; j < ur; ++j)
- vmovups(vreg_accum(i, j), bias_ptr(i));
- jmp(init_done);
- }
-
- L(init_zero);
- for (int i = 0; i < load_loop_blk; ++i)
- for (int j = 0; j < ur; ++j) {
- auto r = vreg_accum(i, j);
- vxorps(r, r, r);
- }
-
- L(init_done);
- for (int i = 0; i < load_loop_blk; ++i)
- vmovups(vreg_load(i), load_ptr(0, i));
- vbroadcastss(vreg_bcast, bcast_ptr(0, 0));
- };
-
- auto store = [=]() {
- Label store_noadd;
-
- if (!jcp.with_sum) {
- test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
- jnz(store_noadd, T_NEAR);
- }
-
- for (int j = 0; j < ur; ++j)
- for (int i = 0; i < load_loop_blk; ++i) {
- auto r = vreg_accum(i, j);
- vaddps(r, r, output_ptr(i, j));
- }
-
- L(store_noadd);
-
- if (jcp.with_eltwise) {
- assert(ur * load_loop_blk < 14);
-
- Label store_norelu;
- test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
- jz(store_norelu, T_NEAR);
-
- eltwise_injector_->compute_vector_range(0, ur * load_loop_blk);
-
- L(store_norelu);
- }
-
- for (int j = 0; j < ur; ++j)
- for (int i = 0; i < load_loop_blk; ++i) {
- vmovups(output_ptr(i, j), vreg_accum(i, j));
- }
- };
-
- auto fma_block = [=](bool last_block) {
- for (int u = 0; u < jcp.reduce_loop_unroll; ++u) {
- for (int j = 0; j < ur; ++j) {
- for (int i = 0; i < load_loop_blk; ++i) {
- if (mayiuse(avx2))
- vfmadd231ps(vreg_accum(i, j), vreg_load(i), vreg_bcast);
- else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support
- vmulps(vtmp, vreg_bcast, vreg_load(i));
- vaddps(vreg_accum(i, j), vreg_accum(i, j), vtmp);
- }
- if (j == ur - 1 && !(last_block
- && u == jcp.reduce_loop_unroll - 1))
- vmovups(vreg_load(i), load_ptr(u + 1, i));
- }
- if (j < ur - 1)
- vbroadcastss(vreg_bcast, bcast_ptr(u, j + 1));
- }
- if (!last_block || u < jcp.reduce_loop_unroll - 1)
- vbroadcastss(vreg_bcast, bcast_ptr(u + 1, 0));
- }
- };
-
- Label reduce_loop, reduce_loop_tail;
-
- mov(aux_reg_load_data, reg_load_data);
- mov(aux_reg_bcast_data, aux1_reg_bcast_data);
-
- init();
-
- mov(reduce_loop_iter, reg_reduce_loop_work);
- sub(reduce_loop_iter, jcp.reduce_loop_unroll);
- jle(reduce_loop_tail, T_NEAR);
-
- L(reduce_loop); {
- fma_block(false);
- add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
- add(aux_reg_load_data, jcp.reduce_loop_load_step);
- sub(reduce_loop_iter, jcp.reduce_loop_unroll);
- jg(reduce_loop, T_NEAR);
- }
-
- L(reduce_loop_tail);
- fma_block(true);
-
- store();
-}
-
-void jit_avx2_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk)
-{
- if (!jcp.with_bias || jcp.prop_kind != backward_weights)
- return;
-
- Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out;
- Label diff_bias_load;
-
- auto diff_bias_ptr = [=](int i) {
- return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)];
- };
-
- auto load_ptr = [=](int u, int i) {
- return ptr[aux_reg_load_data
- + (i * jcp.os + u) * jcp.oc_block * sizeof(float)];
- };
-
- auto diff_bias_reg = [=](int i) { return Ymm(i); };
-
- mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]);
- cmp(reg_diff_bias_data, 0);
- je(diff_bias_loop_out, T_NEAR);
-
- test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
- jz(diff_bias_load, T_NEAR);
-
- for (int i = 0; i < load_loop_blk; ++i) {
- auto r = diff_bias_reg(i);
- vxorps(r, r, r);
- }
- jmp(diff_bias_init_out, T_NEAR);
-
- L(diff_bias_load);
- for (int i = 0; i < load_loop_blk; ++i)
- vmovups(diff_bias_reg(i), diff_bias_ptr(i));
-
- L(diff_bias_init_out);
- mov(aux_reg_load_data, reg_load_data);
- mov(reduce_loop_iter, reg_reduce_loop_work);
- L(diff_bias_loop); {
- for(int u = 0; u < jcp.reduce_loop_unroll; ++u)
- for (int i = 0; i < load_loop_blk; ++i)
- vaddps(diff_bias_reg(i), diff_bias_reg(i), load_ptr(u, i));
- assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
- add(aux_reg_load_data, jcp.reduce_loop_load_step);
- sub(reduce_loop_iter, jcp.reduce_loop_unroll);
- jnz(diff_bias_loop, T_NEAR);
- }
-
- for (int i = 0; i < load_loop_blk; i++)
- vmovups(diff_bias_ptr(i), diff_bias_reg(i));
- add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
- mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
-
- L(diff_bias_loop_out);
-}
-
-void jit_avx2_1x1_conv_kernel_f32::generate()
-{
- preamble();
-
- mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
- mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
- mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
- if (jcp.with_bias) {
- if (jcp.prop_kind == backward_weights) {
- sub(rsp, stack_space_needed);
- mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]);
- mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
- } else
- mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
- }
-
- mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
- mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
- mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
- mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
- if (jcp.prop_kind == backward_weights)
- mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
-
- auto generate_load_loop_body = [=] (int load_loop_blk) {
- generate_bcast_loop(load_loop_blk);
- add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
- switch (jcp.prop_kind) {
- case forward_training:
- case forward_inference:
- add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
- add(reg_output_data,
- load_loop_blk * jcp.os * jcp.oc_block * sizeof(float));
- break;
- case backward_data:
- add(reg_output_data,
- load_loop_blk * jcp.is * jcp.ic_block * sizeof(float));
- break;
- case backward_weights:
- for (int i = 0; i < load_loop_blk; i++)
- add(reg_output_data, reg_output_stride);
- break;
- default:
- assert(!"invalid prop_kind");
- }
- sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
- };
-
- Label load_loop_blk_8;
- Label load_loop_blk_16;
- Label load_loop_blk_24;
- Label load_loop_blk_end;
-
- cmp(reg_load_loop_work, 8);
- jle(load_loop_blk_8, T_NEAR);
-
- cmp(reg_load_loop_work, 32);
- je(load_loop_blk_16, T_NEAR);
-
- cmp(reg_load_loop_work, 16);
- jle(load_loop_blk_16, T_NEAR);
-
- L(load_loop_blk_24); {
- generate_diff_bias_loop(3);
- generate_load_loop_body(3);
- cmp(reg_load_loop_work, 32);
- je(load_loop_blk_16);
- cmp(reg_load_loop_work, 24);
- jge(load_loop_blk_24);
- }
-
- cmp(reg_load_loop_work, 8);
- jle(load_loop_blk_8, T_NEAR);
-
- L(load_loop_blk_16); {
- generate_diff_bias_loop(2);
- generate_load_loop_body(2);
- cmp(reg_load_loop_work, 16);
- jge(load_loop_blk_16);
- }
-
- L(load_loop_blk_8); {
- cmp(reg_load_loop_work, 0);
- je(load_loop_blk_end, T_NEAR);
- generate_diff_bias_loop(1);
- generate_load_loop_body(1);
- }
-
- L(load_loop_blk_end);
-
- if (jcp.with_bias && jcp.prop_kind == backward_weights)
- add(rsp, 8);
-
- postamble();
-
- if (jcp.with_eltwise)
- eltwise_injector_->prepare_table();
-}
-
-bool jit_avx2_1x1_conv_kernel_f32::post_ops_ok(
- jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
- const auto &p = attr.post_ops_;
-
- auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
- auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
-
- switch (p.len_) {
- case 0: return true; // no post_ops
- case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
- case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
- default: return false;
- }
-
- return false;
-}
-
-status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr)
-{
- if (!mayiuse(avx)) return status::unimplemented;
-
- // TODO (Roma): this code is duplicated from the generic kernel; maybe the
- // configuration struct could do some stuff below
- const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
- const int ndims = src_d.ndims();
-
- jcp.prop_kind = cd.prop_kind;
-
- jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
- jcp.mb = src_d.dims()[0];
-
- jcp.oc = dst_d.dims()[1] / jcp.ngroups;
- jcp.oc_without_padding = jcp.oc;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
-
- jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
- jcp.iw = src_d.dims()[ndims - 1];
- jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
- jcp.ow = dst_d.dims()[ndims - 1];
-
- jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
- jcp.kw = weights_d.dims()[with_groups + ndims - 1];
-
- jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
- jcp.l_pad = cd.padding[0][ndims - 3];
-
- jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
- jcp.stride_w = cd.strides[ndims - 3];
-
- jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
-
- jcp.os = jcp.oh * jcp.ow;
- jcp.is = jcp.ih * jcp.iw;
-
- if (!post_ops_ok(jcp, attr))
- return status::unimplemented;
-
- const auto &p = attr.post_ops_;
- jcp.with_sum = p.find(primitive_kind::sum) != -1;
- const int eltwise_ind = p.find(primitive_kind::eltwise);
- jcp.with_eltwise = eltwise_ind != -1;
- if (jcp.with_eltwise) {
- jcp.eltwise = p.entry_[eltwise_ind].eltwise;
- if (!mayiuse(avx2) && jcp.eltwise.alg != alg_kind::eltwise_relu)
- return status::unimplemented;
- }
-
- const int is_bwd_d = jcp.prop_kind == backward_data;
-
- format_tag_t dat_tag = ndims == 3 ? nCw8c : nChw8c;
- format_tag_t wei_tag = with_groups
- ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o,
- gOIhw8o8i)
- : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o,
- OIhw8o8i);
-
- jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
- jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
- jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
-
- const int simd_w = 8;
-
- jcp.oc = rnd_up(jcp.oc, simd_w);
- jcp.ic = rnd_up(jcp.ic, simd_w);
-
- bool args_ok = true
- && jcp.ngroups == 1
- && jcp.src_tag == dat_tag
- && jcp.wei_tag == wei_tag
- && jcp.dst_tag == dat_tag;
- if (!args_ok) return status::unimplemented;
-
- args_ok = true
- && jcp.ih == jcp.oh && jcp.iw == jcp.ow
- && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
- && jcp.t_pad == 0 && jcp.l_pad == 0
- && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
- && jcp.kh == 1 && jcp.kw == 1;
- if (!args_ok) return status::unimplemented;
-
- // TODO: remove this restriction
- // optimized 1x1 bwd_w does not support Intel AVX
- if (jcp.prop_kind == backward_weights && !mayiuse(avx2))
- return status::unimplemented;
-
- jcp.ic_block = jcp.oc_block = simd_w;
-
- jcp.ur = mayiuse(avx2) ? 4 : 3; // Intel AVX support
-
- int load_blocking{ 0 };
- int load_blocking_max{ 0 };
- int bcast_blocking{ 0 };
- int bcast_blocking_max{ 0 };
- int reduce_blocking{ 0 };
-
- if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
- jcp.reduce_dim = jcp.ic;
- jcp.reduce_block = jcp.ic_block;
-
- jcp.load_dim = jcp.oc;
- jcp.load_block = jcp.oc_block;
-
- jcp.bcast_dim = jcp.is;
- jcp.bcast_block = jcp.ur;
-
- jcp.reduce_loop_unroll = jcp.reduce_block;
- jcp.reduce_loop_bcast_step
- = jcp.reduce_loop_unroll * jcp.is * sizeof(float);
- jcp.reduce_loop_load_step
- = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
-
- jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float);
- jcp.bcast_loop_output_substep = -1; // unused
- jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float);
- jcp.bcast_loop_bcast_substep = -1; // unused
-
- jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float);
- jcp.load_loop_iter_step = jcp.oc_block;
-
- load_blocking = 120; // assumes the kernel is jcp.ur x 3
- load_blocking_max = 144;
- bcast_blocking = 128; // affects load balancing across threads
- bcast_blocking_max = 192;
- reduce_blocking = 128; // affects L1$ utilization
- } else if (jcp.prop_kind == backward_data) {
- jcp.reduce_dim = jcp.oc;
- jcp.reduce_block = jcp.oc_block;
-
- jcp.load_dim = jcp.ic;
- jcp.load_block = jcp.oc_block;
-
- jcp.bcast_dim = jcp.os;
- jcp.bcast_block = jcp.ur;
-
- jcp.reduce_loop_unroll = jcp.reduce_block;
- jcp.reduce_loop_bcast_step
- = jcp.reduce_loop_unroll * jcp.os * sizeof(float);
- jcp.reduce_loop_load_step
- = jcp.reduce_loop_unroll * jcp.ic * sizeof(float);
-
- jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float);
- jcp.bcast_loop_output_substep = -1; // unused
- jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float);
- jcp.bcast_loop_bcast_substep = -1; // unused
-
- jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float);
- jcp.load_loop_iter_step = jcp.ic_block;
-
- load_blocking = 96; // assumes the kernel is jcp.ur x 3
- load_blocking_max = 144;
- bcast_blocking = 128; // affects load balancing across threads
- bcast_blocking_max = 196;
- reduce_blocking = 64; // affects L1$ utilization
- } else if (jcp.prop_kind == backward_weights) {
- jcp.reduce_dim = jcp.os;
- jcp.reduce_block = 1;
-
- jcp.load_dim = jcp.oc;
- jcp.load_block = jcp.oc_block;
-
- jcp.bcast_dim = jcp.ic;
- jcp.bcast_block = jcp.ic_block;
-
- jcp.reduce_loop_unroll = jcp.reduce_block;
- jcp.reduce_loop_bcast_step
- = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float);
- jcp.reduce_loop_load_step
- = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
-
- jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float);
- jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float);
- jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float);
- jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float);
-
- jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float);
- jcp.load_loop_iter_step = jcp.oc_block;
-
- /* --- */
-
- load_blocking = div_up(jcp.load_dim, jcp.load_block);
- while (true) {
- if (load_blocking <= 32) break;
- else if (load_blocking % 2 == 0) load_blocking /= 2;
- else if (load_blocking % 3 == 0) load_blocking /= 3;
- else break;
- }
- load_blocking *= jcp.load_block;
- load_blocking_max = load_blocking;
- assert(jcp.load_dim % load_blocking == 0);
-
- bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
- while (true) {
- if (bcast_blocking <= 9) break;
- else if (bcast_blocking % 2 == 0) bcast_blocking /= 2;
- else if (bcast_blocking % 3 == 0) bcast_blocking /= 3;
- else break;
- }
- bcast_blocking *= jcp.bcast_block;
- bcast_blocking_max = bcast_blocking;
- assert(jcp.bcast_dim % bcast_blocking == 0);
-
- reduce_blocking = 128; // affects L1$ utilization
- } else
- return status::unimplemented;
-
- assert(load_blocking);
- assert(load_blocking_max);
- assert(bcast_blocking);
- assert(bcast_blocking_max);
- assert(reduce_blocking);
-
- assert(jcp.bcast_block % jcp.ur == 0);
- jcp.ur_tail = jcp.bcast_dim % jcp.ur;
-
- jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
- jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
- jcp.nb_load_blocking = load_blocking / jcp.load_block;
- jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
- jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
-
- jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
- jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
- jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
-
- return status::success;
-}
-
-void jit_avx2_1x1_conv_kernel_f32::init_scratchpad(
- memory_tracking::registrar_t &scratchpad,
- const jit_1x1_conv_conf_t &jcp) {
- using namespace mkldnn::impl::memory_tracking::names;
-
- if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding)
- scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp
deleted file mode 100644
index bfdeb2b18d..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp
+++ /dev/null
@@ -1,110 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef JIT_AVX2_1x1_CONV_KERNEL_F32_HPP
-#define JIT_AVX2_1x1_CONV_KERNEL_F32_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-
-#include "cpu_memory.hpp"
-#include "jit_generator.hpp"
-#include "jit_primitive_conf.hpp"
-#include "jit_uni_eltwise.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct jit_avx2_1x1_conv_kernel_f32: public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_1x1_conv_kernel_f32)
-
- jit_avx2_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp,
- const primitive_attr_t &attr)
- : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr)
- {
- if (jcp.with_eltwise)
- eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx2>(this,
- jcp.eltwise);
-
- this->generate();
- jit_ker = (void (*)(jit_1x1_conv_call_s *))this->getCode();
- }
-
- ~jit_avx2_1x1_conv_kernel_f32() {
- delete eltwise_injector_;
- }
-
- static bool post_ops_ok(jit_1x1_conv_conf_t &jcp,
- const primitive_attr_t &attr);
-
- static status_t init_conf(jit_1x1_conv_conf_t &jcp,
- const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr);
-
- static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
- const jit_1x1_conv_conf_t &jcp);
-
- jit_1x1_conv_conf_t jcp;
- const primitive_attr_t &attr_;
- void (*jit_ker)(jit_1x1_conv_call_s *);
-
-private:
- using reg64_t = const Xbyak::Reg64;
- using ymm_t = const Xbyak::Ymm;
-
- reg64_t reg_bcast_data = rax;
- reg64_t reg_load_data = rsi;
- reg64_t reg_output_data = rbx;
- reg64_t aux_reg_bcast_data = rdx;
- reg64_t aux1_reg_bcast_data = abi_not_param1;
- reg64_t aux_reg_load_data = abi_param1;
- reg64_t aux_reg_output_data = rbp;
- reg64_t reg_load_loop_work = r9;
- reg64_t reg_bcast_loop_work = r10;
- reg64_t reg_reduce_loop_work = r11;
- reg64_t load_loop_iter = r13;
- reg64_t bcast_loop_iter = r14;
- reg64_t reduce_loop_iter = r15;
- reg64_t imm_addr64 = reduce_loop_iter;
- reg64_t reg_reduce_pos_flag = r8;
- reg64_t reg_output_stride = r12;
- reg64_t reg_bias_data = r12;
- reg64_t reg_diff_bias_data = bcast_loop_iter;
-
- int reg_diff_bias_data_stack_offt = 0;
- int stack_space_needed = 8;
-
- ymm_t vreg_bcast = ymm_t(15);
- ymm_t vtmp = ymm_t(14);
-
- jit_uni_eltwise_injector_f32<avx2> *eltwise_injector_;
-
- void generate_bcast_loop(int load_loop_blk);
- void generate_reduce_loop(int load_loop_blk, int ur);
- void generate_diff_bias_loop(int load_loop_blk);
-
- void generate();
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp
deleted file mode 100644
index f116ac9056..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp
+++ /dev/null
@@ -1,545 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "jit_generator.hpp"
-
-#include "jit_avx2_1x1_convolution.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::memory_tracking::names;
-using namespace mkldnn::impl::utils;
-
-#define data_blk_off(f, n, c, h, w) \
- ((ndims == 3) \
- ? (f).blk_off(n, c, w) \
- : (f).blk_off(n, c, h, w))
-
-/* convolution forward */
-
-void jit_avx2_1x1_convolution_fwd_t::execute_forward(
- const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
-
- const auto &jcp = kernel_->jcp;
- auto rtus_space = scratchpad(ctx).get<data_t>(key_conv_rtus_space);
-
- const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
- const int ndims = dst_d.ndims();
-
- const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
- const int stride_w = pd()->desc()->strides[ndims - 3];
- const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
- const int pad_l = pd()->desc()->padding[0][ndims - 3];
-
- auto step = [](int default_step, int remaining, int tail_step) {
- assert(default_step <= tail_step);
- return remaining < tail_step ? remaining : default_step;
- };
-
- auto ker = [&](const int ithr, const int nthr) {
- // TODO (Roma): remove this restriction
- assert(jcp.stride_w == 1 && jcp.stride_h == 1);
-
- auto p = jit_1x1_conv_call_s();
- auto rp = rtus_driver_t<avx2>::call_params_t();
-
- const int nb_oc = jcp.nb_load;
- const int nb_ic = jcp.nb_reduce;
- const int nb_ic_blocking = jcp.nb_reduce_blocking;
- const int os_block = jcp.bcast_block;
-
- int start{0}, end{0};
- balance211(work_amount, nthr, ithr, start, end);
-
- int iwork = start;
- while (iwork < end) {
- int n{0}, g{0}, osb{0};
- nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb,
- jcp.nb_bcast);
-
- int bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
- jcp.nb_bcast_blocking_max);
- bcast_step = nstl::min(bcast_step, end - iwork);
-
- const int os = osb * os_block;
- const int oh = os / jcp.ow;
- const int ow = os % jcp.ow;
-
- const int ih = nstl::max(oh * stride_h - pad_t, 0);
- const int iw = nstl::max(ow * stride_w - pad_l, 0);
- rp.iw_start = iw;
-
- p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block);
- rp.os = p.bcast_dim;
-
- int ocb = 0;
- while (ocb < jcp.nb_load) {
- const int load_step = step(jcp.nb_load_blocking,
- jcp.nb_load - ocb, jcp.nb_load_blocking_max);
-
- const int _ocb = g * nb_oc + ocb;
- p.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc,
- load_step * jcp.oc_block);
- const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow);
-
- p.output_data = &dst[dst_off];
-
- p.bias_data = &bias[_ocb * jcp.oc_block];
-
- for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
- p.first_last_flag = 0
- | (icb == 0 ? FLAG_REDUCE_FIRST : 0)
- | (icb + nb_ic_blocking >= nb_ic
- ? FLAG_REDUCE_LAST : 0);
-
- p.reduce_dim = this_block_size(icb * jcp.ic_block, jcp.ic,
- nb_ic_blocking * jcp.ic_block);
- rp.icb = p.reduce_dim / jcp.reduce_block;
-
- p.load_data = &weights[pd()->with_groups()
- ? weights_d.blk_off(g, ocb, icb)
- : weights_d.blk_off(ocb, icb)];
-
- const int _icb = g * nb_ic + icb;
- if (pd()->rtus_.reduce_src_) {
- rp.ws = rtus_space
- + ithr * pd()->rtus_.space_per_thread_
- + _icb * jcp.is * jcp.ic_block;
-
- if (ocb == 0) {
- rp.src = src + data_blk_off(src_d, n, _icb, ih, iw);
- rtus_driver_->ker_(&rp);
- }
-
- p.bcast_data = rp.ws;
- } else
- p.bcast_data = src + data_blk_off(src_d, n, _icb, ih, iw);
-
- kernel_->jit_ker(&p);
- }
-
- ocb += load_step;
- }
-
- iwork += bcast_step;
- }
- };
-
- if (pd()->wants_padded_bias()) {
- auto padded_bias = scratchpad(ctx).get<data_t>(key_conv_padded_bias);
- utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
- utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
- jcp.oc - jcp.oc_without_padding);
- bias = padded_bias;
- }
-
- parallel(0, ker);
-
- if (pd()->wants_zero_pad_dst())
- ctx.memory(MKLDNN_ARG_DST)->zero_pad();
-}
-
-/* convolution backward wtr data */
-
-void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data(
- const exec_ctx_t &ctx) const {
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
- const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
-
- const auto &jcp = kernel_->jcp;
- auto rtus_space = scratchpad(ctx).get<data_t>(key_conv_rtus_space);
-
- // TODO (Roma): remove this restriction
- assert(jcp.stride_w == 1 && jcp.stride_h == 1);
- const int ndims = diff_dst_d.ndims();
-
- const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
- const int stride_w = pd()->desc()->strides[ndims - 3];
- const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
- const int pad_l = pd()->desc()->padding[0][ndims - 3];
-
- const int nb_ic = jcp.nb_load;
- const int nb_oc = jcp.nb_reduce;
- const int os_block = jcp.bcast_block;
- const int nb_oc_blocking = jcp.nb_reduce_blocking;
-
- const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
-
- auto step = [](int default_step, int remaining, int tail_step) {
- assert(default_step <= tail_step);
- return remaining < tail_step ? remaining : default_step;
- };
-
- auto ker = [&](const int ithr, const int nthr) {
- auto p = jit_1x1_conv_call_s();
- auto rp = rtus_driver_t<avx2>::call_params_t();
-
- int start{0}, end{0};
- balance211(work_amount, nthr, ithr, start, end);
-
- int load_step = 0;
- for (int icb = 0; icb < jcp.nb_load; icb += load_step) {
- load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb,
- jcp.nb_load_blocking_max);
-
- p.load_dim = this_block_size(icb * jcp.ic_block, jcp.ic,
- load_step * jcp.ic_block);
- rp.icb = p.load_dim / jcp.ic_block;
-
- int bcast_step;
- for (int iwork = start; iwork < end; iwork += bcast_step) {
- int n{0}, g{0}, osb{0};
- nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb,
- jcp.nb_bcast);
-
- bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
- jcp.nb_bcast_blocking_max);
- bcast_step = nstl::min(bcast_step, end - iwork);
-
- const int os = osb * os_block;
- p.bcast_dim = this_block_size(os, jcp.os,
- bcast_step * os_block);
- rp.os = p.bcast_dim;
-
- const int oh = os / jcp.ow;
- const int ow = os % jcp.ow;
- const int ih = nstl::max(oh * stride_h - pad_t, 0);
- const int iw = nstl::max(ow * stride_w - pad_l, 0);
- rp.iw_start = iw;
-
- const int _icb = g * nb_ic + icb;
- rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw);
- if (pd()->rtus_.reduce_src_) {
- rp.ws = rtus_space
- + ithr * pd()->rtus_.space_per_thread_;
- p.output_data = rp.ws;
- } else
- p.output_data = rp.src;
-
- for (int ocb = 0; ocb < jcp.nb_reduce;
- ocb += jcp.nb_reduce_blocking) {
- const int _ocb = g * nb_oc + ocb;
- size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh,
- ow);
- p.bcast_data = &diff_dst[diff_dst_off];
-
- p.load_data = &weights[pd()->with_groups()
- ? weights_d.blk_off(g, ocb, icb)
- : weights_d.blk_off(ocb, icb)];
-
- p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0;
-
- p.reduce_dim = this_block_size(ocb * jcp.oc_block, jcp.oc,
- nb_oc_blocking * jcp.oc_block);
-
- kernel_->jit_ker(&p);
- }
-
- if (pd()->rtus_.reduce_src_)
- rtus_driver_->ker_(&rp);
- }
- }
- };
-
- parallel(0, ker);
-}
-
-/* convolution backward wtr weights */
-
-jit_avx2_1x1_convolution_bwd_weights_t::jit_avx2_1x1_convolution_bwd_weights_t(
- const pd_t *apd)
- : cpu_primitive_t(apd)
- , kernel_(nullptr)
- , rtus_driver_(nullptr)
-{
- kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr());
- reducer_weights_ =
- new cpu_reducer_2d_t<data_type::f32>(pd()->reducer_wei_conf_);
- reducer_bias_ = new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_);
- init_rtus_driver<avx2>(this);
-}
-
-void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights(
- const exec_ctx_t &ctx) const {
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
- auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
-
- auto scratchpad = this->scratchpad(ctx);
-
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
- const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1));
-
- const auto &jcp = kernel_->jcp;
- auto rtus_space = scratchpad.get<data_t>(key_conv_rtus_space);
-
- data_t *diff_bias = pd()->wants_padded_bias()
- ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
-
- auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
- prefix_reducer_bia);
- auto rb = this->reducer_bias_;
- rb->init(reducer_bia_scratchpad);
-
- auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad,
- prefix_reducer_wei);
- auto rw = this->reducer_weights_;
- rw->init(reducer_wei_scratchpad);
-
- const int ndims = diff_dst_d.ndims();
- // TODO (Roma): remove this restriction
- assert(jcp.stride_w == 1 && jcp.stride_h == 1);
-
- const int nb_ic = jcp.nb_bcast;
- const int nb_ic_blocking = jcp.nb_bcast_blocking;
- const int bcast_work = div_up(nb_ic, nb_ic_blocking);
-
- const int nb_oc = jcp.nb_load;
- const int nb_oc_blocking = jcp.nb_load_blocking;
- const int load_work = div_up(nb_oc, nb_oc_blocking);
-
- const int sp_dim = jcp.reduce_dim;
- const int mb_sp_work = jcp.mb * sp_dim;
-
- const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
- const int stride_w = pd()->desc()->strides[ndims - 3];
- const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
- const int pad_l = pd()->desc()->padding[0][ndims - 3];
-
- auto step = [](int default_step, int remaining, int tail_step) {
- assert(default_step <= tail_step);
- return remaining < tail_step ? remaining : default_step;
- };
-
- auto oc_ic_sp_loop = [=](int sp_start, int sp_end, bool first_image,
- data_t *store_to, size_t store_to_ld, const data_t *diff_dst,
- const data_t *src, int ithr) {
- auto p = jit_1x1_conv_call_s();
- auto rp = rtus_driver_t<avx2>::call_params_t();
-
- p.output_stride = store_to_ld * sizeof(float);
- const int sp_step_def = jcp.nb_reduce_blocking * jcp.reduce_block;
-
- int oc_b_step = 0;
- for (int oc_b = 0; oc_b < nb_oc_blocking; oc_b += oc_b_step) {
- oc_b_step = step(12, nb_oc_blocking - oc_b, 18);
- p.load_dim = oc_b_step * jcp.oc_block;
-
- int ic_b_step = 0;
- for (int ic_b = 0; ic_b < nb_ic_blocking; ic_b += ic_b_step) {
- ic_b_step = step(12, nb_ic_blocking - ic_b, 18);
- p.bcast_dim = ic_b_step * jcp.ic_block;
- rp.icb = p.bcast_dim / jcp.ic_block;
-
- p.output_data = store_to + oc_b * store_to_ld
- + ic_b * jcp.ic_block * jcp.oc_block;
-
- /* spatial reduction */
- int sp_step = 0;
- for (int sp = sp_start; sp < sp_end; sp += sp_step) {
- sp_step = step(sp_step_def, sp_end - sp, 192);
- p.reduce_dim = sp_step;
- rp.os = p.reduce_dim;
-
- p.first_last_flag = sp == sp_start && first_image
- ? FLAG_REDUCE_FIRST : 0;
-
- p.load_data = diff_dst
- + (oc_b * jcp.reduce_dim + sp) * jcp.oc_block;
-
- if (pd()->rtus_.reduce_src_) {
- const int oh = sp / jcp.ow;
- const int ow = sp % jcp.ow;
-
- const int ih = nstl::max(oh * stride_h - pad_t, 0);
- const int iw = nstl::max(ow * stride_w - pad_l, 0);
- rp.iw_start = iw;
-
- rp.ws = rtus_space
- + ithr * pd()->rtus_.space_per_thread_
- + (ic_b * jcp.is + sp) * jcp.ic_block;
- if (ndims == 3)
- rp.src = src
- + iw * src_d.blocking_desc().strides[2];
- else
- rp.src = src
- + ih * src_d.blocking_desc().strides[2]
- + iw * src_d.blocking_desc().strides[3];
-
- if (oc_b == 0)
- rtus_driver_->ker_(&rp);
-
- p.bcast_data = rp.ws;
- } else
- p.bcast_data = src
- + (ic_b * jcp.reduce_dim + sp) * jcp.ic_block;
-
- kernel_->jit_ker(&p);
- }
- }
- }
- };
-
- auto ker = [&](const int ithr, const int nthr) {
- assert(nthr == rw->balancer().nthr_);
-
- const int w_njobs = rw->balancer().ithr_njobs(ithr);
- if (w_njobs == 0) return;
-
- /* setup: independent work (oc, ic) */
- const int w_job_start = rw->balancer().ithr_job_off(ithr);
- int g{0}, load_i{0}, bcast_i{0};
- nd_iterator_init(w_job_start, g, jcp.ngroups, load_i, load_work,
- bcast_i, bcast_work);
-
- /* setup: reduction work (mb, sp) */
- int mb_sp_start{0}, mb_sp_end{0};
- balance211(mb_sp_work, rw->balancer().nthr_per_group_,
- rw->balancer().id_in_group(ithr), mb_sp_start, mb_sp_end);
- int img_start{0}, sp_start{0};
- nd_iterator_init(mb_sp_start, img_start, jcp.mb, sp_start, sp_dim);
-
- /* independent work */
- for (int iwork = 0; iwork < w_njobs; ++iwork) {
- const int oc_b = nb_oc_blocking * load_i;
- const int ic_b = nb_ic_blocking * bcast_i;
-
- const int _ic_b = g * nb_ic + ic_b;
- const int _oc_b = g * nb_oc + oc_b;
-
- data_t *store_to;
- size_t store_to_ld;
-
- if (rw->balancer().nthr_per_group_ == 1) {
- const size_t off = pd()->with_groups()
- ? diff_weights_d.blk_off(g, oc_b, ic_b)
- : diff_weights_d.blk_off(oc_b, ic_b);
- store_to = &diff_weights[off];
- store_to_ld = jcp.ic * jcp.oc_block;
- } else {
- const size_t off = iwork * rw->balancer().job_size_;
- store_to =
- rw->get_local_ptr(ithr, reducer_wei_scratchpad) + off;
- store_to_ld = nb_ic_blocking * jcp.ic_block * jcp.oc_block;
- }
-
- /* reduction work */
- int img = img_start;
- int sp = sp_start;
- int sp_step = 0;
- for (int mb_sp = mb_sp_start; mb_sp < mb_sp_end; mb_sp += sp_step)
- {
- sp_step = nstl::min(sp_dim - sp, mb_sp_end - mb_sp);
-
- const bool first_image = img == img_start;
- oc_ic_sp_loop(sp, sp + sp_step, first_image, store_to,
- store_to_ld, &diff_dst[diff_dst_d.blk_off(img, _oc_b)],
- &src[src_d.blk_off(img, _ic_b)], ithr);
-
- sp = 0;
- img += 1;
- }
-
- nd_iterator_step(g, jcp.ngroups, load_i, load_work, bcast_i,
- bcast_work);
- }
- rw->reduce(ithr, diff_weights, reducer_wei_scratchpad);
- };
-
- auto ker_bias = [&](int ithr, int nthr) {
- assert(nthr == rb->balancer().nthr_);
-
- const int b_job_start = rb->balancer().ithr_job_off(ithr);
- const int b_njobs = rb->balancer().ithr_njobs(ithr);
-
- if (b_njobs == 0) return;
-
- /* reduction dimension */
- int img_start{0}, img_end{0};
- balance211(jcp.mb, rb->balancer().nthr_per_group_,
- rb->balancer().id_in_group(ithr), img_start, img_end);
-
- /* jobs */
- int g_start{0}, ocb_start{0};
- nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, nb_oc);
-
- for (int img = img_start; img < img_end; ++img) {
- int g = g_start, ocb = ocb_start;
- for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
- const size_t _oc = g * nb_oc + ocb;
-
- const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
- data_t *d_bias =
- rb->get_local_ptr(ithr, diff_bias, reducer_bia_scratchpad)
- + b_job_loc * rb->balancer().job_size_;
-
- if (img == img_start)
- for (int o = 0; o < 8; ++o) d_bias[o] = 0.;
-
- for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) {
- PRAGMA_OMP_SIMD()
- for (int o = 0; o < 8; ++o)
- d_bias[o] += d_dst[o];
- d_dst += 8;
- }
-
- nd_iterator_step(g, jcp.ngroups, ocb, nb_oc);
- }
- }
- rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
- };
-
- parallel(0, [&](const int ithr, const int nthr) {
- ker(ithr, nthr);
- if (pd()->with_bias())
- ker_bias(ithr, nthr);
- });
-
- /* TODO: put this in ker_bias */
- if (pd()->wants_padded_bias()) {
- assert(jcp.ngroups == 1);
- for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
- diff_bias_in[oc] = diff_bias[oc];
- }
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp
deleted file mode 100644
index 9762242173..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp
+++ /dev/null
@@ -1,344 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_AVX2_1x1_CONVOLUTION_HPP
-#define CPU_JIT_AVX2_1x1_CONVOLUTION_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "mkldnn_thread.hpp"
-#include "utils.hpp"
-
-#include "cpu_convolution_pd.hpp"
-#include "cpu_primitive.hpp"
-#include "cpu_reducer.hpp"
-
-#include "jit_avx2_1x1_conv_kernel_f32.hpp"
-#include "jit_uni_1x1_conv_utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct jit_avx2_1x1_convolution_fwd_t: public cpu_primitive_t {
- // TODO: (Roma) Code duplication duplication! Remove with templates
- // (maybe...)!
- struct pd_t: public cpu_convolution_fwd_pd_t {
- pd_t(engine_t *engine, const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const typename pd_t::base_class *hint_fwd_pd)
- : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_(), rtus_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""),
- jit_avx2_1x1_convolution_fwd_t);
-
- status_t init() {
- bool ok = true
- && is_fwd()
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::f32, data_type::f32, data_type::f32)
- && !has_zero_dim_memory()
- && set_default_formats();
- if (!ok) return status::unimplemented;
-
- const convolution_desc_t *conv_d = desc();
- const memory_desc_t *src_d = src_md();
- rtus_prepare(this, conv_d, src_d, dst_md());
-
- status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_,
- *conv_d, *src_d, *weights_md(), *dst_md(), *attr());
- if (status != status::success) return status;
-
- auto scratchpad = scratchpad_registry().registrar();
- jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_);
-
- rtus_prepare_space_info(this, scratchpad);
-
- return status::success;
- }
-
- jit_1x1_conv_conf_t jcp_;
- reduce_to_unit_stride_t rtus_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
-
- auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
- auto wei_tag = with_groups()
- ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o)
- : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o);
-
- return set_default_formats_common(dat_tag, wei_tag, dat_tag);
- }
- };
-
- template <cpu_isa_t isa, typename conv_t>
- friend void init_rtus_driver(conv_t *self);
-
- jit_avx2_1x1_convolution_fwd_t(const pd_t *apd)
- : cpu_primitive_t(apd)
- , kernel_(nullptr), rtus_driver_(nullptr)
- {
- kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr());
- init_rtus_driver<avx2>(this);
- }
-
- ~jit_avx2_1x1_convolution_fwd_t() {
- delete kernel_;
- delete rtus_driver_;
- }
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_avx2_1x1_conv_kernel_f32 *kernel_;
- rtus_driver_t<avx2> *rtus_driver_;
-};
-
-struct jit_avx2_1x1_convolution_bwd_data_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_bwd_data_pd_t {
- pd_t(engine_t *engine,
- const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_(), rtus_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""),
- jit_avx2_1x1_convolution_bwd_data_t);
-
- status_t init() {
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_data
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::undef, data_type::f32, data_type::f32)
- && !has_zero_dim_memory()
- && set_default_formats();
- if (!ok) return status::unimplemented;
-
- const convolution_desc_t *conv_d = desc();
- const memory_desc_t *diff_src_d = diff_src_md();
- rtus_prepare(this, conv_d, diff_src_d, diff_dst_md());
-
- status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_,
- *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(),
- *attr());
- if (status != status::success) return status;
-
- auto scratchpad = scratchpad_registry().registrar();
- jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_);
-
- rtus_prepare_space_info(this, scratchpad);
-
- return status::success;
- }
-
- jit_1x1_conv_conf_t jcp_;
- reduce_to_unit_stride_t rtus_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
-
- auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
- auto wei_tag = with_groups()
- ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i)
- : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i);
-
- return set_default_formats_common(dat_tag, wei_tag, dat_tag);
- }
- };
-
- template <cpu_isa_t isa, typename conv_t>
- friend void init_rtus_driver(conv_t *self);
-
- jit_avx2_1x1_convolution_bwd_data_t(const pd_t *apd)
- : cpu_primitive_t(apd)
- , kernel_(nullptr)
- , rtus_driver_(nullptr)
- {
- kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr());
- init_rtus_driver<avx2>(this);
- }
-
- ~jit_avx2_1x1_convolution_bwd_data_t() {
- delete kernel_;
- delete rtus_driver_;
- }
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward_data(ctx);
- return status::success;
- }
-
-private:
- void execute_backward_data(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_avx2_1x1_conv_kernel_f32 *kernel_;
- rtus_driver_t<avx2> *rtus_driver_;
-};
-
-struct jit_avx2_1x1_convolution_bwd_weights_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_bwd_weights_pd_t {
- pd_t(engine_t *engine, const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_(), rtus_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""),
- jit_avx2_1x1_convolution_bwd_weights_t);
-
- status_t init() {
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_weights
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::f32, data_type::f32, data_type::f32)
- && !has_zero_dim_memory()
- && set_default_formats();
- if (!ok) return status::unimplemented;
-
- const convolution_desc_t *conv_d = desc();
- const memory_desc_t *src_d = src_md();
- rtus_prepare(this, conv_d, src_d, diff_dst_md());
-
- status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_,
- *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(),
- *attr());
- if (status != status::success) return status;
-
- init_balancers();
-
- auto scratchpad = scratchpad_registry().registrar();
- jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_);
-
- rtus_prepare_space_info(this, scratchpad);
-
- auto reducer_bia_scratchpad = memory_tracking::registrar_t(
- scratchpad, memory_tracking::names::prefix_reducer_bia);
- reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
-
- auto reducer_wei_scratchpad = memory_tracking::registrar_t(
- scratchpad, memory_tracking::names::prefix_reducer_wei);
- reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad);
-
- return status::success;
- }
-
- jit_1x1_conv_conf_t jcp_;
- cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
- cpu_reducer_2d_t<data_type::f32>::conf_t reducer_wei_conf_;
- reduce_to_unit_stride_t rtus_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
-
- auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
- auto wei_tag = with_groups()
- ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o)
- : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o);
-
- return set_default_formats_common(dat_tag, wei_tag, dat_tag);
- }
-
- private:
- void init_balancers() {
- const int ic_block = jcp_.bcast_block;
- const int nb_ic = jcp_.nb_bcast;
- const int nb_ic_blocking = jcp_.nb_bcast_blocking;
- const int bcast_work = utils::div_up(nb_ic, nb_ic_blocking);
-
- const int oc_block = jcp_.load_block;
- const int nb_oc = jcp_.nb_load;
- const int nb_oc_blocking = jcp_.nb_load_blocking;
- const int load_work = utils::div_up(nb_oc, nb_oc_blocking);
-
- const int job_size
- = nb_oc_blocking * nb_ic_blocking * ic_block * oc_block;
- const int njobs_x = bcast_work;
- const int njobs_y = jcp_.ngroups * load_work;
-
- const int max_threads = mkldnn_get_max_threads();
- const size_t max_buffer_size = max_threads * job_size * 8;
-
- if (with_bias()) {
- reducer_bia_conf_.init(reduce_balancer_t(max_threads,
- oc_block, jcp_.ngroups * jcp_.oc / oc_block,
- jcp_.mb, max_buffer_size));
- }
-
- reducer_wei_conf_.init(
- reduce_balancer_t(max_threads, job_size, njobs_y * njobs_x,
- jcp_.mb * jcp_.nb_reduce, max_buffer_size),
- job_size / nb_oc_blocking, nb_oc_blocking, ic_block,
- nb_ic * ic_block * oc_block, nb_oc);
- }
- };
-
- template <cpu_isa_t isa, typename conv_t>
- friend void init_rtus_driver(conv_t *self);
-
- jit_avx2_1x1_convolution_bwd_weights_t(const pd_t *apd);
-
- ~jit_avx2_1x1_convolution_bwd_weights_t() {
- delete kernel_;
- delete rtus_driver_;
- delete reducer_weights_;
- delete reducer_bias_;
- }
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward_weights(ctx);
- return status::success;
- }
-
-private:
- void execute_backward_weights(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_avx2_1x1_conv_kernel_f32 *kernel_;
- cpu_reducer_2d_t<data_type::f32> *reducer_weights_;
- cpu_reducer_t<data_type::f32> *reducer_bias_;
- rtus_driver_t<avx2> *rtus_driver_;
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp
deleted file mode 100644
index e24770a2da..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp
+++ /dev/null
@@ -1,1501 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-* Copyright 2018 YANDEX LLC
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-#include "cpu_memory.hpp"
-
-#include "jit_avx2_conv_kernel_f32.hpp"
-
-#define GET_OFF(field) offsetof(jit_conv_call_s, field)
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::prop_kind;
-using namespace mkldnn::impl::format_tag;
-using namespace mkldnn::impl::memory_tracking::names;
-using namespace mkldnn::impl::utils;
-
-using namespace Xbyak;
-
-void jit_avx2_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w,
- int pad_l, int pad_r, int oc_blocks)
-{
- int iw = jcp.iw;
- int ih = jcp.ih;
- int id = jcp.id;
- int kw = jcp.kw;
- int kh = jcp.kh;
- int kd = jcp.kd;
- int nb_ic = jcp.nb_ic;
- int stride_w = jcp.stride_w;
- int dilate_w = jcp.dilate_w + 1;
- int ic_blk = jcp.ic_block;
- int oc_blk = jcp.oc_block;
-
- for (int ki = 0; ki < kw; ki++) {
- int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
- int jj_end = ur_w
- - nstl::max(0, div_up(ki*dilate_w+pad_r-(kw-1)*dilate_w, stride_w));
- for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
- for (int jj = jj_start; jj < jj_end; jj++) {
- size_t inp_off;
- if (one_of(jcp.src_tag, ncw, nchw, ncdhw))
- inp_off = sizeof(float)*((size_t)ifm2*id*ih*iw
- + (ki*dilate_w + jj*stride_w - pad_l));
- else
- inp_off = sizeof(float)*((ki*dilate_w + jj*stride_w
- - pad_l)*ic_blk + ifm2);
- vbroadcastss(Ymm(oc_blocks * ur_w + jj),
- make_safe_addr(aux_reg_input, inp_off, reg_long_offt));
- }
-
- for (int ii = 0; ii < oc_blocks; ii++) {
- int ker_off = ii * nb_ic * kd * kh * kw * ic_blk * oc_blk
- + ki * ic_blk * oc_blk + ifm2 * oc_blk;
- vmovups(ymm15, ptr[aux_reg_kernel + sizeof(float) * ker_off]);
- for (int jj = jj_start; jj < jj_end; jj++)
- if (mayiuse(avx2))
- vfmadd231ps(Ymm(ur_w * ii + jj),
- Ymm(oc_blocks * ur_w + jj), ymm15);
- else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support
- vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj));
- vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp);
- }
- }
- }
- }
-}
-
-void jit_avx2_conv_fwd_kernel_f32::oh_step_nopad(int ur_w,
- int pad_l, int pad_r, char pad_tag,
- int oc_blocks, char oc_blocks_tag)
-{
- Label kw_loop;
-
- int iw = jcp.iw;
- int ih = jcp.ih;
- int id = jcp.id;
- int kw = jcp.kw;
- int kh = jcp.kh;
- int kd = jcp.kd;
- int nb_ic = jcp.nb_ic;
- int stride_w = jcp.stride_w;
- int dilate_w = jcp.dilate_w + 1;
- int ic_blk = jcp.ic_block;
- int oc_blk = jcp.oc_block;
-
- xor_(ki_iter, ki_iter);
- L(kw_loop);
- {
- int jj_start = 0;
- int jj_end = ur_w;
- for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
- for (int jj = jj_start; jj < jj_end; jj++) {
- size_t inp_off;
- if (one_of(jcp.src_tag, ncw, nchw, ncdhw))
- inp_off = sizeof(float)*((size_t)ifm2 * id * ih * iw
- + (jj * stride_w - pad_l));
- else
- inp_off = sizeof(float)*((jj * stride_w - pad_l) * ic_blk
- + ifm2);
- vbroadcastss(Ymm(oc_blocks * ur_w + jj),
- make_safe_addr(aux_reg_input, inp_off, reg_long_offt));
- }
- for (int ii = 0; ii < oc_blocks; ii++) {
- int aux_kernel_offset =
- ii * nb_ic * kd * kh * kw * ic_blk * oc_blk + ifm2 * oc_blk;
- vmovups(ymm15, ptr[aux_reg_kernel
- + sizeof(float) * aux_kernel_offset]);
- for (int jj = jj_start; jj < jj_end; jj++)
- if (mayiuse(avx2))
- vfmadd231ps(Ymm(ur_w * ii + jj),
- Ymm(oc_blocks * ur_w + jj), ymm15);
- else { // Intel AVX support
- vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj));
- vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp);
- }
- }
- }
- add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk);
- add(aux_reg_input, sizeof(float) * (one_of(jcp.src_tag, ncw, nchw, ncdhw)
- ? dilate_w : ic_blk * dilate_w));
-
- inc(ki_iter);
- cmp(ki_iter, kw);
- jl(kw_loop, T_NEAR);
- }
-}
-
-void jit_avx2_conv_fwd_kernel_f32::width_blk_step(int ur_w,
- int pad_l, int pad_r, char pad_tag,
- int oc_blocks, char oc_blocks_tag)
-{
- int iw = jcp.iw;
- int kw = jcp.kw;
- int ow = jcp.ow;
- int oh = jcp.oh;
- int od = jcp.od;
- int dilate_h = jcp.dilate_h + 1;
- int dilate_w = jcp.dilate_w + 1;
- int ic_blk = jcp.ic_block;
- int oc_blk = jcp.oc_block;
- const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw)
- ? 1 : ic_blk;
- const int inp_off = one_of(jcp.src_tag, ncw, nchw, ncdhw)
- ? dilate_w : ic_blk * dilate_w;
-
- Label init_done, init_first;
-
- if (!jcp.with_sum) {
- test(reg_ci_flag, FLAG_IC_FIRST);
- jne(init_first, T_NEAR);
- }
-
- for (int ii = 0; ii < oc_blocks; ii++) {
- for (int jj = 0; jj < ur_w; jj++) {
- size_t offt =
- sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk;
- vmovups(Ymm(ur_w * ii + jj),
- make_safe_addr(reg_output, offt, reg_long_offt));
- }
- }
-
- if (jcp.with_sum && jcp.with_bias) {
- test(reg_ci_flag, FLAG_IC_FIRST);
- je(init_done, T_NEAR);
-
- for (int ii = 0; ii < oc_blocks; ii++)
- for (int jj = 0; jj < ur_w; jj++)
- vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
- yword[reg_bias + sizeof(float) * ii * oc_blk]);
- }
-
- jmp(init_done);
-
- L(init_first);
- if (this->jcp.with_bias) {
- for (int ii = 0; ii < oc_blocks; ii++)
- for (int jj = 0; jj < ur_w; jj++)
- vmovups(Ymm(ur_w * ii + jj),
- yword[reg_bias + sizeof(float) * ii * oc_blk]);
- } else {
- for (int ii = 0; ii < oc_blocks; ii++)
- for (int jj = 0; jj < ur_w; jj++)
- uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj));
- }
-
- L(init_done);
-
- if (one_of(jcp.ndims, 3, 4)) {
- mov(aux_reg_input, reg_input);
- mov(aux_reg_kernel, reg_kernel);
- }
-
- Label skip_kh_loop, skip_kd_loop, kd_loop;
- if (jcp.ndims == 5) {
- push(reg_output);
- push(oi_iter);
-
- mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
- mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
- mov(aux_reg_inp_d, reg_input);
-
- if ((jcp.dilate_d >= jcp.id)
- || (jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) {
- cmp(reg_ki, 0);
- je(skip_kd_loop, T_NEAR);
- }
- L(kd_loop);
- mov(kj, ptr[param1 + GET_OFF(kh_padding)]);
- } else {
- mov(kj, reg_kh);
- }
-
- if (jcp.ndims == 5) {
- mov(aux_reg_input, aux_reg_inp_d);
- mov(aux_reg_kernel, aux_reg_ker_d);
- }
-
- if ((jcp.dilate_h >= jcp.ih)
- || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
- cmp(kj, 0);
- je(skip_kh_loop, T_NEAR);
- }
- Label kh_loop;
- L(kh_loop);
- {
- if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) {
- oh_step_nopad(ur_w, pad_l, pad_r, pad_tag, oc_blocks,
- oc_blocks_tag);
- sub(aux_reg_input, sizeof(float) * kw * inp_off);
- add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult);
- } else {
- oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks);
- add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk);
- add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult);
- }
-
- dec(kj);
- cmp(kj, 0);
- jg(kh_loop, T_NEAR);
- }
-
- L(skip_kh_loop);
-
- if (jcp.ndims == 5) {
- add(aux_reg_inp_d,
- sizeof(float) * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mult);
- add(aux_reg_ker_d, sizeof(float) * jcp.kw * jcp.kh * jcp.oc_block
- * jcp.ic_block);
-
- dec(reg_ki);
- cmp(reg_ki, 0);
- jg(kd_loop, T_NEAR);
- L(skip_kd_loop);
-
- pop(oi_iter);
- pop(reg_output);
- }
-
- Label regular_store;
-
- if (jcp.with_eltwise) {
- test(reg_ci_flag, FLAG_IC_LAST);
- je(regular_store, T_NEAR);
-
- eltwise_injector_->compute_vector_range(0, oc_blocks * ur_w);
-
- L(regular_store);
- }
-
- for (int ii = 0; ii < oc_blocks; ii++) {
- for (int jj = 0; jj < ur_w; jj++) {
- const size_t o_off
- = sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk;
- Ymm reg_out = Ymm(ur_w * ii + jj);
- vmovups(make_safe_addr(reg_output, o_off, reg_long_offt), reg_out);
- }
- }
-}
-
-inline void jit_avx2_conv_fwd_kernel_f32::solve_common(
- int oc_blocks, char oc_blocks_tag)
-{
- int ur_w = jcp.ur_w;
- int ur_w_tail = jcp.ur_w_tail;
- int n_oi = jcp.ow / ur_w;
- int iw = jcp.iw;
- int kw = jcp.kw;
- int ic_blk = jcp.ic_block;
- int oc_blk = jcp.oc_block;
- int dilate_w = jcp.dilate_w + 1;
- int str_w = jcp.stride_w;
- const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : ic_blk;
-
- int l_pad = jcp.l_pad;
- int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w
- - (iw + l_pad - 1));
- int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w
- - (iw + l_pad - 1);
- if (r_pad1 > 0) n_oi--;
-
- if (l_pad > 0) {
- n_oi--;
- if (n_oi < 0 && r_pad1 > 0)
- width_blk_step(ur_w, l_pad, r_pad1,
- 'l', oc_blocks, oc_blocks_tag); // "lrpad"
- else
- width_blk_step(ur_w, l_pad, 0,
- 'l', oc_blocks, oc_blocks_tag); // "lpad"
- add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult);
- add(reg_output, sizeof(float) * ur_w * oc_blk);
- }
-
- Label ow_loop;
- xor_(oi_iter, oi_iter);
-
- if (n_oi > 0) {
- L(ow_loop);
-
- width_blk_step(ur_w, 0, 0,
- 'm', oc_blocks, oc_blocks_tag); // "middle"
- add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
- add(reg_output, sizeof(float) * ur_w * oc_blk);
-
- inc(oi_iter);
- cmp(oi_iter, n_oi);
- jl(ow_loop, T_NEAR);
- }
-
- if (r_pad1 > 0 && n_oi >=0) {
- width_blk_step(ur_w, 0, r_pad1,
- 'r', oc_blocks, oc_blocks_tag); // "rpad"
- add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
- add(reg_output, sizeof(float) * ur_w * oc_blk);
- }
-
- if (ur_w_tail != 0)
- width_blk_step(ur_w_tail, 0, r_pad,
- 't', oc_blocks, oc_blocks_tag); // "tail"
-}
-
-void jit_avx2_conv_fwd_kernel_f32::generate()
-{
- this->preamble();
-
- mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
- mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
- mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
- if (jcp.with_bias)
- mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
- mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
- mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
- mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]);
-
- int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking;
- Label tail, exit;
-
- if (jcp.nb_oc > jcp.nb_oc_blocking) {
- cmp(reg_oc_blocks, jcp.nb_oc_blocking);
- jne(nb_oc_tail ? tail : exit, T_NEAR);
-
- solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking);
- jmp(exit, T_NEAR);
-
- if (nb_oc_tail) {
- L(tail);
- cmp(reg_oc_blocks, nb_oc_tail);
- jne(exit, T_NEAR);
- solve_common(nb_oc_tail, '0' + nb_oc_tail);
- }
-
- L(exit);
- } else if (jcp.nb_oc == jcp.nb_oc_blocking) {
- solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking);
- } else {
- solve_common(nb_oc_tail, '0' + nb_oc_tail);
- }
-
- this->postamble();
-
- if (jcp.with_eltwise)
- eltwise_injector_->prepare_table();
-}
-
-bool jit_avx2_conv_fwd_kernel_f32::post_ops_ok(
- jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
- const auto &p = attr.post_ops_;
-
- auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
- auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
-
- switch (p.len_) {
- case 0: return true; // no post_ops
- case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
- case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
- default: return false;
- }
-
- return false;
-}
-
-status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr)
-{
- if (!mayiuse(avx)) return status::unimplemented;
-
- jcp.prop_kind = cd.prop_kind;
-
- const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
- int ndims = src_d.ndims();
- jcp.ndims = ndims;
-
- jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
- jcp.mb = src_d.dims()[0];
-
- jcp.oc = dst_d.dims()[1] / jcp.ngroups;
- jcp.oc_without_padding = jcp.oc;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
-
- jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
- jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
- jcp.iw = src_d.dims()[ndims-1];
- jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
- jcp.oh = (ndims == 3) ? 1 :dst_d.dims()[ndims-2];
- jcp.ow = dst_d.dims()[ndims-1];
- jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
- jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2];
- jcp.kw = weights_d.dims()[with_groups + ndims-1];
-
- jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
- jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
- jcp.l_pad = cd.padding[0][ndims-3];
- jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
- jcp.stride_h = (ndims == 3) ? 1 :cd.strides[ndims-4];
- jcp.stride_w = cd.strides[ndims-3];
-
- jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
- jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
- jcp.dilate_w = cd.dilates[ndims-3];
-
- jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
- - (jcp.ih + jcp.t_pad - 1);
-
- if (ndims == 3) {
- jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c);
- jcp.wei_tag = weights_d.matches_one_of_tag(
- Owi8o, gOwi8o, OIw8i8o, gOIw8i8o);
- jcp.dst_tag = dst_d.matches_one_of_tag(nCw8c);
- } else if (ndims == 4) {
- jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c);
- jcp.wei_tag = weights_d.matches_one_of_tag(
- Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o);
- jcp.dst_tag = dst_d.matches_one_of_tag(nChw8c);
- } else if (ndims == 5) {
- jcp.src_tag = src_d.matches_one_of_tag(ncdhw, ndhwc, nCdhw8c);
- jcp.wei_tag = weights_d.matches_one_of_tag(
- Odhwi8o, gOdhwi8o, OIdhw8i8o, gOIdhw8i8o);
- jcp.dst_tag = dst_d.matches_one_of_tag(nCdhw8c);
- }
- jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
-
- if (!post_ops_ok(jcp, attr))
- return status::unimplemented;
-
- const auto &p = attr.post_ops_;
- jcp.with_sum = p.find(primitive_kind::sum) != -1;
- const int eltwise_ind = p.find(primitive_kind::eltwise);
- jcp.with_eltwise = eltwise_ind != -1;
- if (jcp.with_eltwise) {
- jcp.eltwise = p.entry_[eltwise_ind].eltwise;
- if (!mayiuse(avx2) && jcp.eltwise.alg != alg_kind::eltwise_relu)
- return status::unimplemented;
- }
-
- const int simd_w = 8;
- const bool flat = jcp.ic < simd_w;
- const bool mimo = !flat;
-
-
- /* Grouped channel offset to support 'non-blocked data' format for
- * convolution sizes with '(input_channel / ngroups) < simd' */
- jcp.nonblk_group_off =
- one_of(jcp.src_tag, ncw, nchw, ncdhw) && jcp.ngroups > 1 ? jcp.ic : 1;
-
- bool ok_to_pad_channels = true
- && jcp.ngroups == 1;
-
- if (ok_to_pad_channels) {
- jcp.oc = rnd_up(jcp.oc, simd_w);
- if (mimo)
- jcp.ic = rnd_up(jcp.ic, simd_w);
- }
-
- bool args_ok = true
- && IMPLICATION(flat, true
- && one_of(jcp.src_tag, ncw, nwc, nchw, nhwc, ncdhw, ndhwc)
- && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o, Odhwi8o,
- gOdhwi8o))
- && IMPLICATION(mimo, true
- && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c)
- && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o,
- OIdhw8i8o, gOIdhw8i8o))
- && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c);
- if (!args_ok) return status::unimplemented;
-
- jcp.ur_h = 1; /* no code-unrolling by h so far */
- jcp.ur_w = 3;
-
- jcp.oc_block = simd_w;
- jcp.nb_oc = jcp.oc / jcp.oc_block;
-
- jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */
-
- // Intel AVX and Intel AVX2 kernels need 2 and 1 temporary YMMs, respectively
- // Thus, we can only assign 14 or 15 YMMs for data storage
- const int num_avail_regs = mayiuse(avx2) ? 15 : 14;
- if (!mayiuse(avx2)) {
- if ((jcp.nb_oc_blocking + 1) * jcp.ur_w > num_avail_regs) {
- // current register assignment requires more YMMs than available
- // adjust one of nb_oc_block, ur_w preserving to ur_w >= l_pad
- if (jcp.ur_w > jcp.l_pad && jcp.ur_w > 1)
- jcp.ur_w -= 1;
- else
- for (int b = 3; b > 1; b--)
- if (jcp.nb_oc % b == 0) {
- jcp.nb_oc_blocking = b;
- break;
- }
- }
- }
-
- if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
- jcp.ur_w_tail = jcp.ow % jcp.ur_w;
-
- args_ok = true
- && jcp.oc % simd_w == 0
- && jcp.l_pad <= jcp.ur_w
- && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0)
- || (jcp.stride_w == 1 && jcp.stride_h == 1))
- && IMPLICATION(mimo, jcp.ic % simd_w == 0);
- if (!args_ok) return status::unimplemented;
-
- int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
- + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
-
- if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) {
- /* recalculate ur_w, nb_oc_blocking and ur_w_tail */
- jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail,
- nstl::min(jcp.ow, num_avail_regs / 2));
- jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w;
- jcp.ur_w_tail = jcp.ow % jcp.ur_w;
- /* check again ... */
- r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
- + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
- if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail))
- return status::unimplemented;
- }
- assert(jcp.nb_oc_blocking > 0);
- assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs);
-
- jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w;
- jcp.nb_ic = jcp.ic / jcp.ic_block;
-
- if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
- jcp.nb_ic_blocking = 12;
- jcp.nb_ic_blocking_max = 16;
- } else {
- jcp.nb_ic_blocking = 1;
- jcp.nb_ic_blocking_max = jcp.nb_ic_blocking;
- }
-
- return status::success;
-}
-
-void jit_avx2_conv_fwd_kernel_f32::init_scratchpad(
- memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
- if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
- scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
-}
-
-void jit_avx2_conv_bwd_data_kernel_f32::compute_loop(int ur_w, int l_overflow,
- int r_overflow)
-{
- int kw = jcp.kw;
- int kh = jcp.kh;
- int kd = jcp.kd;
- int iw = jcp.iw;
- int ih = jcp.ih;
- int id = jcp.id;
- int ow = jcp.ow;
-
- int ic_block = jcp.ic_block;
- int oc_block = jcp.oc_block;
- int nb_ic_block = jcp.nb_ic_blocking;
- int stride_w = jcp.stride_w;
- int stride_h = jcp.stride_h;
-
- Label kd_loop, skip_kd_loop;
- Label oc_loop, skip_oc_loop;
-
- for (int ii = 0; ii < nb_ic_block; ii++)
- for (int jj = 0; jj < ur_w; jj++) {
- uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
- Ymm(ur_w * ii + jj));
- }
-
- if (one_of(jcp.ndims, 3, 4)) {
- cmp(reg_channel_work, 0);
- jle(skip_oc_loop, T_NEAR);
- xor_(reg_channel, reg_channel);
-
- mov(aux_reg_ddst_oc_loop, reg_ddst);
- mov(aux_reg_kernel_oc_loop, reg_kernel);
-
- L(oc_loop);
- mov(aux_reg_ddst, aux_reg_ddst_oc_loop);
- mov(aux_reg_kernel, aux_reg_kernel_oc_loop);
- }
-
- if (jcp.ndims == 5) {
- assert(jcp.nb_oc_blocking == 1);
- push(oi_iter);
-
- mov(reg_ki, ptr[this->param1 + GET_OFF(kd_padding)]);
- mov(aux_reg_dst_d, reg_ddst);
- mov(aux_reg_ker_d, ptr[this->param1 + GET_OFF(filt)]);
-
- L(kd_loop);
- mov(kj, ptr[this->param1 + GET_OFF(kh_padding)]);
- } else {
- mov(kj, reg_kh);
- }
-
- if (jcp.ndims == 5) {
- mov(aux_reg_ddst, aux_reg_dst_d);
- mov(aux_reg_kernel, aux_reg_ker_d);
- }
-
- Label kh_loop, skip_kh_loop;
- cmp(kj, 0);
- jle(skip_kh_loop, T_NEAR);
- L(kh_loop); {
- for (int ki = 0; ki < kw; ki++) {
- int jj_start = get_iw_start(ki, l_overflow); // 0;
- int jj_end = get_iw_end(ur_w, ki, r_overflow); // ur_w;
- for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) {
-
- for (int jj = jj_start ; jj < jj_end; jj += stride_w) {
- int aux_output_offset
- = (jj + jcp.l_pad - ki) / stride_w * jcp.oc_block + ofm2;
- vbroadcastss(Ymm(nb_ic_block * ur_w + jj / stride_w),
- ptr[aux_reg_ddst
- + sizeof(float) * aux_output_offset]);
- }
-
- for (int ii = 0; ii < nb_ic_block; ii++) {
- int aux_kernel_offset
- = ii * kd * kh * kw * jcp.ic_block * jcp.oc_block
- + ki * jcp.ic_block * jcp.oc_block
- + ofm2 * jcp.ic_block;
- vmovups(ymm15,
- ptr[aux_reg_kernel
- + sizeof(float) * aux_kernel_offset]);
- for (int jj = jj_start; jj < jj_end; jj += stride_w)
- vfmadd231ps(Ymm(ur_w * ii + jj),
- Ymm(nb_ic_block * ur_w + jj / stride_w), ymm15);
- }
- }
- }
- add(aux_reg_kernel, sizeof(float) * stride_h * kw * oc_block
- * ic_block);
- sub(aux_reg_ddst, sizeof(float) * ow * oc_block);
-
- dec(kj);
- cmp(kj, 0);
- jg(kh_loop, T_NEAR);
- }
- L(skip_kh_loop);
-
- if (jcp.ndims == 5) {
- sub(aux_reg_dst_d,
- sizeof(float) * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
- add(aux_reg_ker_d,
- sizeof(float) * jcp.kw * jcp.kh * oc_block * ic_block);
-
- dec(reg_ki);
- cmp(reg_ki, 0);
- jg(kd_loop, T_NEAR);
- L(skip_kd_loop);
-
- pop(oi_iter);
- }
-
- if (one_of(jcp.ndims, 3, 4)) {
- int ddst_oc_shift = sizeof(float) * jcp.od * jcp.oh * jcp.ow
- * jcp.oc_block;
- int kernel_oc_shift = sizeof(float) * jcp.kd * jcp.kh * jcp.kw
- * jcp.ic * jcp.oc_block;
-
- add(aux_reg_ddst_oc_loop, ddst_oc_shift);
- add(aux_reg_kernel_oc_loop, kernel_oc_shift);
-
- inc(reg_channel);
- cmp(reg_channel, reg_channel_work);
- jl(oc_loop, T_NEAR);
-
- L(skip_oc_loop);
- mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
- }
-
- Label no_update_label;
- cmp(reg_channel, 0);
- je(no_update_label, T_NEAR);
- for (int ii = 0; ii < nb_ic_block; ii++) {
- for (int jj = 0; jj < ur_w; jj++) {
- size_t offt =
- sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block;
- vmovups(Ymm(15),
- make_safe_addr(reg_dsrc, offt, reg_long_offt));
- vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
- Ymm(15));
-
- }
- }
- L(no_update_label);
-
- for (int ii = 0; ii < nb_ic_block; ii++)
- for (int jj = 0; jj < ur_w; jj++) {
- size_t offt =
- sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block;
- vmovups(make_safe_addr(reg_dsrc, offt, reg_long_offt),
- Ymm(ur_w * ii + jj));
- }
-}
-
-void jit_avx2_conv_bwd_data_kernel_f32::generate() {
- preamble();
-
- mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]);
- mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]);
- mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
- mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
- mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
- mov(reg_channel_work, ptr[param1 + GET_OFF(ch_blocks)]);
-
- int ddst_shift = sizeof(float) * (jcp.ur_w / jcp.stride_w) * jcp.ic_block;
- int dsrc_shift = sizeof(float) * jcp.ur_w * jcp.oc_block;
-
- int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w);
- int r_overflow = nstl::max(0, (jcp.kw - 1
- - nstl::max(0, jcp.r_pad)) / jcp.stride_w);
- int r_overflow1 = nstl::max(0, (jcp.kw - 1
- - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
-
- int n_oi = jcp.iw / jcp.ur_w;
- if (r_overflow1 > 0)
- n_oi--;
-
- if (jcp.ur_w == jcp.iw) {
- compute_loop(jcp.ur_w, l_overflow, r_overflow);
- } else if (n_oi == 0) {
- compute_loop(jcp.ur_w, l_overflow, r_overflow1);
- add(reg_dsrc, dsrc_shift);
- add(reg_ddst, ddst_shift);
- if (jcp.ur_w_tail != 0)
- compute_loop(jcp.ur_w_tail, 0, r_overflow);
- } else {
- xor_(oi_iter, oi_iter);
- if (l_overflow > 0) {
- compute_loop(jcp.ur_w, l_overflow, 0);
- add(reg_dsrc, dsrc_shift);
- add(reg_ddst, ddst_shift);
- inc(oi_iter);
- }
-
- if ((l_overflow <= 0 && n_oi > 0) || (l_overflow > 0 && n_oi > 1)) {
- Label ow_loop;
- L(ow_loop); {
- compute_loop(jcp.ur_w, 0, 0);
- add(reg_dsrc, dsrc_shift);
- add(reg_ddst, ddst_shift);
- inc(oi_iter);
- cmp(oi_iter, n_oi); jl(ow_loop, T_NEAR);
- }
- }
-
- if (r_overflow1 > 0 ) {
- compute_loop(jcp.ur_w, 0, r_overflow1);
- add(reg_dsrc, dsrc_shift);
- add(reg_ddst, ddst_shift);
- }
-
- if (jcp.ur_w_tail != 0)
- compute_loop(jcp.ur_w_tail, 0, r_overflow);
- }
-
- this->postamble();
-}
-
-status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &diff_dst_d)
-{
- if (!mayiuse(avx2)) return status::unimplemented;
-
- const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
-
- int ndims = diff_src_d.ndims();
- jcp.ndims = ndims;
-
- jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
- jcp.mb = diff_src_d.dims()[0];
-
- jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
- jcp.oc_without_padding = jcp.oc;
- jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
-
- jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1;
- jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2];
- jcp.iw = diff_src_d.dims()[ndims-1];
- jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
- jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
- jcp.ow = diff_dst_d.dims()[ndims-1];
-
- jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
- jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
- jcp.kw = weights_d.dims()[with_groups + ndims - 1];
-
- jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
- jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
- jcp.l_pad = cd.padding[0][ndims-3];
-
- jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
- jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
- jcp.stride_w = cd.strides[ndims-3];
-
- jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
- jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
- jcp.dilate_w = cd.dilates[ndims-3];
-
- const int simd_w = 8;
-
- /* derivatives */
- jcp.idp = jcp.id + 2 * jcp.f_pad;
- jcp.ihp = jcp.ih + 2 * jcp.t_pad;
- jcp.iwp = jcp.iw + 2 * jcp.l_pad;
- jcp.ohp = jcp.oh; /* do we really need */
- jcp.owp = jcp.ow; /* padded output ??? */
-
- bool ok_to_pad_channels = true
- && jcp.ngroups == 1;
-
- /* gemm-based convolution performs better in these cases */
- if (jcp.ic < simd_w && jcp.kw > 3 && jcp.stride_w > 1)
- return status::unimplemented;
-
- if (ok_to_pad_channels) {
- jcp.oc = rnd_up(jcp.oc, simd_w);
- jcp.ic = rnd_up(jcp.ic, simd_w);
- }
-
- jcp.ic_block = (jcp.ic % simd_w) ? 1 : simd_w;
- jcp.nb_ic = jcp.ic / jcp.ic_block;
-
- jcp.oc_block = simd_w;
- if (jcp.oc % jcp.oc_block) return status::unimplemented;
- jcp.nb_oc = jcp.oc / jcp.oc_block;
-
- jcp.ur_h = 1; /* no code-unrolling by h so far */
- jcp.nb_ic_blocking = 1;
- jcp.nb_oc_blocking = 1;
- jcp.ur_w = 1;
-
- if(one_of(ndims, 3, 4) && jcp.ow < 40)
- jcp.nb_oc_blocking = jcp.ow < 15 ? 4 : 2;
-
- if (ndims == 3) {
- jcp.src_tag = diff_src_d.matches_one_of_tag(nCw8c);
- jcp.wei_tag = weights_d.matches_one_of_tag(OIw8i8o, gOIw8o8i);
- jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCw8c);
- } else if (ndims == 4) {
- jcp.src_tag = diff_src_d.matches_one_of_tag(nChw8c);
- jcp.wei_tag = weights_d.matches_one_of_tag(OIhw8o8i, gOIhw8o8i);
- jcp.dst_tag = diff_dst_d.matches_one_of_tag(nChw8c);
- } else if (ndims == 5) {
- jcp.src_tag = diff_src_d.matches_one_of_tag(nCdhw8c);
- jcp.wei_tag = weights_d.matches_one_of_tag(OIdhw8o8i, gOIdhw8o8i);
- jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCdhw8c);
- }
-
- bool args_ok = true
- && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c)
- && one_of(jcp.wei_tag, gOIw8o8i, OIw8i8o, gOIhw8o8i, OIhw8o8i,
- gOIdhw8o8i, OIdhw8o8i)
- && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c)
- && jcp.stride_w == jcp.stride_h
- && jcp.stride_d == 1
- && jcp.dilate_d == 0
- && jcp.dilate_h == 0
- && jcp.dilate_w == 0
- && jcp.ic % simd_w == 0
- && jcp.oc % simd_w == 0
- && jcp.od == (jcp.idp - jcp.kd) / jcp.stride_d + 1
- && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
- && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1;
- if (!args_ok) return status::unimplemented;
- jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad;
- jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad;
- int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w);
-
- const int max_regs = 15; /* Maximun number of registers available for
- result accumulation and delta dst data.
- One additional register is reserved for weights
- data. */
-
- /* Find the best blocking with maximum number of fma instructions
- per ur_w * nb_ic_blocking compute loops. Number of required registers
- is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs.
- ur_w must be divisible by stride_w */
- if (jcp.stride_w + 1 > max_regs) /* Minimal possible registers
- distribution exceeds max_regs */
- return status::unimplemented;
-
- int best_nfmas = 0;
- for (int b = 1; b <= 4; b++)
- {
- if (jcp.nb_ic % b != 0)
- continue;
-
- for (int u = jcp.stride_w;
- u * b + u / jcp.stride_w <= max_regs && u < jcp.iw + jcp.stride_w;
- u += jcp.stride_w)
- {
- int ur_w = nstl::min(u, jcp.iw);
- /* maximum 1 step with l_overflow so far */
- if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw)
- continue;
- int nfmas = utils::div_up(ur_w, jcp.stride_w) * b;
- if (nfmas > best_nfmas
- || (nfmas == best_nfmas && jcp.ur_w < ur_w)) {
- jcp.ur_w = ur_w;
- jcp.nb_ic_blocking = b;
- best_nfmas = nfmas;
- }
- }
- }
- if (best_nfmas == 0) /* can't find appropriate blocking */
- return status::unimplemented;
-
- jcp.ur_w_tail = jcp.iw % jcp.ur_w;
-
- int r_overflow_no_tail = nstl::max(0, (jcp.kw - 1 - jcp.ur_w_tail
- - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
- /* maximum 1 ur_w block with r_overflow so far */
- if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w)
- return status::unimplemented;
-
- if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
- return status::unimplemented;
-
- return status::success;
-}
-
-void jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(
- memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
- UNUSED(scratchpad);
- UNUSED(jcp);
-}
-
-void jit_avx2_conv_bwd_weights_kernel_f32::generate() {
- this->preamble();
-
- mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
- mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
- mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
- compute_oh_loop_common();
- this->postamble();
-}
-
-status_t jit_avx2_conv_bwd_weights_kernel_f32::init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &diff_weights_d,
- const memory_desc_wrapper &diff_dst_d) {
- if (!mayiuse(avx2)) return status::unimplemented;
-
- const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
- int ndims = src_d.ndims();
- jcp.ndims = ndims;
-
- jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
- jcp.mb = src_d.dims()[0];
-
- jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
- jcp.oc_without_padding = jcp.oc;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
-
- jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
- jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
- jcp.iw = src_d.dims()[ndims-1];
- jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
- jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
- jcp.ow = diff_dst_d.dims()[ndims-1];
-
- jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
- jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2];
- jcp.kw = diff_weights_d.dims()[with_groups + ndims-1];
-
- jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
- jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
- jcp.l_pad = cd.padding[0][ndims-3];
-
- jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
- jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
- jcp.stride_w = cd.strides[ndims-3];
-
- jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
- jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
- jcp.dilate_w = cd.dilates[ndims-3];
-
- if (ndims == 3) {
- jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c);
- jcp.wei_tag = diff_weights_d.matches_one_of_tag(
- Owi8o, gOwi8o, OIw8i8o, gOIw8i8o);
- jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCw8c);
- } else if (ndims == 4) {
- jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c);
- jcp.wei_tag = diff_weights_d.matches_one_of_tag(
- Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o);
- jcp.dst_tag = diff_dst_d.matches_one_of_tag(nChw8c);
- } else if (ndims == 5) {
- jcp.src_tag = src_d.matches_one_of_tag(ncdhw, ndhwc, nCdhw8c);
- jcp.wei_tag = diff_weights_d.matches_one_of_tag(
- Odhwi8o, gOdhwi8o, OIdhw8i8o, gOIdhw8i8o);
- jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCdhw8c);
- }
- jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
-
- const bool flat = jcp.ic == 3;
- const bool mimo = !flat;
-
- const int simd_w = 8;
-
- jcp.b_pad = nstl::max(
- 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
- jcp.r_pad = nstl::max(
- 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
-
- int back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + jcp.kd - jcp.id
- - jcp.f_pad);
- if (ndims == 5)
- if (jcp.f_pad != 0 || back_pad != 0)
- return status::unimplemented;
-
- const int max_h_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1);
- const int max_w_pad = ((jcp.kw - 1) * (jcp.dilate_w + 1) + 1);
- const bool boundaries_ok = true
- && jcp.t_pad < max_h_pad && jcp.b_pad < max_h_pad
- && jcp.l_pad < max_w_pad && jcp.r_pad < max_w_pad;
- if (!boundaries_ok)
- return status::unimplemented;
-
- bool ok_to_pad_channels = true
- && jcp.ngroups == 1;
-
- if (ok_to_pad_channels) {
- jcp.oc = rnd_up(jcp.oc, simd_w);
- if (mimo)
- jcp.ic = rnd_up(jcp.ic, simd_w);
- }
-
- bool args_ok = true
- && IMPLICATION(flat, true
- && one_of(jcp.src_tag, ncw, nwc, nchw, nhwc, ncdhw, ndhwc)
- && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o, Odhwi8o,
- gOdhwi8o))
- && IMPLICATION(mimo, true
- && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c)
- && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o,
- OIdhw8i8o, gOIdhw8i8o))
- && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c)
- && IMPLICATION(mimo, jcp.ic % simd_w == 0)
- && jcp.oc % simd_w == 0
- && jcp.kw < 14
- && jcp.kh <= jcp.t_pad + jcp.ih /* [bwd_w:r1] */
- && jcp.kh <= jcp.ih /* [bwd_w:r2] */
- && jcp.kd <= jcp.f_pad + jcp.id
- && jcp.kd <= jcp.id
- && jcp.t_pad < jcp.kh /* XXX: must fix the kernel! */
- && jcp.dilate_d == 0
- && jcp.dilate_h == 0
- && jcp.dilate_w == 0;
- if (!args_ok) return status::unimplemented;
-
- jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w;
- jcp.nb_ic = jcp.ic / jcp.ic_block;
-
- jcp.oc_block = simd_w;
- jcp.nb_oc = jcp.oc / jcp.oc_block;
- jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
-
- return status::success;
-}
-
-void jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(
- memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
- if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
- scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
-}
-
-inline void jit_avx2_conv_bwd_weights_kernel_f32::od_step_comeback_pointers()
-{
- Label kd_comeback_loop;
- mov(kj, jcp.kd); //FIXME (Anton): this works only if f_pad = back_pad = 0
- L(kd_comeback_loop); {
- const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw)
- ? 1 : jcp.ic_block;
- sub(aux_reg_input, sizeof(float) * jcp.iw * jcp.ih * inp_mult);
- sub(aux_reg_kernel, sizeof(float) * jcp.kw * jcp.kh * jcp.ic_block
- * jcp.oc_block);
- dec(kj);
- cmp(kj, 0);
- jg(kd_comeback_loop, T_NEAR);
- }
-}
-
-inline void jit_avx2_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers()
-{
- mov(kj, reg_kh);
- Label kh_comeback_loop;
- L(kh_comeback_loop); {
- const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw)
- ? 1 : jcp.ic_block;
- sub(reg_input, sizeof(float) * jcp.iw * inp_mult);
- sub(reg_kernel, sizeof(float) * jcp.kw * jcp.ic_block * jcp.oc_block);
- dec(kj);
- cmp(kj, 0);
- jg(kh_comeback_loop, T_NEAR);
- }
-}
-
-inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_ic_block_step(
- int ur_w, int pad_l, int pad_r, int ic_block_step, int input_offset,
- int kernel_offset, int output_offset)
-{
- const int kw = jcp.kw;
- const int ic_block = jcp.ic_block;
- const int oc_block = jcp.oc_block;
- for (int i_kw = 0; i_kw < kw; i_kw++)
- for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
- size_t off
- = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block
- + kernel_offset;
- vmovups(Ymm(i_kw * ic_block_step + i_ic), yword[reg_kernel + off]);
- }
-
- for (int i_ur = 0; i_ur < ur_w; i_ur++) {
- vmovups(Ymm(kw * ic_block_step + 0),
- yword[reg_output
- + sizeof(float) * i_ur * oc_block + output_offset]);
-
- for (int i_kw = 0; i_kw < kw; i_kw++) {
- int i_iw = i_ur * jcp.stride_w + i_kw;
- if (i_iw - pad_l < 0
- || i_iw > (ur_w - 1) * jcp.stride_w + kw - 1 - pad_r)
- continue;
- for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
- size_t i_off = (size_t)input_offset + sizeof(float)*(
- one_of(jcp.src_tag, ncw, nchw, ncdhw)
- ? (i_iw - pad_l) + i_ic
- * ((size_t)jcp.id * jcp.ih * jcp.iw)
- : (i_iw - pad_l) * ic_block + i_ic);
- vbroadcastss(Ymm(kw * ic_block_step + 1),
- make_safe_addr(reg_input, i_off, reg_long_offt));
- vfmadd231ps(Ymm(i_kw * ic_block_step + i_ic),
- Ymm(kw * ic_block_step + 0),
- Ymm(kw * ic_block_step + 1));
- }
- }
- }
-
- for (int i_kw = 0; i_kw < kw; i_kw++)
- for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
- size_t off
- = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block
- + kernel_offset;
- vmovups(yword[reg_kernel + off],
- Ymm(i_kw * ic_block_step + i_ic));
- }
-}
-
-inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_disp()
-{
- int ic_block_step;
- if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) {
- ic_block_step = jcp.kw >= 5 ? 1 : jcp.ic_block;
- } else {
- ic_block_step = jcp.kw > 7 ? 1
- : jcp.kw > 3 ? 2
- : jcp.kw > 1 ? 4 : 8;
- }
-
- const int max_ur_w = jcp.ow > 56 ? 14 : 28;
-
- if (jcp.ow <= max_ur_w)
- compute_oh_step_unroll_ow(ic_block_step, max_ur_w);
- else
- compute_oh_step_common(ic_block_step, max_ur_w);
-
- if (jcp.ndims == 5) {
- od_step_comeback_pointers();
- mov(reg_input, aux_reg_input);
- mov(reg_kernel, aux_reg_kernel);
- } else {
- oh_step_comeback_pointers();
- }
-}
-
-inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_unroll_ow(
- int ic_block_step, int max_ur_w)
-{
- UNUSED(max_ur_w);
-
- const int ic_block = jcp.ic_block;
- const int oc_block = jcp.oc_block;
- int inp_mul = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : jcp.ic_block;
- Label kd_loop;
-
- const int r_pad
- = nstl::max(0,
- (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
-
- if (jcp.ndims == 5) {
- mov(aux_reg_input, reg_input);
- mov(aux_reg_kernel, reg_kernel);
- mov(ki, jcp.kd);
- L(kd_loop);
- mov(reg_input, aux_reg_input);
- mov(reg_kernel, aux_reg_kernel);
- }
-
- mov(kj, reg_kh);
- Label kh_loop;
- L(kh_loop); {
- xor_(b_ic, b_ic);
- Label ic_block_loop;
- L(ic_block_loop); {
- compute_ic_block_step(jcp.ow, jcp.l_pad, r_pad, ic_block_step, 0,
- 0, 0);
- size_t inp_icblk_stride = sizeof(float) * ic_block_step
- * (one_of(jcp.src_tag, ncw, nchw, ncdhw)
- ? jcp.id*jcp.ih*jcp.iw : 1);
- safe_add(reg_input, inp_icblk_stride, reg_long_offt);
- add(reg_kernel, sizeof(float) * ic_block_step * oc_block);
- add(b_ic, ic_block_step);
- cmp(b_ic, ic_block);
- jl(ic_block_loop, T_NEAR);
- }
- if(one_of(jcp.src_tag, ncw, nchw, ncdhw)) {
- size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block;
- safe_sub(reg_input, offt, reg_long_offt);
- add(reg_input, sizeof(float) * jcp.iw);
- } else {
- add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block);
- }
- add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block);
- dec(kj);
- cmp(kj, 0);
- jg(kh_loop, T_NEAR);
- }
-
- if (jcp.ndims == 5) {
- add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul);
- add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block
- * oc_block);
- dec(ki);
- cmp(ki, 0);
- jg(kd_loop, T_NEAR);
- }
-
-}
-
-inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_common(
- int ic_block_step, int max_ur_w)
-{
- const int ic_block = jcp.ic_block;
- const int oc_block = jcp.oc_block;
- const int stride_w = jcp.stride_w;
- int inp_mul = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : jcp.ic_block;
- Label kd_loop;
-
- const int r_pad = jcp.r_pad;
-
- int ur_w = nstl::min(jcp.ow, max_ur_w);
- int ur_w_trips = jcp.ow / ur_w;
- int ur_w_tail = jcp.ow % ur_w;
- if ((ur_w_tail == 0 && r_pad != 0) || r_pad >= ur_w_tail) {
- if (ur_w_trips > 1) {
- ur_w_tail += ur_w;
- ur_w_trips--;
- } else {
- ur_w_tail += (ur_w - ur_w / 2);
- ur_w = ur_w / 2;
- }
- }
- const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : ic_block;
-
- int input_comeback = (ur_w_trips * ur_w * stride_w - jcp.l_pad) * inp_mult;
- int output_comeback = ur_w_trips * ur_w * oc_block;
-
- if (jcp.ndims == 5) {
- mov(aux_reg_input, reg_input);
- mov(aux_reg_kernel, reg_kernel);
- mov(ki, jcp.kd);
- L(kd_loop);
- mov(reg_input, aux_reg_input);
- mov(reg_kernel, aux_reg_kernel);
- }
-
- mov(kj, reg_kh);
- Label kh_loop;
- L(kh_loop); {
- xor_(b_ic, b_ic);
- Label ic_block_loop;
- L(ic_block_loop); {
- if (jcp.l_pad != 0) {
- ur_w_trips--;
- compute_ic_block_step(ur_w,
- jcp.l_pad, 0, ic_block_step, 0, 0, 0);
- add(reg_input, sizeof(float)
- * (ur_w * stride_w - jcp.l_pad) * inp_mult);
- add(reg_output, sizeof(float) * ur_w * oc_block);
- }
-
- if (ur_w_trips > 0) {
- xor_(reg_ur_w_trips, reg_ur_w_trips);
- Label ow_block_loop;
- L(ow_block_loop); {
- compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0);
- add(reg_input, sizeof(float) * ur_w * stride_w * inp_mult);
- add(reg_output, sizeof(float) * ur_w * oc_block);
-
- inc(reg_ur_w_trips);
- cmp(reg_ur_w_trips, ur_w_trips);
- jl(ow_block_loop, T_NEAR);
- }
- }
-
- if (ur_w_tail > 0)
- compute_ic_block_step(ur_w_tail,
- 0, r_pad, ic_block_step, 0, 0, 0);
-
- sub(reg_input, sizeof(float) * input_comeback);
- sub(reg_output, sizeof(float) * output_comeback);
-
- size_t inp_icblk_stride = sizeof(float) * ic_block_step
- * (one_of(jcp.src_tag, ncw, nchw, ncdhw)
- ? jcp.id*jcp.ih*jcp.iw : 1);
- safe_add(reg_input, inp_icblk_stride, reg_long_offt);
- add(reg_kernel, sizeof(float) * ic_block_step * oc_block);
-
- add(b_ic, ic_block_step);
- cmp(b_ic, jcp.ic_block);
- jl(ic_block_loop, T_NEAR);
- }
- if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) {
- size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block;
- safe_sub(reg_input, offt, reg_long_offt);
- add(reg_input, sizeof(float) * jcp.iw);
- } else {
- add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block);
- }
- add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block);
- dec(kj);
- cmp(kj, 0);
- jg(kh_loop, T_NEAR);
- }
-
- if (jcp.ndims == 5) {
- add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul);
- add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block
- * oc_block);
- dec(ki);
- cmp(ki, 0);
- jg(kd_loop, T_NEAR);
- }
-
-}
-
-inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_loop_common()
-{
- const int icoc_block = jcp.ic_block * jcp.oc_block;
- const int t_pad = jcp.t_pad;
- const int stride_h = jcp.stride_h;
- const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw)
- ? 1 : jcp.ic_block;
- int b_pad = jcp.b_pad;
-
- Label oh_tpad_loop, oh_loop, oh_loop_end;
-
- mov(reg_kh, jcp.kh);
- xor_(reg_ih_count, reg_ih_count);
- xor_(reg_oj, reg_oj);
- if (t_pad > 0) {
- assert(jcp.kh <= t_pad + jcp.ih); /* [bwd_w:r1] */
- mov(reg_kh, jcp.kh <= t_pad + jcp.ih ? jcp.kh - t_pad : jcp.ih);
- add(reg_kernel, sizeof(float) * t_pad * jcp.kw * icoc_block);
-
- L(oh_tpad_loop); {
- compute_oh_step_disp();
- add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block);
- sub(reg_kernel, sizeof(float) * stride_h * jcp.kw * icoc_block);
-
- inc(reg_oj);
- add(reg_ih_count, stride_h);
- add(reg_kh, stride_h);
-
- /* the overlap between input and kernel may not reach kernel size.
- * so far we do not support that (until we put constant here) */
- const int final_inp_ker_overlap = jcp.kh; /* [bwd_w:r2] */
- cmp(reg_kh, final_inp_ker_overlap);
- jl(oh_tpad_loop, T_NEAR);
- }
-
- if (t_pad % stride_h != 0) {
- int inp_corr = stride_h - t_pad % stride_h;
- add(reg_kernel, sizeof(float) * inp_corr * jcp.kw * icoc_block);
- add(reg_input, sizeof(float) * inp_corr * jcp.iw * inp_mult);
- }
- }
- cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1);
- jge(oh_loop_end, T_NEAR);
- cmp(reg_oj, jcp.oh);
- jge(oh_loop, T_NEAR);
-
- mov(reg_kh, jcp.kh);
- L(oh_loop); {
- compute_oh_step_disp();
- add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult);
- add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block);
-
- inc(reg_oj);
- add(reg_ih_count, stride_h);
-
- cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1);
- jge(oh_loop_end, T_NEAR);
-
- cmp(reg_oj, jcp.oh);
- jl(oh_loop, T_NEAR);
- }
- L(oh_loop_end);
- if (b_pad > 0) {
- Label oh_bpad_loop, oh_bpad_loop_end;
- cmp(reg_oj, jcp.oh);
- jge(oh_bpad_loop_end, T_NEAR);
-
- mov(reg_kh, jcp.ih + t_pad);
- sub(reg_kh, reg_ih_count);
- L(oh_bpad_loop); {
- compute_oh_step_disp();
- add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult);
- add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block);
-
- sub(reg_kh, stride_h);
- cmp(reg_kh, 0);
- jle(oh_bpad_loop_end, T_NEAR);
-
- inc(reg_oj);
- cmp(reg_oj, jcp.oh);
- jl(oh_bpad_loop, T_NEAR);
- }
- L(oh_bpad_loop_end);
- }
-}
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp
deleted file mode 100644
index 412c50c9ee..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp
+++ /dev/null
@@ -1,225 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef JIT_AVX2_CONV_KERNEL_F32_HPP
-#define JIT_AVX2_CONV_KERNEL_F32_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-
-#include "cpu_memory.hpp"
-#include "jit_generator.hpp"
-#include "jit_primitive_conf.hpp"
-#include "jit_uni_eltwise.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct jit_avx2_conv_fwd_kernel_f32: public jit_generator {
- jit_avx2_conv_fwd_kernel_f32(jit_conv_conf_t ajcp,
- const primitive_attr_t &attr)
- : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr)
- {
- if (jcp.with_eltwise)
- eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx2>(this,
- jcp.eltwise);
-
- this->generate();
- jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
- }
-
- ~jit_avx2_conv_fwd_kernel_f32() {
- delete eltwise_injector_;
- }
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_fwd_kernel_f32)
-
- static bool post_ops_ok(jit_conv_conf_t &jcp,
- const primitive_attr_t &attr);
- static status_t init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr);
- static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
- const jit_conv_conf_t &jcp);
-
- jit_conv_conf_t jcp;
- const primitive_attr_t &attr_;
- void (*jit_ker)(jit_conv_call_s *);
-
-private:
- using reg64_t = const Xbyak::Reg64;
- reg64_t reg_input = rax;
- reg64_t aux_reg_input = r8;
- reg64_t reg_kernel = rdx;
- reg64_t aux_reg_kernel = r9;
- reg64_t reg_output = rsi;
- reg64_t reg_bias = rbx;
-
- reg64_t aux_reg_inp_d = r11;
- reg64_t aux_reg_ker_d = abi_not_param1;
-
- reg64_t reg_ki = rsi;
- reg64_t kj = r10;
- reg64_t oi_iter = r11;
- reg64_t ki_iter = r12;
- reg64_t reg_kh = abi_not_param1;
- reg64_t reg_oc_blocks = r14;
- reg64_t imm_addr64 = r15;
- reg64_t reg_long_offt = r15;
- Xbyak::Reg32 reg_ci_flag = r13d;
-
- Xbyak::Ymm ytmp = Xbyak::Ymm(14);
-
- jit_uni_eltwise_injector_f32<avx2> *eltwise_injector_;
-
- inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r,
- int oc_blocks);
- inline void oh_step_nopad(int ur_w, int pad_l, int pad_r,
- char pad_label, int oc_blocks, char oc_blocks_label);
- inline void width_blk_step(int ur_w, int pad_l, int pad_r,
- char pad_label, int oc_blocks, char oc_blocks_label);
- inline void solve_common(int oc_blocks, char oc_blocks_label);
-
- void generate();
-};
-
-struct jit_avx2_conv_bwd_data_kernel_f32: public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_data_kernel_f32)
-
- jit_avx2_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp)
- {
- this->generate();
- jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
- }
-
- static status_t init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &diff_dst_d);
- static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
- const jit_conv_conf_t &jcp);
-
- jit_conv_conf_t jcp;
- void (*jit_ker)(jit_conv_call_s *);
-
-private:
- using reg64_t = const Xbyak::Reg64;
-
- reg64_t reg_ddst = rax;
- reg64_t aux_reg_ddst = r8;
- reg64_t reg_kernel = rdx;
- reg64_t aux_reg_kernel = r10;
- reg64_t reg_dsrc = rsi;
- reg64_t aux_reg_ddst_oc_loop = rbx; // used in ndims < 5 case only
- reg64_t aux_reg_kernel_oc_loop = abi_not_param1; /* used in ndims < 5
- case only */
-
- reg64_t aux_reg_dst_d = r12; // used in ndims == 5 case only
- reg64_t aux_reg_ker_d = r14; // used in ndims == 5 case only
-
- reg64_t reg_ki = abi_not_param1; // used in ndims == 5 case only
- reg64_t kj = r11;
- reg64_t oi_iter = r12;
- reg64_t reg_kh = r14;
- reg64_t reg_channel = r13; // used in ndims < 5 case only
- reg64_t reg_channel_work = r9; // used in ndims < 5 case only
- reg64_t reg_long_offt = r15;
-
- inline void compute_loop(int ur_w, int l_overflow, int r_overflow);
-
- void generate();
-
- inline int get_iw_start(int ki, int l_overflow)
- {
- int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w
- + l_overflow * jcp.stride_w
- - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1);
- while (res < 0)
- res += jcp.stride_w;
-
- return res;
- }
-
- inline int get_iw_end(int ur_w, int ki, int r_overflow)
- {
- if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail))
- ur_w += nstl::min(0, jcp.r_pad); // remove negative padding
- int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w
- + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1);
- while (res < 0)
- res += jcp.stride_w;
-
- return ur_w - res;
- }
-};
-
-struct jit_avx2_conv_bwd_weights_kernel_f32: public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_weights_kernel_f32)
-
- jit_avx2_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp)
- {
- this->generate();
- jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
- }
-
- static status_t init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &diff_weights_d,
- const memory_desc_wrapper &diff_dst_d);
- static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
- const jit_conv_conf_t &jcp);
-
- jit_conv_conf_t jcp;
- void (*jit_ker)(jit_conv_call_s *);
-
-private:
- using reg64_t = const Xbyak::Reg64;
- reg64_t reg_input = rax;
- reg64_t reg_kernel = rdx;
- reg64_t reg_output = rsi;
- reg64_t b_ic = abi_not_param1;
- reg64_t kj = r8;
- reg64_t reg_kh = r9;
- reg64_t reg_ur_w_trips = r10;
- reg64_t reg_tmp = r11;
- reg64_t reg_oj = r15;
- reg64_t reg_ih_count = rbx;
- reg64_t aux_reg_input = r12;
- reg64_t aux_reg_kernel = r13;
- reg64_t ki = r14;
- reg64_t reg_long_offt = r11;
-
- inline void od_step_comeback_pointers();
- inline void oh_step_comeback_pointers();
- inline void compute_ic_block_step(int ur_w, int pad_l, int pad_r,
- int ic_block_step, int input_offset, int kernel_offset,
- int output_offset);
- inline void compute_oh_step_disp();
- inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w);
- inline void compute_oh_step_common(int ic_block_step, int max_ur_w);
- inline void compute_oh_loop_common();
-
- void generate();
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp
deleted file mode 100644
index 13f61e84fe..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp
+++ /dev/null
@@ -1,410 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "jit_avx2_convolution.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::memory_tracking::names;
-using namespace mkldnn::impl::utils;
-
-#define src_blk_off(f, n, c, d, h, w) \
- (pd()->ndims() == 3) \
- ? (f).blk_off(n, c, w) \
- : (pd()->ndims() == 4) \
- ? (f).blk_off(n, c, h, w) \
- : (f).blk_off(n, c, d, h, w)
-
-#define wht_blk_off_(f, g, ...) \
- pd()->with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__)
-#define wht_blk_off(f, g, oc, ic, kd, kh, kw) \
- (pd()->ndims() == 3) \
- ? wht_blk_off_(f, g, oc, ic, kw) \
- : (pd()->ndims() == 4) \
- ? wht_blk_off_(f, g, oc, ic, kh, kw) \
- : wht_blk_off_(f, g, oc, ic, kd, kh, kw)
-
-void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
- const memory_desc_wrapper bias_d(pd()->weights_md(1));
-
- const auto &jcp = kernel_->jcp;
-
- int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
- const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.od
- * jcp.oh;
-
- auto ker = [&](const int ithr, const int nthr) {
- size_t start{0}, end{0};
- balance211(work_amount, nthr, ithr, start, end);
-
- int icbb = 0;
- while (icbb < jcp.nb_ic) {
- int icb_step = jcp.nb_ic_blocking;
- int icb_step_rem = jcp.nb_ic - icbb;
- if (icb_step_rem < jcp.nb_ic_blocking_max)
- icb_step = icb_step_rem;
-
- size_t n{0}, g{0}, ocbb{0}, oh{0}, od{0};
- nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work,
- od, jcp.od, oh, jcp.oh);
- for (size_t iwork = start; iwork < end; ++iwork) {
- int ocb = ocbb * jcp.nb_oc_blocking;
- int ocb_num = jcp.nb_oc_blocking;
-
- for (int icb = icbb; icb < icbb + icb_step; ++icb) {
- auto par_conv = jit_conv_call_s();
-
- const int ij = oh * jcp.stride_h;
- const int i_t_overflow = nstl::max(0, jcp.t_pad - ij);
- const int i_b_overflow = nstl::max(jcp.ih, ij
- + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih;
-
- const int dj = od * jcp.stride_d;
- const int d_t_overflow = nstl::max(0, jcp.f_pad - dj);
- const int d_b_overflow = nstl::max(jcp.id, dj
- + (jcp.kd-1) * (jcp.dilate_d+1) - jcp.f_pad+1) - jcp.id;
-
- const size_t _oc = g * jcp.nb_oc + ocb;
- const size_t _ic = g * jcp.nb_ic * jcp.nonblk_group_off + icb;
-
- const int ih = nstl::max(ij - jcp.t_pad
- + div_up(i_t_overflow,
- (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0);
-
- const int id = nstl::max(dj - jcp.f_pad
- + div_up(d_t_overflow,
- (jcp.dilate_d+1)) * (jcp.dilate_d + 1), 0);
-
- par_conv.src = &src[src_blk_off(src_d, n,
- jcp.ic == 3 ? 0 : _ic, id, ih, 0)];
-
- par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, od, oh, 0)];
-
- const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
- const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1));
- par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb,
- jcp.ic == 3 ? 0 : icb, wd, wh, 0)];
-
- if (icb == 0) {
- if (bias)
- par_conv.bias =
- &bias[bias_d.blk_off(_oc * jcp.oc_block)];
- par_conv.flags |= FLAG_IC_FIRST;
- }
-
- if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) {
- par_conv.flags |= FLAG_IC_LAST;
- }
-
- par_conv.oc_blocks =
- nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb;
-
- par_conv.kw_padding = 0;
- const int kh_padding = jcp.kh
- - div_up(i_t_overflow, (jcp.dilate_h + 1))
- - div_up(i_b_overflow, (jcp.dilate_h + 1));
- par_conv.kh_padding = nstl::max(0, kh_padding);
-
- const int kd_padding = jcp.kd
- - div_up(d_t_overflow, (jcp.dilate_d + 1))
- - div_up(d_b_overflow, (jcp.dilate_d + 1));
- par_conv.kd_padding = nstl::max(0, kd_padding);
-
- kernel_->jit_ker(&par_conv);
- }
- nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work,
- od, jcp.od, oh, jcp.oh);
- }
- icbb += icb_step;
- }
- };
-
- if (pd()->wants_padded_bias()) {
- auto padded_bias = scratchpad(ctx).get<data_t>(key_conv_padded_bias);
- utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
- utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
- jcp.oc - jcp.oc_without_padding);
- bias = padded_bias;
- }
-
- parallel(0, ker);
-
- if (pd()->wants_zero_pad_dst())
- ctx.memory(MKLDNN_ARG_DST)->zero_pad();
-}
-
-void jit_avx2_convolution_bwd_data_t::execute_backward_data(
- const exec_ctx_t &ctx) const {
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
-
- const auto &jcp = kernel_->jcp;
-
- int icb_work = jcp.nb_ic / jcp.nb_ic_blocking;
- int ih_block_size = jcp.ih;
- int num_ih_blocks = utils::div_up(jcp.ih, ih_block_size);
- size_t work_amount = jcp.mb * jcp.ngroups * icb_work * num_ih_blocks;
- if (work_amount < (size_t)2 * mkldnn_get_max_threads()) {
- ih_block_size = 1;
- num_ih_blocks = utils::div_up(jcp.ih, ih_block_size);
- work_amount *= num_ih_blocks;
- }
-
- auto ker = [&](const int ithr, const int nthr) {
- size_t start{0}, end{0};
- balance211(work_amount, nthr, ithr, start, end);
-
- size_t n{0}, g{0}, icbb{0}, ihb{0};
- nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, icbb, icb_work,
- ihb, num_ih_blocks);
- for (size_t iwork = start; iwork < end; ++iwork) {
- for (int oc = 0; oc < jcp.nb_oc; oc += jcp.nb_oc_blocking)
- for (int id = 0; id < jcp.id; ++id) {
- auto par_conv = jit_conv_call_s();
-
- const int idp = jcp.id + 2 * jcp.f_pad;
- const int d_t_overflow = nstl::max(0,
- jcp.kd - 1 - id - jcp.f_pad);
- const int back_pad = idp - jcp.id - jcp.f_pad;
- const int d_b_overflow = nstl::max(0,
- jcp.kd - 1 - (jcp.id - 1 - id) - back_pad);
- const int od = id + jcp.f_pad - d_b_overflow;
-
- int ih_start = ihb * ih_block_size;
- int ih_end = nstl::min(jcp.ih, ih_start + ih_block_size);
- for (int ih = ih_start; ih < ih_end; ++ih) {
-
- const int i_t_overflow = nstl::max(0, (jcp.kh - 1
- - ih - jcp.t_pad) / jcp.stride_h);
- const int i_b_overflow = nstl::max(0, (jcp.kh - jcp.ih
- + ih - jcp.b_pad) / jcp.stride_h);
- int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1
- + jcp.b_pad - ih) % jcp.stride_h);
- int overflow_kh_lo = (ih + jcp.t_pad) % jcp.stride_h;
-
- par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow;
- par_conv.kh_padding = (overflow_kh_hi - overflow_kh_lo)
- / jcp.stride_h + 1 - i_t_overflow - i_b_overflow;
- par_conv.kw_padding = 0;
-
- const int k_lo = overflow_kh_lo
- + i_b_overflow * jcp.stride_h;
- const int oh = (ih + jcp.t_pad - k_lo) / jcp.stride_h;
-
- par_conv.src = &diff_src[src_blk_off(diff_src_d, n,
- /*jcp.ic == 3 ? 0 :*/
- g * jcp.nb_ic + jcp.nb_ic_blocking * icbb, id, ih, 0)];
- par_conv.dst = &diff_dst[src_blk_off(diff_dst_d,
- n, g * jcp.nb_oc + oc, od, oh, 0)];
- par_conv.filt = &weights[wht_blk_off(weights_d, g, oc,
- jcp.ic == 3 ? 0 : jcp.nb_ic_blocking * icbb,
- d_b_overflow, k_lo, 0)];
-
- par_conv.src_prf = nullptr;
- par_conv.dst_prf = nullptr;
- par_conv.filt_prf = nullptr;
- par_conv.channel = oc;
- par_conv.ch_blocks = nstl::min(jcp.nb_oc - oc,
- jcp.nb_oc_blocking);
-
- kernel_->jit_ker(&par_conv);
- }
- }
- nd_iterator_step(n, jcp.mb, g, jcp.ngroups, icbb, icb_work, ihb,
- num_ih_blocks);
- }
- };
-
- parallel(0, ker);
-}
-
-void jit_avx2_convolution_bwd_weights_t::execute_backward_weights(
- const exec_ctx_t &ctx) const {
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
- auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
-
- auto scratchpad = this->scratchpad(ctx);
-
- data_t *diff_bias = pd()->wants_padded_bias()
- ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
-
- const auto &jcp = kernel_->jcp;
-
- auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
- prefix_reducer_bia);
- auto rb = this->reducer_bias_;
- rb->init(reducer_bia_scratchpad);
-
- auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad,
- prefix_reducer_wei);
- auto rw = this->reducer_weights_;
- rw->init(reducer_wei_scratchpad);
-
- auto ker = [&](int ithr, int nthr) {
- assert(nthr == rw->balancer().nthr_);
-
- const int w_job_start = rw->balancer().ithr_job_off(ithr);
- const int w_njobs = rw->balancer().ithr_njobs(ithr);
-
- if (w_njobs == 0) return;
-
- /* reduction dimension */
- int img_od_start{0}, img_od_end{0}, img{0}, od_s{0};
- balance211(jcp.mb * jcp.od, rw->balancer().nthr_per_group_,
- rw->balancer().id_in_group(ithr), img_od_start, img_od_end);
-
- int img_start = img_od_start, img_end = img_od_end;
- nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od);
- const int img_first = img;
-
- /* jobs */
- int g_start{0}, ocb_start{0}, icb_start{0};
- nd_iterator_init(w_job_start, g_start, jcp.ngroups, ocb_start,
- jcp.nb_oc, icb_start, jcp.nb_ic);
-
- while (img_start < img_end) {
- int g = g_start, ocb = ocb_start, icb = icb_start;
-
- const int work_rem = img_end - img_start;
- const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem;
- const int id_s = od_s * jcp.stride_d;
- const int idp = jcp.id + jcp.f_pad + jcp.back_pad;
-
- if (id_s < idp - jcp.back_pad - jcp.kd + 1)
- for (int w_job_loc = 0; w_job_loc < w_njobs; ++w_job_loc) {
- const size_t _oc = g * jcp.nb_oc + ocb;
- const size_t _ic = g * jcp.nb_ic + icb;
-
- /* TODO: put dw <-- 0 in kernel */
- if (img == img_first)
- array_set(rw->get_local_ptr(ithr, diff_weights,
- reducer_wei_scratchpad) +
- w_job_loc * rw->balancer().job_size_, 0,
- rw->balancer().job_size_);
-
- for (int od = od_s; od < od_e; ++od) {
- const int id = od * jcp.stride_d;
- if (id >= jcp.id - jcp.back_pad - jcp.kd + 1) break;
-
- auto par_conv = jit_conv_call_s();
- par_conv.src = &src[src_blk_off(src_d, img, _ic, id, 0, 0)];
- par_conv.dst =
- &diff_dst[src_blk_off(diff_dst_d, img, _oc, od, 0, 0)];
- par_conv.filt = rw->get_local_ptr(ithr, diff_weights,
- reducer_wei_scratchpad) +
- w_job_loc * rw->balancer().job_size_;
-
- kernel_->jit_ker(&par_conv);
- }
- nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc, icb,
- jcp.nb_ic);
- }
- nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od);
- }
- rw->reduce(ithr, diff_weights, reducer_wei_scratchpad);
- };
-
- auto ker_bias = [&](int ithr, int nthr) {
- assert(nthr == rb->balancer().nthr_);
-
- const int b_job_start = rb->balancer().ithr_job_off(ithr);
- const int b_njobs = rb->balancer().ithr_njobs(ithr);
-
- if (b_njobs == 0) return;
-
- /* reduction dimension */
- int img_start{0}, img_end{0};
- balance211(jcp.mb, rb->balancer().nthr_per_group_,
- rb->balancer().id_in_group(ithr), img_start, img_end);
-
- /* jobs */
- int g_start{0}, ocb_start{0};
- nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start,
- jcp.nb_oc);
-
- for (int img = img_start; img < img_end; ++img) {
- int g = g_start, ocb = ocb_start;
- for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
- const size_t _oc = g * jcp.nb_oc + ocb;
-
- const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
- data_t *d_bias = rb->get_local_ptr(ithr, diff_bias,
- reducer_bia_scratchpad) +
- b_job_loc * rb->balancer().job_size_;
-
- if (img == img_start)
- for (int o = 0; o < 8; ++o)
- d_bias[o] = 0.;
-
- for (int dhw = 0; dhw < jcp.od * jcp.oh * jcp.ow; ++dhw) {
- PRAGMA_OMP_SIMD()
- for (int o = 0; o < 8; ++o)
- d_bias[o] += d_dst[o];
- d_dst += 8;
- }
-
- nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc);
- }
- }
- rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
- };
-
- parallel(0, [&](const int ithr, const int nthr) {
- ker(ithr, nthr);
- if (pd()->with_bias())
- ker_bias(ithr, nthr);
- });
-
- /* TODO: put this in ker_bias */
- if (pd()->wants_padded_bias()) {
- assert(jcp.ngroups == 1);
- for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
- diff_bias_in[oc] = diff_bias[oc];
- }
-}
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp
deleted file mode 100644
index bb65bce79c..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp
+++ /dev/null
@@ -1,302 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_AVX2_CONVOLUTION_HPP
-#define CPU_JIT_AVX2_CONVOLUTION_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "mkldnn_thread.hpp"
-#include "utils.hpp"
-
-#include "cpu_convolution_pd.hpp"
-#include "cpu_reducer.hpp"
-
-#include "jit_avx2_conv_kernel_f32.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct jit_avx2_convolution_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_fwd_pd_t {
- pd_t(engine_t *engine,
- const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const typename pd_t::base_class *hint_fwd_pd)
- : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit:", avx2, ""),
- jit_avx2_convolution_fwd_t);
-
- status_t init() {
- bool ok = true
- && is_fwd()
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::f32, data_type::f32, data_type::f32)
- && !has_zero_dim_memory()
- && set_default_formats();
- if (!ok) return status::unimplemented;
-
- status_t status = jit_avx2_conv_fwd_kernel_f32::init_conf(jcp_,
- *desc(), src_md(), weights_md(), dst_md(), *attr());
- if (status != status::success) return status;
-
- auto scratchpad = scratchpad_registry().registrar();
- jit_avx2_conv_fwd_kernel_f32::init_scratchpad(scratchpad, jcp_);
-
- return status::success;
- }
-
- jit_conv_conf_t jcp_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
-
- const bool flat = IC() < 8;
- auto src_tag = flat
- ? utils::pick(ndims() - 3, ncw, nchw, ncdhw)
- : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
- auto dst_tag =
- utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
- auto wei_tag = with_groups()
- ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o,
- gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o)
- : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o,
- OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o);
-
- return set_default_formats_common(src_tag, wei_tag, dst_tag);
- }
- };
-
- jit_avx2_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd)
- { kernel_ = new jit_avx2_conv_fwd_kernel_f32(pd()->jcp_, *pd()->attr()); }
- ~jit_avx2_convolution_fwd_t() { delete kernel_; }
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_avx2_conv_fwd_kernel_f32 *kernel_;
-};
-
-struct jit_avx2_convolution_bwd_data_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_bwd_data_pd_t {
- pd_t(engine_t *engine,
- const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_()
- {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit:", avx2, ""),
- jit_avx2_convolution_bwd_data_t);
-
- status_t init() {
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_data
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::undef, data_type::f32, data_type::f32)
- && !has_zero_dim_memory()
- && set_default_formats();
- if (!ok) return status::unimplemented;
-
- status_t status = jit_avx2_conv_bwd_data_kernel_f32::init_conf(
- jcp_, *desc(), *diff_src_md(), *weights_md(),
- *diff_dst_md());
- if (status != status::success) return status;
-
- auto scratchpad = scratchpad_registry().registrar();
- jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(scratchpad,
- jcp_);
-
- return status::success;
- }
-
- jit_conv_conf_t jcp_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
-
- auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
- auto wei_tag = with_groups()
- ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i)
- : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i, OIdhw8o8i);
-
- return set_default_formats_common(dat_tag, wei_tag, dat_tag);
- }
- };
-
- jit_avx2_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd)
- { kernel_ = new jit_avx2_conv_bwd_data_kernel_f32(pd()->jcp_); }
- ~jit_avx2_convolution_bwd_data_t() { delete kernel_; }
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward_data(ctx);
- return status::success;
- }
-
-private:
- void execute_backward_data(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_avx2_conv_bwd_data_kernel_f32 *kernel_;
-};
-
-struct jit_avx2_convolution_bwd_weights_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_bwd_weights_pd_t {
- pd_t(engine_t *engine, const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit:", avx2, ""),
- jit_avx2_convolution_bwd_weights_t);
-
- status_t init() {
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_weights
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::f32, data_type::f32, data_type::f32)
- && !has_zero_dim_memory()
- && set_default_formats();
- if (!ok) return status::unimplemented;
-
- status_t status = jit_avx2_conv_bwd_weights_kernel_f32::init_conf(
- jcp_, *desc(), *src_md(), *diff_weights_md(),
- *diff_dst_md());
- if (status != status::success) return status;
-
- init_balancers();
-
- auto scratchpad = scratchpad_registry().registrar();
- jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(scratchpad,
- jcp_);
-
- auto reducer_bia_scratchpad = memory_tracking::registrar_t(
- scratchpad, memory_tracking::names::prefix_reducer_bia);
- reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
-
- auto reducer_wei_scratchpad = memory_tracking::registrar_t(
- scratchpad, memory_tracking::names::prefix_reducer_wei);
- reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad);
-
- return status::success;
- }
-
- jit_conv_conf_t jcp_;
- cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
- cpu_reducer_t<data_type::f32>::conf_t reducer_wei_conf_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
- const bool flat = IC() == 3;
-
- auto src_tag = flat
- ? utils::pick(ndims() - 3, ncw, nchw, ncdhw)
- : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
- auto dst_tag =
- utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
- auto wei_tag = with_groups()
- ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o,
- gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o)
- : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o,
- OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o);
-
- return set_default_formats_common(src_tag, wei_tag, dst_tag);
- }
-
- private:
- void init_balancers() {
- const int max_threads = mkldnn_get_max_threads();
- const size_t max_buffer_size = 1<<21; /* just a heuristic */
-
- if(with_bias()) {
- reducer_bia_conf_.init(reduce_balancer_t(max_threads,
- jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb,
- max_buffer_size));
- }
-
- reducer_wei_conf_.init(reduce_balancer_t(max_threads,
- jcp_.kd * jcp_.kh * jcp_.kw
- * jcp_.ic_block * jcp_.oc_block,
- jcp_.ngroups * jcp_.nb_ic * jcp_.nb_oc,
- jcp_.mb * jcp_.od, max_buffer_size));
- }
- };
-
- jit_avx2_convolution_bwd_weights_t(const pd_t *apd)
- : cpu_primitive_t(apd)
- , kernel_(nullptr)
- , reducer_weights_(nullptr)
- , reducer_bias_(nullptr)
- {
- kernel_ = new jit_avx2_conv_bwd_weights_kernel_f32(pd()->jcp_);
- reducer_bias_ =
- new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_);
- reducer_weights_ =
- new cpu_reducer_t<data_type::f32>(pd()->reducer_wei_conf_);
- }
-
- ~jit_avx2_convolution_bwd_weights_t() {
- delete kernel_;
- delete reducer_weights_;
- delete reducer_bias_;
- }
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward_weights(ctx);
- return status::success;
- }
-
-private:
- void execute_backward_weights(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_avx2_conv_bwd_weights_kernel_f32 *kernel_;
- cpu_reducer_t<data_type::f32> *reducer_weights_, *reducer_bias_;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp
deleted file mode 100644
index 635b83b2bf..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp
+++ /dev/null
@@ -1,1255 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include <float.h>
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "mkldnn_thread.hpp"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_memory.hpp"
-#include "cpu_barrier.hpp"
-
-#include "jit_uni_1x1_conv_utils.hpp"
-#include "jit_avx512_common_1x1_conv_kernel.hpp"
-
-#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::format_tag;
-using namespace mkldnn::impl::prop_kind;
-using namespace mkldnn::impl::utils;
-
-using namespace Xbyak;
-
-void jit_avx512_common_1x1_conv_kernel::bcast_loop(int load_loop_blk)
-{
- mov(aux1_reg_bcast_data, reg_bcast_data);
- mov(aux_reg_bcast_data, reg_bcast_data);
-
- mov(aux_reg_output_data, reg_output_data);
- mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_offt));
-
- if (jcp.ver == ver_4fma)
- {
- Label bcast_loop;
- Label bcast_loop_wraparound;
- Label bcast_loop_out;
- Label bcast_loop_ur_full;
-
- cmp(bcast_loop_iter, jcp.ur);
- jle(bcast_loop_wraparound, T_NEAR);
-
- L(bcast_loop); {
- assert(jcp.bcast_block % jcp.ur == 0);
- int num_substeps = jcp.bcast_block / jcp.ur;
- assert(num_substeps > 0 && num_substeps < 10);
- for (int i = 0; i < num_substeps; i++) {
- reduce_loop(load_loop_blk, jcp.ur, i, false);
- if (i < num_substeps - 1) {
- add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
- add(aux_reg_output_data, jcp.bcast_loop_output_substep);
- }
- else {
- add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
- - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
- add(aux_reg_output_data, jcp.bcast_loop_output_step
- - (num_substeps - 1) * jcp.bcast_loop_output_substep);
- }
- }
- sub(bcast_loop_iter, jcp.bcast_block);
- cmp(bcast_loop_iter, jcp.bcast_block);
- jg(bcast_loop, T_NEAR);
- }
-
- L(bcast_loop_wraparound);
- if (jcp.ur_tail) {
- je(bcast_loop_ur_full, T_NEAR);
- reduce_loop(load_loop_blk, jcp.ur_tail, 0, true);
- jmp(bcast_loop_out, T_NEAR);
- }
- L(bcast_loop_ur_full);
- reduce_loop(load_loop_blk, jcp.ur, 0, true);
- L(bcast_loop_out);
- }
- else
- {
- Label bcast_loop;
- Label bcast_loop_tail;
-
- cmp(bcast_loop_iter, jcp.ur);
- jl(bcast_loop_tail, T_NEAR);
-
- L(bcast_loop); {
- assert(jcp.bcast_block % jcp.ur == 0);
- int num_substeps = jcp.bcast_block / jcp.ur;
- assert(num_substeps > 0 && num_substeps < 10);
- for (int i = 0; i < num_substeps; i++) {
- reduce_loop(load_loop_blk, jcp.ur, i, false);
- if (i < num_substeps - 1) {
- add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
- add(aux_reg_output_data, jcp.bcast_loop_output_substep);
- }
- else {
- add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
- - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
- add(aux_reg_output_data, jcp.bcast_loop_output_step
- - (num_substeps - 1) * jcp.bcast_loop_output_substep);
- }
- }
- sub(bcast_loop_iter, jcp.bcast_block);
- cmp(bcast_loop_iter, jcp.bcast_block);
- jge(bcast_loop, T_NEAR);
- }
-
- L(bcast_loop_tail);
- if (jcp.ur_tail) {
- Label bcast_loop_tail_out;
- cmp(bcast_loop_iter, 0);
- jz(bcast_loop_tail_out, T_NEAR);
- reduce_loop(load_loop_blk, jcp.ur_tail, 0, true);
- L(bcast_loop_tail_out);
- }
- }
-}
-
-void jit_avx512_common_1x1_conv_kernel::reduce_loop(int load_loop_blk,
- int ur, int substep, bool wraparound)
-{
- auto vreg_load = [=](int i_load, int i_fma) {
- return Zmm(utils::rnd_up(ur * load_loop_blk, jcp.fma_step)
- + jcp.fma_step * i_load + i_fma);
- };
-
- auto vreg_accum = [=](int i_load, int i_ur) {
- return Zmm(i_ur * load_loop_blk + i_load);
- };
-
- auto bias_ptr = [=](int i_load) {
- return EVEX_compress_addr(reg_bias_data,
- jcp.typesize_out * jcp.oc_block * i_load);
- };
-
- auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) {
- assert(i_ur < jcp.ur);
- assert(i_reduce <= jcp.reduce_loop_unroll);
- int offt;
- if (one_of(jcp.prop_kind, forward_training, forward_inference,
- backward_data)) {
- assert(jcp.reduce_loop_unroll == jcp.reduce_block);
- offt = (i_reduce == jcp.reduce_loop_unroll)
- ? (jcp.bcast_dim + i_ur) * jcp.reduce_loop_unroll
- : i_ur * jcp.reduce_loop_unroll + i_reduce;
- } else {
- if (jcp.transpose_src) {
- const int reduce_group = i_reduce / 4;
- const int reduce_shift = i_reduce % 4;
- offt = 4 * (reduce_group * jcp.ic_block + i_ur) + reduce_shift;
- }
- else
- offt = i_reduce * jcp.ic_block + i_ur;
- }
- return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt,
- bcast);
- };
-
- auto load_ptr = [=](int i_reduce, int i_load) {
- int offt;
- int u0 = i_reduce % jcp.reduce_loop_unroll;
- int u1 = i_reduce / jcp.reduce_loop_unroll;
- offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block;
- return EVEX_compress_addr(aux_reg_load_data,
- u1 * jcp.reduce_loop_load_step
- + jcp.typesize_in * offt);
- };
-
- auto output_ptr = [=](int i_load, int i_ur) {
- if (one_of(jcp.prop_kind, forward_training, forward_inference,
- backward_data))
- return EVEX_compress_addr(aux_reg_output_data,
- (i_load * jcp.bcast_dim + i_ur) * jcp.load_block
- * jcp.typesize_out);
- else
- return ptr[aux_reg_output_data +
- (i_load
- ? reg_output_stride * i_load
- : 0) // TODO: Xbyak should allow 0 scale
- + jcp.typesize_out * jcp.load_block * i_ur];
- };
-
- auto init = [=]() {
- Label init_done;
- Label init_zero;
-
- if (jcp.with_sum) {
- for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
- for (int i_ur = 0; i_ur < ur; ++i_ur) {
- mic_prefetcht1(output_ptr(i_load, i_ur));
- }
- }
- }
-
- if (jcp.with_bias
- && one_of(jcp.prop_kind, forward_training, forward_inference)) {
- test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
- jz(init_zero, T_NEAR);
-
- for (int i_load = 0; i_load < load_loop_blk; i_load++)
- for (int i_ur = 0; i_ur < ur; ++i_ur)
- vmovups(vreg_accum(i_load, i_ur), bias_ptr(i_load));
- jmp(init_done, T_NEAR);
- }
-
- L(init_zero);
- for (int i_load = 0; i_load < load_loop_blk; ++i_load)
- for (int i_ur = 0; i_ur < ur; ++i_ur) {
- auto r = vreg_accum(i_load, i_ur);
- vpxord(r, r, r);
- }
- L(init_done);
- };
-
- auto store = [=]() {
- Label store_noadd;
- if (!jcp.with_sum) {
- test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
- jnz(store_noadd, T_NEAR);
- }
-
- for (int i_ur = 0; i_ur < ur; ++i_ur)
- for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
- auto r = vreg_accum(i_load, i_ur);
- vaddps(r, r, output_ptr(i_load, i_ur));
- }
-
- L(store_noadd);
- if (jcp.with_eltwise) {
- Label store_noeltwise;
- test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
- jz(store_noeltwise, T_NEAR);
-
- eltwise_injector_->compute_vector_range(0, ur * load_loop_blk);
-
- L(store_noeltwise);
- }
-
- auto store_output = [=](bool output_is_aligned) {
- for (int i_ur = 0; i_ur < ur; ++i_ur)
- for (int i_load = 0; i_load < load_loop_blk; ++i_load)
- if (output_is_aligned && jcp.use_vmovntps)
- vmovntps(output_ptr(i_load, i_ur),
- vreg_accum(i_load, i_ur));
- else
- vmovups(output_ptr(i_load, i_ur),
- vreg_accum(i_load, i_ur));
- };
-
- Label unaligned_store, end_store;
- test(aux_reg_output_data, cpu_isa_traits<avx512_common>::vlen - 1);
- jnz(unaligned_store, T_NEAR);
- store_output(true);
- jmp(end_store, T_NEAR);
- L(unaligned_store); {
- store_output(false);
- }
- L(end_store);
- };
-
- auto prefetch_callback = [=](int ur, int i_reduce, int i_ur, int i_load,
- bool last_block, bool wraparound, int reduce_step)
- {
- bool pf_ker_l1 = true;
- bool pf_ker_l2 = wraparound;
- int n_ops = (jcp.reduce_loop_unroll / reduce_step) * ur * load_loop_blk;
- int i_op = (i_reduce / reduce_step) * ur * load_loop_blk +
- i_ur * load_loop_blk + i_load;
-
- int n_pf_ker_l1 = pf_ker_l1 ? jcp.reduce_block : 0;
- int n_pf_ker_l2 = pf_ker_l2 && wraparound ? jcp.reduce_block : 0;
- int n_pf_out_l1 = jcp.use_vmovntps ? 0 : ur;
-
- int pf_inp_ops = n_ops / 2; // # of operations during which to pf input
- int pf_inp_trigger;
- if (jcp.prop_kind == backward_weights)
- pf_inp_trigger = nstl::max(1, pf_inp_ops / jcp.reduce_block);
- else
- pf_inp_trigger = nstl::max(1, pf_inp_ops / ur);
-
- int n_other_pf =
- load_loop_blk * (n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1);
- int n_other_pf_ops = n_ops - pf_inp_ops;
- int other_pf_trigger
- = n_other_pf ? nstl::max(1, n_other_pf_ops / n_other_pf) : 0;
-
- if (i_op < pf_inp_ops && i_op % pf_inp_trigger == 0) {
- // input prefetches have the highest priority b/c the
- // first iteration of the kernel block touches all the
- // cache lines
- int i_pf = i_op / pf_inp_trigger;
- auto pf_reg = wraparound && last_block
- ? reg_bcast_data
- : (last_block ? aux1_reg_bcast_data
- : aux_reg_bcast_data);
- int offt = i_pf;
- if (jcp.prop_kind == backward_weights) {
- offt += wraparound && last_block
- ? 0
- : (last_block ? jcp.is : jcp.reduce_block);
- offt *= jcp.bcast_block;
- } else {
- offt += wraparound && last_block
- ? 0
- : (last_block ? jcp.ur : jcp.bcast_dim);
- offt *= jcp.reduce_block;
- }
- mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]);
- } else if (i_op >= pf_inp_ops && n_other_pf) {
- // remaining prefetches are spread among the rest of the
- // operations; prefetches for output take priority
- // TODO: spread L2 prefetches among L1 prefetches
- i_op -= pf_inp_ops;
- if (i_op % other_pf_trigger == 0) {
- int i_pf = i_op / (load_loop_blk * other_pf_trigger);
- if (i_pf < n_pf_ker_l2) {
- int offt = (i_pf + (i_load + 1) * jcp.reduce_dim)
- * jcp.load_block;
- mic_prefetcht1(ptr[aux_reg_load_data
- + offt * jcp.typesize_in]);
- } else if (i_pf < n_pf_ker_l2 + n_pf_ker_l1) {
- i_pf -= n_pf_ker_l2;
- auto pf_reg = last_block ? reg_load_data
- : aux_reg_load_data;
- int offt = (i_pf + i_load * jcp.reduce_dim
- + (last_block
- ? (wraparound ? jcp.reduce_dim : 0)
- : jcp.reduce_block))
- * jcp.load_block;
- mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]);
- } else if (i_pf < n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1) {
- i_pf -= n_pf_ker_l1 + n_pf_ker_l2;
- int offt = i_pf * jcp.load_block;
- mic_prefetcht0(ptr[aux_reg_output_data
- + offt * jcp.typesize_out]);
- }
- }
- }
- };
-
- auto fma_block = [=](bool last_block) {
- assert(jcp.reduce_loop_unroll % jcp.fma_step == 0);
-
- int reduce_step = jcp.fma_step;
-
- for (int i_reduce = 0; i_reduce < jcp.reduce_loop_unroll;
- i_reduce += reduce_step) {
- for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
- // if transposed input data used and if spatial size is
- // not divided by transpose step (4) then for last reduce step
- // we should load only needed load_registers data
- // and clear remaining
- if (jcp.transpose_src && jcp.is % jcp.fma_step && last_block
- && i_reduce == jcp.reduce_loop_unroll - reduce_step) {
- Label load_all;
- Label load_finish;
- test(reg_reduce_pos_flag, FLAG_SP_LAST);
- jz(load_all, T_NEAR);
-
- const int n_loads = jcp.is % jcp.fma_step;
- for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) {
- if (i_fma < n_loads)
- vmovups(vreg_load(i_load, i_fma),
- load_ptr(i_reduce + i_fma, i_load));
- else
- vpxord(vreg_load(i_load, i_fma),
- vreg_load(i_load, i_fma),
- vreg_load(i_load, i_fma));
- }
- jmp(load_finish);
-
- L(load_all);
- for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) {
- vmovups(vreg_load(i_load, i_fma),
- load_ptr(i_reduce + i_fma, i_load));
- }
- L(load_finish);
- } else {
- for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) {
- vmovups(vreg_load(i_load, i_fma),
- load_ptr(i_reduce + i_fma, i_load));
- }
- }
- }
-
- for (int i_ur = 0; i_ur < ur; ++i_ur) {
- if (jcp.ver == ver_avx512_core && jcp.expl_bcast
- && load_loop_blk > 1)
- vbroadcastss(vreg_bcast, bcast_ptr(i_reduce, i_ur, false));
- for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
- if (jcp.ver == ver_4fma)
- v4fmaddps(vreg_accum(i_load, i_ur),
- vreg_load(i_load, 0),
- bcast_ptr(i_reduce, i_ur, false));
- else if (jcp.ver == ver_avx512_core && jcp.expl_bcast
- && load_loop_blk > 1)
- vfmadd231ps(vreg_accum(i_load, i_ur),
- vreg_load(i_load, 0), vreg_bcast);
- else
- vfmadd231ps(vreg_accum(i_load, i_ur),
- vreg_load(i_load, 0),
- bcast_ptr(i_reduce, i_ur, true));
- prefetch_callback(ur, i_reduce, i_ur, i_load,
- last_block, wraparound, reduce_step);
- }
- }
- }
- };
- Label reduce_loop;
- Label reduce_loop_tail;
-
- mov(aux_reg_load_data, reg_load_data);
-
- mov(aux_reg_bcast_data, aux1_reg_bcast_data);
- init();
-
- mov(reduce_loop_iter, reg_reduce_loop_work);
- sub(reduce_loop_iter, jcp.reduce_loop_unroll);
- jle(reduce_loop_tail, T_NEAR);
-
- L(reduce_loop); {
- fma_block(false);
- add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
- add(aux_reg_load_data, jcp.reduce_loop_load_step);
- sub(reduce_loop_iter, jcp.reduce_loop_unroll);
- jg(reduce_loop, T_NEAR);
- }
-
- L(reduce_loop_tail);
- fma_block(true);
-
- store();
-}
-
-void jit_avx512_common_1x1_conv_kernel::generate()
-{
- preamble();
-
- mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
- mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
- mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
-
- sub(rsp, stack_space_needed);
-
- if (jcp.with_bias)
- mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
-
- mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
- mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
- mov(EVEX_compress_addr(rsp, bcast_loop_work_offt), reg_bcast_loop_work);
- mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
- mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
- if (one_of(jcp.prop_kind, forward_training, forward_inference))
- mov(reg_relu_ns, reinterpret_cast<size_t>(&jcp.eltwise.alpha));
- if (jcp.prop_kind == backward_weights)
- mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
-
- auto load_loop_body = [=](int load_loop_blk) {
- bcast_loop(load_loop_blk);
- add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
- switch (jcp.prop_kind) {
- case forward_training:
- case forward_inference:
- add(reg_bias_data,
- load_loop_blk * jcp.load_block * jcp.typesize_out);
- add(reg_output_data,
- load_loop_blk * jcp.bcast_dim * jcp.load_block *
- jcp.typesize_out);
- break;
- case backward_data:
- add(reg_output_data,
- load_loop_blk * jcp.bcast_dim * jcp.load_block *
- jcp.typesize_out);
- break;
- case backward_weights:
- for (int i_load = 0; i_load < load_loop_blk; i_load++)
- add(reg_output_data, reg_output_stride);
- break;
- default:
- assert(!"invalid prop_kind");
- }
- sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
- };
-
- const int simd_w = 16;
-
- Label load_loop_blk[7];
-
- static const int ur_cases_fma_embd_bcast[] = { 2, 4, 5, 8, 14, 32 };
- static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 };
- static const int ur_cases_4fma[] = { 2, 4, 6, 12, 32 };
-
- const int size_ur_cases_fma
- = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ?
- sizeof(ur_cases_fma_expl_bcast) :
- sizeof(ur_cases_fma_embd_bcast);
- const int size_ur_cases_4fma = sizeof(ur_cases_4fma);
-
- const int *ur_cases_fma = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ?
- ur_cases_fma_expl_bcast :
- ur_cases_fma_embd_bcast;
- const int *ur_cases = jcp.ver == ver_4fma ? ur_cases_4fma : ur_cases_fma;
- const int num_ur_cases =
- (jcp.ver == ver_4fma ? size_ur_cases_4fma : size_ur_cases_fma)
- / sizeof(*ur_cases);
-
- for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) {
- int label_idx = num_ur_cases - ur_idx - 1;
- if (jcp.ur <= ur_cases[ur_idx]) {
- cmp(reg_load_loop_work, simd_w * (label_idx + 1));
- jle(load_loop_blk[label_idx], T_NEAR);
- }
- }
-
- for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) {
- if (jcp.ur <= ur_cases[ur_idx]) {
- int label_idx = num_ur_cases - ur_idx - 1;
- L(load_loop_blk[label_idx]);
- {
- if (label_idx == 0) {
- cmp(reg_load_loop_work, 0);
- je(load_loop_blk[num_ur_cases], T_NEAR);
- }
- load_loop_body(label_idx + 1);
- if (label_idx - 1 > 0) {
- cmp(reg_load_loop_work, 2 * label_idx * simd_w);
- je(load_loop_blk[label_idx - 1], T_NEAR);
- }
- cmp(reg_load_loop_work, (label_idx + 1) * simd_w);
- jge(load_loop_blk[label_idx]);
- }
- for (int idx = label_idx - 1; idx > 0; --idx) {
- cmp(reg_load_loop_work, simd_w * (idx + 1));
- je(load_loop_blk[idx], T_NEAR);
- }
- if (ur_idx < num_ur_cases - 2) {
- cmp(reg_load_loop_work, simd_w);
- jle(load_loop_blk[0], T_NEAR);
- }
- }
- }
- L(load_loop_blk[num_ur_cases]);
-
- add(rsp, stack_space_needed);
-
- postamble();
-
- if (jcp.with_eltwise)
- eltwise_injector_->prepare_table();
-}
-
-bool jit_avx512_common_1x1_conv_kernel::post_ops_ok(
- jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
- const auto &p = attr.post_ops_;
-
- auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
- auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
-
- switch (p.len_) {
- case 0: return true; // no post_ops
- case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
- case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
- default: return false;
- }
-
- return false;
-}
-
-status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr, int nthreads, bool reduce_src) {
- if (!mayiuse(avx512_common)) return status::unimplemented;
-
- const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
- const int simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
- const int ndims = src_d.ndims();
-
- jcp.prop_kind = cd.prop_kind;
-
- jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
- jcp.mb = src_d.dims()[0];
-
- jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
- jcp.oc = dst_d.dims()[1] / jcp.ngroups;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
-
- bool ok_to_pad_channels = true
- && jcp.ngroups == 1
- && src_d.data_type() == data_type::f32;
- if (ok_to_pad_channels) {
- jcp.oc = rnd_up(jcp.oc, simd_w);
- jcp.ic = rnd_up(jcp.ic, simd_w);
- }
-
- jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
- jcp.iw = src_d.dims()[ndims - 1];
- jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
- jcp.ow = dst_d.dims()[ndims - 1];
-
- jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
- jcp.kw = weights_d.dims()[with_groups + ndims - 1];
-
- jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
- jcp.l_pad = cd.padding[0][ndims - 3];
-
- jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
- jcp.stride_w = cd.strides[ndims - 3];
-
- jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format_kind,
- format_kind::undef, cd.diff_bias_desc.format_kind)
- != format_kind::undef;
-
- jcp.os = jcp.oh * jcp.ow;
- jcp.is = jcp.ih * jcp.iw;
- jcp.tr_is = rnd_up(jcp.is, 4);
-
- if (!post_ops_ok(jcp, attr))
- return status::unimplemented;
-
- const auto &p = attr.post_ops_;
- jcp.with_sum = p.find(primitive_kind::sum) != -1;
- const int eltwise_ind = p.find(primitive_kind::eltwise);
- jcp.with_eltwise = eltwise_ind != -1;
- if (jcp.with_eltwise) {
- jcp.eltwise = p.entry_[eltwise_ind].eltwise;
- if (dst_d.data_type() == data_type::s32) return status::unimplemented;
- }
-
- auto dat_tag = pick(ndims - 3, nCw16c, nChw16c);
- jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
- jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
-
- bool args_ok = true
- && jcp.ngroups == 1
- && jcp.src_tag == dat_tag
- && jcp.dst_tag == dat_tag;
- if (!args_ok) return status::unimplemented;
-
- args_ok = true
- && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
- && jcp.t_pad == 0 && jcp.l_pad == 0
- && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
- && jcp.kh == 1 && jcp.kw == 1;
- if (!args_ok) return status::unimplemented;
-
- jcp.ic_block = jcp.oc_block = simd_w;
- jcp.transpose_src = false;
-
- if (everyone_is(data_type::f32, src_d.data_type(),
- weights_d.data_type(), dst_d.data_type()))
- {
- const int is_bwd_d = jcp.prop_kind == backward_data;
- format_tag_t wei_tag = with_groups
- ? pick(2 * ndims - 6 + is_bwd_d, gOIw16i16o, gIOw16o16i,
- gOIhw16i16o, gIOhw16o16i)
- : pick(2 * ndims - 6 + is_bwd_d, OIw16i16o, IOw16o16i,
- OIhw16i16o, IOhw16o16i);
-
- jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
- if (jcp.wei_tag != wei_tag)
- return status::unimplemented;
-
- if (jcp.prop_kind != backward_weights && mayiuse(avx512_mic_4ops) &&
- ((jcp.prop_kind == backward_data) ? jcp.oc_block : jcp.ic_block) % 4
- == 0) {
- jcp.ver = ver_4fma;
- jcp.fma_step = 4;
- } else if (jcp.prop_kind == backward_weights && mayiuse(avx512_mic_4ops)
- && !reduce_src
- /* Heuristic condition for relation of src size to oc. Otherwise
- the src transposition overhead exceed the benefit from 4fma
- */
- && ((jcp.is * jcp.ic) / jcp.oc <= 2048)
- && mkldnn_thr_syncable()
- )
- {
- jcp.transpose_src = true;
- jcp.ver = ver_4fma;
- jcp.fma_step = 4;
- } else {
- jcp.ver = (mayiuse(avx512_core)) ? ver_avx512_core : ver_fma;
- jcp.fma_step = 1;
- }
- jcp.typesize_in = sizeof(prec_traits<data_type::f32>::type);
- jcp.typesize_out = sizeof(prec_traits<data_type::f32>::type);
- } else {
- return status::unimplemented;
- }
-
- /* once all the formats are set, check the padding consistency */
- args_ok = true
- && jcp.ic <= src_d.padded_dims()[1]
- && jcp.oc <= dst_d.padded_dims()[1]
- && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
- && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
- if (!args_ok) return status::unimplemented;
-
- const int SMALL_SPATIAL = 10;
- const int BIG_SPATIAL = 28;
- const int BIG_REDUCE_DIM = 1024;
- const int BIG_LOAD_DIM = 256;
-
- int load_blocking{ 0 };
- int load_blocking_max{ 0 };
- int bcast_blocking{ 0 };
- int bcast_blocking_max{ 0 };
- int reduce_blocking{ 0 };
- int reduce_blocking_max{ 0 };
-
- jcp.load_grp_count = 1;
-
- const int L1_capacity = get_cache_size(1, true) / sizeof(float);
- const int L2_size = get_cache_size(2, true) / sizeof(float);
- const int L2_capacity = (L2_size * 3) / 4;
-
- if (one_of(jcp.prop_kind, forward_training, forward_inference,
- backward_data)) {
- if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
- jcp.reduce_dim = jcp.ic;
- jcp.reduce_block = jcp.ic_block;
-
- jcp.load_dim = jcp.oc;
- jcp.load_block = jcp.oc_block;
-
- jcp.bcast_dim = jcp.is;
- } else {
- jcp.reduce_dim = jcp.oc;
- jcp.reduce_block = jcp.oc_block;
-
- jcp.load_dim = jcp.ic;
- jcp.load_block = jcp.ic_block;
-
- jcp.bcast_dim = jcp.os;
- }
- jcp.reduce_loop_unroll = jcp.reduce_block;
- jcp.reduce_loop_bcast_step
- = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in;
-
- jcp.reduce_loop_load_step
- = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in;
- jcp.load_loop_load_step
- = jcp.reduce_dim * jcp.load_block * jcp.typesize_in;
-
- // adjusting registry blocking
- int max_regs, min_regs, size_treshold, ur_step;
- const int spatial
- = (one_of(jcp.prop_kind, forward_training, forward_inference)) ?
- jcp.oh :
- jcp.ih;
- if (jcp.ver == ver_avx512_core && (8 * jcp.mb) / nthreads >= 1) {
- max_regs = 9;
- min_regs = 6;
- size_treshold = 14;
- ur_step = 1;
- jcp.expl_bcast = true;
-
- if (jcp.load_dim > 128 && jcp.load_dim < BIG_LOAD_DIM
- && spatial > SMALL_SPATIAL && spatial < BIG_SPATIAL) {
- max_regs = 6;
- min_regs = 5;
- }
- } else {
- max_regs = jcp.ver == ver_4fma ? 28 : 30;
- min_regs = 9;
- size_treshold = jcp.ver == ver_4fma ? 28 : 14;
- ur_step = jcp.ver == ver_4fma ? 4 : 1;
- jcp.expl_bcast = false;
- jcp.use_vmovntps = true;
- }
- jcp.ur = 1;
- for (int ur_w = max_regs; ur_w >= min_regs; ur_w -= ur_step) {
- if ((spatial >= size_treshold && spatial % ur_w == 0)
- || (spatial < size_treshold && jcp.os % ur_w == 0)) {
- jcp.ur = ur_w;
- break;
- }
- }
- if (jcp.ur == 1) {
- jcp.ur = nstl::min(max_regs, jcp.os);
- int os_tail = jcp.os % max_regs;
- for (int i = max_regs; i >= min_regs; i -= ur_step) {
- int i_tail = jcp.os % i;
- if (i_tail > os_tail || i_tail == 0) {
- jcp.ur = i;
- os_tail = i_tail;
- if (i_tail == 0)
- break;
- }
- }
- }
-
- jcp.reduce_loop_unroll = jcp.reduce_block;
- jcp.reduce_loop_bcast_step
- = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in;
-
- jcp.bcast_block = jcp.ur;
-
- jcp.bcast_loop_output_step = jcp.ur * jcp.load_block * jcp.typesize_out;
- jcp.bcast_loop_output_substep = -1; // unused
- jcp.bcast_loop_bcast_step = jcp.ur * jcp.reduce_block * jcp.typesize_in;
- jcp.bcast_loop_bcast_substep = -1; // unused
-
- jcp.load_loop_iter_step = jcp.load_block;
-
- if (jcp.prop_kind == backward_data)
- jcp.loop_order = loop_lbr;
- else
- jcp.loop_order = reduce_src ? loop_blr : loop_lbr;
-
- int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
- int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
- int nb_load = div_up(jcp.load_dim, jcp.load_block);
-
- if (jcp.ver == ver_avx512_core && jcp.expl_bcast) {
- if (jcp.load_dim <= BIG_LOAD_DIM && spatial > SMALL_SPATIAL
- && spatial < BIG_SPATIAL)
- reduce_blocking = nstl::min(jcp.reduce_dim, 80);
- else if (spatial > SMALL_SPATIAL)
- reduce_blocking = nstl::min(jcp.reduce_dim, 512);
- else
- reduce_blocking = nstl::min(jcp.reduce_dim, 256);
-
- if ((jcp.mb > 28 && spatial >= 28)
- || (jcp.mb > 112 && spatial >= 17))
- jcp.use_vmovntps = true;
- else
- jcp.use_vmovntps = false;
- } else {
-
- reduce_blocking = nb_reduce;
- if (spatial <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
- reduce_blocking = 16;
- else if (spatial > SMALL_SPATIAL
- && jcp.reduce_dim >= BIG_REDUCE_DIM)
- reduce_blocking = 8;
- reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true);
- reduce_blocking *= jcp.reduce_block;
- }
-
- // Check input data cache aliasing.
- // For other ISA constants may be updated.
- // 64 * 1024 is chosen due to 1MB L2 16-way cache.
- // 7 is empirical value. It is about half of 16.
- // So we leave about half of the set for other data - weights, dst
- int way_size = (64 * 1024) / jcp.typesize_in;
- int max_hits = 7;
- if (jcp.bcast_dim * reduce_blocking > way_size * max_hits) {
- int nrb = reduce_blocking / simd_w;
- int sp = jcp.bcast_dim;
- int wl = way_size / simd_w;
- for (int start_off = 0; start_off < jcp.ur; start_off++) {
- for (int off = start_off, hits = 0; off < sp * nrb; off += wl) {
- if (off % sp >= jcp.ur || ++hits < max_hits)
- continue;
- int max_r_blocking = simd_w * nstl::max(1, (off + wl) / sp);
- reduce_blocking
- = nstl::min(reduce_blocking, max_r_blocking);
- break;
- }
- }
- }
-
- if (reduce_blocking < jcp.reduce_dim) {
- jcp.use_vmovntps = false;
- if (jcp.prop_kind == backward_data)
- jcp.loop_order = reduce_src ? loop_lbr : loop_rlb;
- else
- jcp.loop_order = reduce_src ? loop_rbl : loop_rlb;
- }
- load_blocking = jcp.load_dim;
-
- int load_size = jcp.load_dim * jcp.reduce_dim;
- int bcast_size = jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim;
-
- if (jcp.ver == ver_avx512_core && nthreads <= 28 && jcp.mb < nthreads
- && nb_load * nb_bcast > nthreads) {
- // Some heuristic here
- float calc_koef = 0.01, best_cost = FLT_MAX;
- int n_lgc = nthreads;
- float ratio = (float)load_size / (float)bcast_size;
- int best_lgc = ratio > 1 ? n_lgc : 1;
- auto calc_job_cost = [&](int lb, int tg, float mem_k) {
- int bb_size = jcp.mb * div_up(nb_bcast, tg);
- float calc_size = (float)(bb_size * jcp.ur)
- * (lb * jcp.load_block) * jcp.reduce_dim;
- float mem_size = (float)(bb_size * jcp.ur + lb * jcp.load_block)
- * jcp.reduce_dim;
- return calc_koef * calc_size + mem_k * mem_size;
- };
- for (int lgc, ilgc = 0; ilgc < n_lgc; ilgc++) {
- lgc = ratio > 1 ? n_lgc - ilgc : ilgc + 1;
- int min_lb = nb_load / lgc;
- int max_lb = div_up(nb_load, lgc);
- int min_tg = nthreads / lgc;
- int max_tg = div_up(nthreads, lgc);
- // Some heuristic here
- float mem_koef = (max_tg == 1) ? 1.f : 1.3f;
- float job_cost = 0.;
- if (nthreads % lgc < nb_load % lgc) {
- job_cost = calc_job_cost(max_lb, min_tg, mem_koef);
- } else {
- auto job_cost1 = calc_job_cost(max_lb, max_tg, mem_koef);
- auto job_cost2 = calc_job_cost(min_lb, min_tg, mem_koef);
- job_cost = nstl::max(job_cost1, job_cost2);
- }
-
- if (job_cost < best_cost) {
- best_lgc = lgc;
- best_cost = job_cost;
- }
- }
- jcp.load_grp_count = best_lgc;
- load_blocking = div_up(nb_load, jcp.load_grp_count) * jcp.load_block;
- } else {
- jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast);
- jcp.load_grp_count = best_divider(
- nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false);
- }
-
- if (jcp.ver == ver_avx512_core && jcp.expl_bcast && jcp.bcast_dim <= 64
- && load_size >= L2_size) {
- jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4);
- } else if (jcp.bcast_dim <= 49 && jcp.mb <= nthreads
- && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) {
- jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2);
- load_blocking = jcp.load_block;
- }
-
- if (jcp.ver == ver_4fma && jcp.bcast_dim * jcp.mb < jcp.load_dim
- && jcp.oh * jcp.ow > 64
- && IMPLICATION(reduce_src, jcp.load_dim < 1024)) {
- /* Looking for best loading dimension blocking
- * to get the best thread and data read/write efficiency
- * by finding the optimal 'load_chunk' value
- * Example:
- * for 72 threads and convolution with mb=1, ih=iw=7, oc = 512
- * the 'best' load_chunk value should be 1
- * TODO: remove heuristic constants in above condition
- * TODO: check this blocking for other ISA
- */
- float best_eff = -1.f;
- int best_lgc = 1;
-
- for (int load_chunk = 1; load_chunk <= nb_load; load_chunk++) {
- int lgc = div_up(nb_load, load_chunk);
- if (lgc > nthreads)
- continue;
- int thr_per_grp = div_up(nthreads, lgc);
- int bcast_per_thr = div_up(jcp.mb * nb_bcast, thr_per_grp)
- * jcp.bcast_block;
- int load_per_thr = load_chunk * simd_w;
- float data_norm = (bcast_per_thr + load_per_thr) / 2.f;
- float data_eff = (bcast_per_thr * load_per_thr)
- / (data_norm * data_norm);
- float thr_eff_over_grp = (float)nstl::max(1, nthreads / lgc)
- / div_up(nthreads, lgc);
- float thr_eff_in_grp = ((float)jcp.mb * nb_bcast)
- / rnd_up(jcp.mb * nb_bcast, thr_per_grp);
- float thr_eff = thr_eff_over_grp * thr_eff_in_grp;
- float load_eff = (float)nb_load / rnd_up(nb_load, lgc);
- float overall_eff = data_eff + thr_eff + load_eff;
- if (overall_eff > best_eff) {
- best_eff = overall_eff;
- best_lgc = lgc;
- }
- }
- jcp.load_grp_count = best_lgc;
- load_blocking
- = div_up(nb_load, jcp.load_grp_count) * jcp.load_block;
- }
- bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
- div_up(nthreads, jcp.load_grp_count))
- * jcp.bcast_block;
- bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking);
- bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block);
-
- int space_for_bcast
- = (L2_capacity - /* kernel_size - */
- 2 * jcp.load_block * reduce_blocking
- - jcp.ur * reduce_blocking - 3 * 1024);
- if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity)
- space_for_bcast /= 2;
-
- int bcast_in_cache
- = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking);
- bcast_blocking = nstl::min(
- bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block));
-
- load_blocking_max = load_blocking;
- bcast_blocking_max = bcast_blocking * 3 / 2;
- reduce_blocking_max = reduce_blocking;
-
- } else if (jcp.prop_kind == backward_weights) {
-
- jcp.use_vmovntps = false;
- if (jcp.is > SMALL_SPATIAL * SMALL_SPATIAL && jcp.ver == ver_4fma)
- jcp.use_vmovntps = true;
-
- if (jcp.transpose_src)
- jcp.reduce_dim = jcp.tr_is;
- else
- jcp.reduce_dim = jcp.is;
-
- if (jcp.ver == ver_4fma) {
- // reduce_block should be divided by fma_step
- jcp.reduce_block = best_divider(jcp.reduce_dim, 4, 16, true, 4);
- } else {
- jcp.reduce_block = best_divider(jcp.reduce_dim, 7, 16, true);
- if (jcp.reduce_dim % jcp.reduce_block != 0)
- jcp.reduce_block = best_divider(jcp.iw, 4, jcp.iw, false);
- if (jcp.reduce_block > 256) {
- jcp.reduce_block = 1;
- }
-
- }
-
- jcp.load_dim = jcp.oc;
- jcp.load_block = jcp.oc_block;
-
- jcp.bcast_dim = jcp.ic;
- jcp.bcast_block = jcp.ic_block;
-
- if (jcp.ver == ver_avx512_core && jcp.reduce_block <= 19) {
- // if reduce_block is big then generated JIT code may be big
- // for small values of ur because reduce_loop_unroll = reduce_block
- jcp.ur = jcp.bcast_block / 2;
- jcp.expl_bcast = true;
- } else {
- jcp.ur = jcp.bcast_block;
- jcp.expl_bcast = false;
- }
-
- jcp.reduce_loop_unroll = jcp.reduce_block;
- jcp.reduce_loop_bcast_step
- = jcp.reduce_loop_unroll * jcp.ic_block * jcp.typesize_in;
- jcp.reduce_loop_load_step
- = jcp.reduce_loop_unroll * jcp.oc_block * jcp.typesize_in;
-
- jcp.bcast_loop_output_step =
- jcp.oc_block * jcp.ic_block * jcp.typesize_out;
- jcp.bcast_loop_output_substep =
- jcp.oc_block * jcp.ur * jcp.typesize_out;
- jcp.bcast_loop_bcast_step =
- jcp.ic_block * jcp.reduce_dim * jcp.typesize_in;
- jcp.bcast_loop_bcast_substep = jcp.ur * jcp.typesize_in;
-
- jcp.load_loop_load_step = jcp.oc_block * jcp.os * jcp.typesize_in;
- jcp.load_loop_iter_step = jcp.oc_block;
-
- /* --- */
- balance(jcp, nthreads);
-
- load_blocking = div_up(jcp.load_dim, jcp.load_block);
- load_blocking = best_divider(load_blocking, 16, load_blocking, false);
- load_blocking *= jcp.load_block;
-
- load_blocking_max = load_blocking;
- assert(jcp.load_dim % load_blocking == 0);
-
- int max_bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
- int min_bcast_blocking = 5;
-
- bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
- bcast_blocking = best_divider(
- bcast_blocking, min_bcast_blocking, max_bcast_blocking, false);
- bcast_blocking *= jcp.bcast_block;
- bcast_blocking_max = bcast_blocking;
- assert(jcp.bcast_dim % bcast_blocking == 0);
-
- // for reduction balance
- if (jcp.ver == ver_avx512_core) {
- int max_reduce_blocking
- = nstl::min(L1_capacity / jcp.ur, jcp.reduce_dim);
- int min_reduce_blocking = nstl::min(
- L1_capacity / jcp.ur, nstl::max(jcp.iw, jcp.ih));
- reduce_blocking = best_divider(jcp.reduce_dim, min_reduce_blocking,
- max_reduce_blocking, true);
- reduce_blocking
- = nstl::max(rnd_dn(reduce_blocking, jcp.reduce_block),
- jcp.reduce_block);
- } else {
- int max_reduce_blocking = L2_capacity
- / ((bcast_blocking + load_blocking) * jcp.reduce_block);
- max_reduce_blocking = nstl::min(max_reduce_blocking,
- (L1_capacity / (jcp.bcast_block)) / jcp.reduce_block);
-
- int num_jobs = div_up(jcp.load_dim, load_blocking)
- * div_up(jcp.bcast_dim, bcast_blocking);
- int threads_per_job = nstl::max(1, nthreads / num_jobs);
- reduce_blocking = div_up(jcp.mb * jcp.reduce_dim, jcp.reduce_block);
- reduce_blocking = div_up(reduce_blocking, threads_per_job);
-
- reduce_blocking = best_divider(reduce_blocking,
- max_reduce_blocking - 2, max_reduce_blocking, true);
- reduce_blocking *= jcp.reduce_block;
- }
-
- reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block);
- } else
- return status::unimplemented;
-
- assert(load_blocking);
- assert(load_blocking_max);
- assert(bcast_blocking);
- assert(bcast_blocking_max);
- assert(reduce_blocking);
- assert(reduce_blocking_max);
- assert(load_blocking % jcp.load_block == 0);
- assert(reduce_blocking % jcp.reduce_block == 0);
- assert(load_blocking_max % jcp.load_block == 0);
- assert(reduce_blocking_max % jcp.reduce_block == 0);
- if (jcp.ver == ver_4fma) {
- assert(jcp.reduce_loop_unroll % jcp.fma_step == 0);
- assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
- }
-
- assert(jcp.bcast_block % jcp.ur == 0);
- assert(jcp.reduce_dim % jcp.reduce_block == 0);
-
- jcp.ur_tail = jcp.bcast_dim % jcp.ur;
-
- jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
- jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
- jcp.nb_load_blocking = load_blocking / jcp.load_block;
- jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
- jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
- jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block;
-
- jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
- jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
- jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
-
- return status::success;
-}
-
-void jit_avx512_common_1x1_conv_kernel::init_scratchpad(
- memory_tracking::registrar_t &scratchpad,
- const jit_1x1_conv_conf_t &jcp) {
- using namespace mkldnn::impl::memory_tracking::names;
-
- if (jcp.prop_kind != backward_data && jcp.with_bias
- && jcp.oc != jcp.oc_without_padding)
- scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
-
- if (jcp.prop_kind == backward_weights) {
- const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic;
- scratchpad.book(key_conv_wei_reduction,
- jcp.typesize_out * wei_size * (jcp.nthr_mb - 1));
- }
-
- if (jcp.transpose_src) {
- const size_t tr_src_size =
- (size_t)jcp.nthr_mb * jcp.ngroups * jcp.ic * jcp.tr_is;
- scratchpad.book(key_conv_tr_src, jcp.typesize_out * tr_src_size);
- scratchpad.book(key_conv_tr_src_bctx,
- sizeof(simple_barrier::ctx_t) * jcp.nthr);
- }
-}
-
-void jit_avx512_common_1x1_conv_kernel::balance(jit_1x1_conv_conf_t &jcp,
- int nthreads)
-{
- // initialize jcp reduction threading properties
- jcp.nthr = jcp.nthr_mb = jcp.nthr_g = jcp.nthr_oc_b = jcp.nthr_ic_b = 1;
- if (nthreads < jcp.ngroups) {
- /* simplification... fortunately it doesn't hurt much */
- return;
- }
- const int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
- const int nb_load = div_up(jcp.load_dim, jcp.load_block);
- const int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
-
- jcp.nthr_g = jcp.ngroups;
- const int nthr = nthreads / jcp.nthr_g;
-
- auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
- /* calculate per thread memory cost (read/write). high level
- * optimizer tries to minimize memory consumption. few notes: (n1)
- * unclear why, but that essentially helps first convolution...
- * (n2) assuming the reduction over minibatch is always there:
- * - instead of 8 it should be 5 here (write ~= 2 read):
- * kernel: temporal workspace 1 write
- * reduction: 1 read from workspace and 1 write to the diff_wei
- * - but experiments showed 8 works better than 5 or 6... */
- int bcast_koeff = 1;
- int load_koeff = 1;
- int output_koeff = 12;
- if (jcp.transpose_src) {
- bcast_koeff = 5;
- load_koeff = 1;
- output_koeff = 8;
- }
- return 0
- + (size_t)bcast_koeff * div_up(jcp.mb * nb_reduce, nthr_mb)
- * div_up(jcp.ngroups, jcp.nthr_g)
- * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block * jcp.reduce_block
- / jcp.stride_h / jcp.stride_w /* (n1) */
- + (size_t)load_koeff * div_up(jcp.mb * nb_reduce, nthr_mb)
- * div_up(jcp.ngroups, jcp.nthr_g)
- * div_up(nb_load, nthr_oc_b) * jcp.oc_block * jcp.reduce_block
- + (size_t)output_koeff /* (n2) */
- * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b)
- * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block
- * jcp.oc_block;
- };
-
- int nthr_mb = 1, nthr_oc_b = 1, nthr_ic_b = 1;
- auto best_mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
-
- /* step 1: find the best thread distribution with lowest memory cost */
- const int nthr_mb_max = nstl::min(nthr, jcp.mb * nb_reduce);
- for (nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
- const int nthr_par = nthr / nthr_mb;
- const int nthr_oc_b_max = nstl::min(nthr_par, nb_load);
- for (nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
- nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, nb_bcast);
- auto mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
- if (mem_cost <= best_mem_cost) {
- best_mem_cost = mem_cost;
- jcp.nthr_mb = nthr_mb;
- jcp.nthr_oc_b = nthr_oc_b;
- jcp.nthr_ic_b = nthr_ic_b;
- }
- }
-
- if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
- }
- if (jcp.nthr_mb > nthreads / 2 && jcp.nthr_mb < nthreads)
- jcp.nthr_mb = nstl::min(jcp.mb, nthreads);
-
- jcp.nthr = jcp.nthr_mb * jcp.nthr_g * jcp.nthr_oc_b * jcp.nthr_ic_b;
- assert(jcp.nthr <= nthreads);
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp
deleted file mode 100644
index d2ae017943..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp
+++ /dev/null
@@ -1,108 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef JIT_AVX512_COMMON_1x1_CONV_KERNEL_HPP
-#define JIT_AVX512_COMMON_1x1_CONV_KERNEL_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-
-#include "jit_generator.hpp"
-#include "jit_primitive_conf.hpp"
-#include "jit_uni_eltwise.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct jit_avx512_common_1x1_conv_kernel : public jit_generator {
- jit_avx512_common_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp,
- const primitive_attr_t &attr)
- : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr)
- {
- if (jcp.with_eltwise)
- eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>(
- this, jcp.eltwise);
-
- this->generate();
- jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode();
- }
-
- ~jit_avx512_common_1x1_conv_kernel() {
- delete eltwise_injector_;
- }
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_1x1_conv_kernel)
-
- static bool post_ops_ok(jit_1x1_conv_conf_t &jcp,
- const primitive_attr_t &attr);
-
- static status_t init_conf(jit_1x1_conv_conf_t &jcp,
- const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr,
- int nthreads, bool reduce_src);
-
- static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
- const jit_1x1_conv_conf_t &jcp);
-
- jit_1x1_conv_conf_t jcp;
- const primitive_attr_t &attr_;
- void (*jit_ker)(jit_1x1_conv_call_s *);
-
- private:
- using reg64_t = const Xbyak::Reg64;
- using zmm_t = const Xbyak::Zmm;
-
- reg64_t reg_bcast_data = r8;
- reg64_t reg_load_data = r10;
- reg64_t reg_output_data = r9;
- reg64_t aux_reg_bcast_data = r14;
- reg64_t aux1_reg_bcast_data = rbx;
- reg64_t aux_reg_load_data = r15;
- reg64_t imm_addr64 = aux_reg_load_data;
- reg64_t aux_reg_output_data = abi_not_param1;
- reg64_t reg_load_loop_work = rsi;
- reg64_t reg_reduce_loop_work = r11;
- reg64_t bcast_loop_iter = rdx;
- reg64_t reduce_loop_iter = abi_param1;
- reg64_t reg_reduce_pos_flag = rax;
- reg64_t reg_output_stride = r13;
- reg64_t reg_bias_data = r12;
- reg64_t reg_relu_ns = r13;
- reg64_t reg_bcast_loop_work = aux1_reg_bcast_data;
-
- Xbyak::Zmm vreg_bcast = Xbyak::Zmm(31);
-
- jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_;
-
- int bcast_loop_work_offt = 0;
- int stack_space_needed = 16;
-
- void bcast_loop(int load_loop_blk);
- void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound);
-
- void generate();
- static void balance(jit_1x1_conv_conf_t &jcp, int nthreads);
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp
deleted file mode 100644
index 54d58c8a39..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp
+++ /dev/null
@@ -1,816 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "jit_generator.hpp"
-
-#include "jit_avx512_common_1x1_convolution.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::memory_tracking::names;
-using namespace mkldnn::impl::utils;
-
-#define data_blk_off(f, n, c, h, w) \
- ((ndims == 3) \
- ? (f).blk_off(n, c, w) \
- : (f).blk_off(n, c, h, w))
-
-
-namespace {
-template <typename T, typename U>
-void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end,
- T nx, T &nx_start, T &nx_end, T nx_divider)
-{
- const int grp_count = nstl::min(nx_divider, nthr);
- const int grp_size_big = nthr / grp_count + 1;
- const int grp_size_small = nthr / grp_count;
- const int n_grp_big = nthr % grp_count;
- const int threads_in_big_groups = n_grp_big * grp_size_big;
-
- const int ithr_bound_distance = ithr - threads_in_big_groups;
- T grp, grp_ithr, grp_nthr;
- if (ithr_bound_distance < 0) { // ithr in first groups
- grp = ithr / grp_size_big;
- grp_ithr = ithr % grp_size_big;
- grp_nthr = grp_size_big;
- } else { // ithr in last groups
- grp = n_grp_big + ithr_bound_distance / grp_size_small;
- grp_ithr = ithr_bound_distance % grp_size_small;
- grp_nthr = grp_size_small;
- }
-
- balance211(nx, grp_count, grp, nx_start, nx_end);
- balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end);
-}
-}
-/* convolution forward */
-
-template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
-void jit_avx512_common_1x1_convolution_fwd_t<src_type, wei_type, dst_type>::
-execute_forward(const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
-
- auto scratchpad = this->scratchpad(ctx);
-
- const auto &jcp = kernel_->jcp;
- if (pd()->wants_padded_bias()) {
- auto padded_bias = scratchpad.template get<dst_data_t>(
- key_conv_padded_bias);
- utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
- utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
- jcp.oc - jcp.oc_without_padding);
- bias = padded_bias;
- }
-
- parallel(0, [&](const int ithr, const int nthr) {
- execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad);
- });
-
- if (pd()->wants_zero_pad_dst())
- ctx.memory(MKLDNN_ARG_DST)->zero_pad();
-}
-
-template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
-void jit_avx512_common_1x1_convolution_fwd_t<src_type, wei_type, dst_type>::
-execute_forward_thr(const int ithr, const int nthr, const src_data_t *src,
- const wei_data_t *weights, const dst_data_t *bias, dst_data_t *dst,
- const memory_tracking::grantor_t &scratchpad) const {
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
-
- const auto &jcp = kernel_->jcp;
- auto rtus_space = scratchpad.get<src_data_t>(key_conv_rtus_space);
-
- const int ndims = src_d.ndims();
- const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
- const int stride_w = pd()->desc()->strides[ndims - 3];
- const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
- const int pad_l = pd()->desc()->padding[0][ndims - 3];
-
- const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
-
- auto step = [](int default_step, int remaining, int tail_step) {
- assert(default_step <= tail_step);
- return remaining < tail_step ? remaining : default_step;
- };
-
- auto p = jit_1x1_conv_call_s();
-
- auto rp = rtus_driver_t<avx512_common>::call_params_t();
-
- const int nb_oc = jcp.nb_load;
- const int nb_ic = jcp.nb_reduce;
- const int nb_ic_blocking = jcp.nb_reduce_blocking;
- const int os_block = jcp.bcast_block;
-
- int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0};
- balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
- jcp.nb_load, ocb_start, ocb_end, jcp.load_grp_count);
-
- auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step,
- int &oh, int &ow, int &ih, int &iw)
- {
- int osb{0};
- nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb,
- jcp.nb_bcast);
- bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
- jcp.nb_bcast_blocking_max);
- bcast_step = nstl::min(bcast_step, bcast_end - iwork);
-
- const int os = osb * os_block;
- oh = os / jcp.ow;
- ow = os % jcp.ow;
-
- ih = nstl::max(oh * stride_h - pad_t, 0);
- iw = nstl::max(ow * stride_w - pad_l, 0);
- rp.iw_start = iw;
-
- p.bcast_dim = this_block_size(os, jcp.os,
- bcast_step * os_block);
- rp.os = p.bcast_dim;
- };
-
- auto init_load = [&](int ocb, int &load_step)
- {
- load_step = step(jcp.nb_load_blocking, ocb_end - ocb,
- jcp.nb_load_blocking_max);
- p.load_dim = this_block_size(ocb * jcp.oc_block,
- ocb_end * jcp.oc_block, load_step * jcp.oc_block);
- };
-
- auto init_reduce = [&](int icb)
- {
- const int nb_ic_blocking_step =
- nstl::min(icb + nb_ic_blocking, nb_ic) - icb;
- p.first_last_flag = 0
- | (icb == 0 ? FLAG_REDUCE_FIRST : 0)
- | (icb + nb_ic_blocking_step >= nb_ic
- ? FLAG_REDUCE_LAST : 0);
-
- p.reduce_dim = this_block_size(icb * jcp.ic_block,
- jcp.ic, nb_ic_blocking_step * jcp.ic_block);
- rp.icb = p.reduce_dim / jcp.reduce_block;
- };
-
- auto inner_ker = [&](int ocb, int icb, int n, int g, int oh, int ow,
- int ih, int iw)
- {
-
- const int _ocb = g * nb_oc + ocb;
- const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow);
-
- p.output_data = &dst[dst_off];
- p.bias_data = &bias[_ocb * jcp.oc_block];
- p.load_data = &weights[pd()->with_groups()
- ? weights_d.blk_off(g, ocb, icb)
- : weights_d.blk_off(ocb, icb)];
-
- const int _icb = g * nb_ic + icb;
- if (pd()->rtus_.reduce_src_) {
- rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
- + _icb * jcp.is * jcp.ic_block;
- if (ocb == ocb_start) {
- rp.src = src + data_blk_off(src_d, n, _icb, ih, iw);
- rtus_driver_->ker_(&rp);
- }
- p.bcast_data = rp.ws;
- } else
- p.bcast_data = src + data_blk_off(src_d, n, _icb, ih, iw);
-
- kernel_->jit_ker(&p);
- };
-
- if (jcp.loop_order == loop_rlb) {
- for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
- init_reduce(icb);
- int ocb = ocb_start;
- while (ocb < ocb_end) {
- int load_step;
- init_load(ocb, load_step);
- int iwork = bcast_start;
- while (iwork < bcast_end) {
- int n, g, bcast_step, oh, ow, ih, iw;
- init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
- inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
- iwork += bcast_step;
- }
- ocb += load_step;
- }
- }
- } else if (jcp.loop_order == loop_lbr) {
- int ocb = ocb_start;
- while (ocb < ocb_end) {
- int load_step;
- init_load(ocb, load_step);
- int iwork = bcast_start;
- while (iwork < bcast_end) {
- int n, g, bcast_step, oh, ow, ih, iw;
- init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
- for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
- init_reduce(icb);
- inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
- }
- iwork += bcast_step;
- }
- ocb += load_step;
- }
- } else if (jcp.loop_order == loop_rbl) {
- for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
- init_reduce(icb);
- int iwork = bcast_start;
- while (iwork < bcast_end) {
- int n, g, bcast_step, oh, ow, ih, iw;
- init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
- int ocb = ocb_start;
- while (ocb < ocb_end) {
- int load_step;
- init_load(ocb, load_step);
- inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
- ocb += load_step;
- }
- iwork += bcast_step;
- }
- }
- } else if (jcp.loop_order == loop_blr) {
- int iwork = bcast_start;
- while (iwork < bcast_end) {
- int n, g, bcast_step, oh, ow, ih, iw;
- init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
- int ocb = ocb_start;
- while (ocb < ocb_end) {
- int load_step;
- init_load(ocb, load_step);
- for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
- init_reduce(icb);
- inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
- }
- ocb += load_step;
- }
- iwork += bcast_step;
- }
- } else {
- assert(!"unsupported loop order");
- }
-}
-
-
-template struct jit_avx512_common_1x1_convolution_fwd_t<data_type::f32>;
-/* convolution backward wtr data */
-
-template <data_type_t diff_dst_type, data_type_t wei_type,
- data_type_t diff_src_type>
-void jit_avx512_common_1x1_convolution_bwd_data_t<diff_dst_type, wei_type,
- diff_src_type>::execute_backward_data(const exec_ctx_t &ctx) const {
- auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
- const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
-
- const auto &jcp = kernel_->jcp;
- auto rtus_space = scratchpad(ctx).template get<diff_src_data_t>(
- key_conv_rtus_space);
-
- const int ndims = diff_src_d.ndims();
-
- // TODO (Roma): remove this restriction
- assert(jcp.stride_w == 1 && jcp.stride_h == 1);
-
- const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
- const int stride_w = pd()->desc()->strides[ndims - 3];
- const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
- const int pad_l = pd()->desc()->padding[0][ndims - 3];
-
- const int nb_ic = jcp.nb_load;
- const int nb_oc = jcp.nb_reduce;
- const int os_block = jcp.bcast_block;
- const int nb_oc_blocking = jcp.nb_reduce_blocking;
-
- const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
-
- auto step = [](int default_step, int remaining, int tail_step) {
- assert(default_step <= tail_step);
- return remaining < tail_step ? remaining : default_step;
- };
-
- parallel(0, [&](const int ithr, const int nthr) {
- auto p = jit_1x1_conv_call_s();
- auto rp = rtus_driver_t<avx512_common>::call_params_t();
-
- int bcast_start{0}, bcast_end{0}, icb_start{0}, icb_end{0};
- balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
- jcp.nb_load, icb_start, icb_end, jcp.load_grp_count);
-
- bool reduce_outer = (jcp.loop_order == loop_rbl
- || jcp.loop_order == loop_rlb);
- int nboc_outer = reduce_outer ? nb_oc : 1;
- int ocb_outer_step = reduce_outer ? nb_oc_blocking : 1;
-
- int nboc_inner = reduce_outer ? 1 : nb_oc;
- int ocb_inner_step = reduce_outer ? 1 : nb_oc_blocking;
-
- for (int ocb_outer = 0; ocb_outer < nboc_outer;
- ocb_outer += ocb_outer_step) {
- size_t cur_ocb_outer =
- nstl::min(ocb_outer + ocb_outer_step, nboc_outer) - ocb_outer;
-
- int load_step = 0;
- for (int icb = icb_start; icb < icb_end; icb += load_step) {
- load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb,
- jcp.nb_load_blocking_max);
-
- p.load_dim = this_block_size(icb * jcp.ic_block,
- icb_end * jcp.ic_block, load_step * jcp.ic_block);
- rp.icb = p.load_dim / jcp.ic_block;
-
- int bcast_step;
- for (int iwork = bcast_start; iwork < bcast_end;
- iwork += bcast_step)
- {
- int n{0}, g{0}, osb{0};
- nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb,
- jcp.nb_bcast);
-
- bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
- jcp.nb_bcast_blocking_max);
- bcast_step = nstl::min(bcast_step, bcast_end - iwork);
-
- const int os = osb * os_block;
- p.bcast_dim = this_block_size(os, jcp.os,
- bcast_step * os_block);
- rp.os = p.bcast_dim;
-
- const int oh = os / jcp.ow;
- const int ow = os % jcp.ow;
- const int ih = nstl::max(oh * stride_h - pad_t, 0);
- const int iw = nstl::max(ow * stride_w - pad_l, 0);
- rp.iw_start = iw;
-
- const int _icb = g * nb_ic + icb;
- rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw);
- if (pd()->rtus_.reduce_src_) {
- rp.ws = rtus_space
- + ithr * pd()->rtus_.space_per_thread_;
- p.output_data = rp.ws;
- } else
- p.output_data = rp.src;
-
- for (int ocb_inner = 0; ocb_inner < nboc_inner;
- ocb_inner += ocb_inner_step) {
- int cur_ocb_inner =
- nstl::min(ocb_inner + ocb_inner_step, nboc_inner) -
- ocb_inner;
-
- int ocb = reduce_outer ? ocb_outer : ocb_inner;
- int nb_oc_blocking_step = reduce_outer
- ? cur_ocb_outer : cur_ocb_inner;
- const int _ocb = g * nb_oc + ocb;
- size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh, ow);
- p.bcast_data = &diff_dst[diff_dst_off];
-
- p.load_data = &weights[pd()->with_groups()
- ? weights_d.blk_off(g, ocb, icb)
- : weights_d.blk_off(ocb, icb)];
-
- p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0;
-
- p.reduce_dim = this_block_size(ocb * jcp.oc_block,
- jcp.oc, nb_oc_blocking_step * jcp.oc_block);
-
- kernel_->jit_ker(&p);
- }
- if (pd()->rtus_.reduce_src_)
- rtus_driver_->ker_(&rp);
- }
- }
- }
- });
-}
-
-template struct jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>;
-
-/* convolution backward wtr weights */
-
-#define wht_blk_off(d, g, ...) \
- (pd()->with_groups() \
- ? (d).blk_off((g), __VA_ARGS__) \
- : (d).blk_off(__VA_ARGS__))
-
-jit_avx512_common_1x1_convolution_bwd_weights_t ::
- jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd)
- : cpu_primitive_t(apd)
- , kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr)
- , trans_kernel_(nullptr), rtus_driver_(nullptr)
-{
- kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr());
- acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>();
- reducer_bias_ = new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_);
- init_rtus_driver<avx512_common>(this);
-
- const auto &jcp = kernel_->jcp;
-
- if (jcp.transpose_src) {
- auto tp = jit_transpose4x16_src_t();
- tp.src_pf0_distance = 4;
- tp.tr_src_pf0_distance = 0;
- tp.src_pf1 = true;
- tp.tr_src_pf1 = false;
- trans_kernel_ = new jit_transpose4x16_src(&jcp, &tp);
- }
-}
-
-void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights(
- const exec_ctx_t &ctx) const
-{
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
- auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
-
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
-
- const auto &jcp = kernel_->jcp;
-
- const auto scratchpad = this->scratchpad(ctx);
-
- auto rtus_space = scratchpad.get<data_t>(key_conv_rtus_space);
- data_t *diff_bias = pd()->wants_padded_bias()
- ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
- auto wei_reduction = scratchpad.get<data_t>(key_conv_wei_reduction);
-
- /* prepare src transposition barriers */
- auto tr_src = scratchpad.get<data_t>(key_conv_tr_src);
- auto tr_src_bctx = scratchpad.get<simple_barrier::ctx_t>(
- key_conv_tr_src_bctx);
- if (jcp.transpose_src) {
- for (int i = 0; i < jcp.nthr; ++i)
- simple_barrier::ctx_init(&tr_src_bctx[i]);
- }
-
- const int ndims = src_d.ndims();
- const int wei_size = jcp.ngroups * jcp.oc * jcp.ic;
-
- simple_barrier::ctx_t reduction_barrier;
- simple_barrier::ctx_init(&reduction_barrier);
-
- const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
- prefix_reducer_bia);
- auto rb = this->reducer_bias_;
- rb->init(reducer_bia_scratchpad);
-
- // TODO (Roma): remove this restriction
- assert(jcp.stride_w == 1 && jcp.stride_h == 1);
-
- const int nb_ic = jcp.nb_bcast;
- const int nb_ic_blocking = jcp.nb_bcast_blocking;
-
- const int nb_oc = jcp.nb_load;
- const int nb_oc_blocking = jcp.nb_load_blocking;
-
- const int sp_nb = jcp.nb_reduce;
- const int mb_sp_work = jcp.mb * sp_nb;
-
- const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
- const int stride_w = pd()->desc()->strides[ndims - 3];
- const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
- const int pad_l = pd()->desc()->padding[0][ndims - 3];
-
- auto step = [](int default_step, int remaining, int tail_step) {
- assert(default_step <= tail_step);
- return remaining < tail_step ? remaining : default_step;
- };
-
- // TODO: use memory descriptor with the same fmt as src
- // (or use a macro :))
- auto tr_src_off = [&](int img, int icb, int is) {
- const size_t tr_chn_size = jcp.tr_is * jcp.ic_block;
- const size_t tr_img_size = tr_chn_size * nb_ic * jcp.ngroups;
- return img * tr_img_size + icb * tr_chn_size + is * jcp.ic_block;
- };
-
- auto uker_trans = [&](int ithr_mb, int img, int sp_b_start, int sp_size,
- int g_start, int g_work, int ic_b_start, int ic_b_work,
- int ithr, int nthr, int first_ic_b)
- {
- const int work_amount = g_work * ic_b_work;
-
- int start{ 0 }, end{ 0 };
- balance211(work_amount, nthr, ithr, start, end);
-
- int g{ 0 }, ic_b{ 0 };
- nd_iterator_init(start, g, g_work, ic_b, ic_b_work);
- g += g_start;
- const int ic_b_tr = g * nb_ic + first_ic_b + ic_b;
- ic_b += ic_b_start;
-
- const int _ic = g * nb_ic + ic_b;
-
- const int is = sp_b_start * jcp.reduce_block;
- const int ih = is / jcp.iw;
- const int iw = is % jcp.iw;
-
- const int src1_off = data_blk_off(src_d, img, _ic, ih, iw);
- data_t *src1 = (data_t *)&src[src1_off];
- data_t *tr_src1 = &tr_src[tr_src_off(ithr_mb, ic_b_tr, is)];
-
- assert(jcp.ic_block == 16);
- const int src_stride = jcp.is * jcp.ic_block;
- const int tr_src_stride = jcp.tr_is * jcp.ic_block;
-
- const int my_work = end - start;
- for (int iwork = 0; iwork < my_work; iwork++) {
- auto par_trans = jit_src_transpose_s();
- assert(sp_size % 4 == 0 || sp_size % 4 == jcp.is % 4);
- par_trans.size = sp_size;
- par_trans.src = src1;
- par_trans.tr_src = tr_src1;
- par_trans.src_prf = src1 + 64 * 16;
- par_trans.tr_src_prf = tr_src1 + 80 * 16;
- trans_kernel_->jit_ker(&par_trans);
-
- src1 += src_stride;
- tr_src1 += tr_src_stride;
- }
- };
-
- auto ker = [&](const int ithr, const int nthr) {
- assert(nthr == jcp.nthr);
- assert(IMPLICATION(!mkldnn_thr_syncable(), jcp.nthr_mb == 1));
-
- const int ithr_ic_b = ithr % jcp.nthr_ic_b;
- const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b;
- const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g;
- const int ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b /
- jcp.nthr_g;
-
- const int ithr_but_oc
- = (ithr_mb * jcp.nthr_g + ithr_g) * jcp.nthr_ic_b + ithr_ic_b;
-
- /* reduction dimension */
- int mb_sp_b_start{ 0 }, mb_sp_b_end{ 0 };
- if (jcp.transpose_src && jcp.nthr_mb < jcp.mb / 2) {
- // it's preferable to parallelize by mb if possible
- int img_start{ 0 }, img_end{ 0 };
- balance211(jcp.mb, jcp.nthr_mb, ithr_mb, img_start, img_end);
- mb_sp_b_start = img_start * sp_nb;
- mb_sp_b_end = img_end * sp_nb;
- }
- else {
- balance211(mb_sp_work, jcp.nthr_mb, ithr_mb, mb_sp_b_start,
- mb_sp_b_end);
- }
-
- /* independent dimensions */
- int g_start{ 0 }, oc_b_start{ 0 }, ic_b_start{ 0 };
- int g_end{ 0 }, oc_b_end{ 0 }, ic_b_end{ 0 };
-
- balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end);
- balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start,
- oc_b_end);
- balance211(jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start,
- ic_b_end);
-
- const int g_work = g_end - g_start;
- const int oc_b_work = oc_b_end - oc_b_start;
- const int ic_b_work = ic_b_end - ic_b_start;
-
- data_t *diff_wei = ithr_mb == 0
- ? diff_weights : wei_reduction + (ithr_mb - 1) * wei_size;
-
- int sp_b_step = 0;
- for (int mb_sp_b = mb_sp_b_start; mb_sp_b < mb_sp_b_end;
- mb_sp_b += sp_b_step) {
- int img{ 0 }, sp_b{ 0 };
- nd_iterator_init(mb_sp_b, img, jcp.mb, sp_b, sp_nb);
- sp_b_step = step(jcp.nb_reduce_blocking,
- nstl::min(sp_nb - sp_b, mb_sp_b_end - mb_sp_b),
- jcp.nb_reduce_blocking_max);
-
- for (int g = g_start; g < g_end; ++g) {
- int load_step = 0;
- int bcast_step = 0;
- for (int ic_b = ic_b_start; ic_b < ic_b_end;
- ic_b += bcast_step) {
- bcast_step = step(nb_ic_blocking, ic_b_end - ic_b,
- jcp.nb_bcast_blocking_max);
- if (jcp.transpose_src) {
- if (jcp.nthr_oc_b > 1)
- simple_barrier::barrier(
- &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b);
- const int sp_size
- = nstl::min(sp_b_step * jcp.reduce_block,
- jcp.is - sp_b * jcp.reduce_block);
- uker_trans(ithr_mb, img, sp_b, sp_size, g, 1, ic_b,
- bcast_step, ithr_oc_b, jcp.nthr_oc_b, ic_b_start);
- if (jcp.nthr_oc_b > 1)
- simple_barrier::barrier(
- &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b);
- }
-
- for (int oc_b = oc_b_start; oc_b < oc_b_end;
- oc_b += load_step) {
- load_step = step(nb_oc_blocking, oc_b_end - oc_b,
- jcp.nb_load_blocking_max);
- const int _ic_b = g * nb_ic + ic_b;
- const int _ic_b_tr = g * nb_ic + ic_b_start;
- const int _oc_b = g * nb_oc + oc_b;
-
- data_t *store_to;
-
- const size_t off
- = wht_blk_off(diff_weights_d, g, oc_b, ic_b);
- store_to = diff_wei + off;
-
- const data_t *diff_src = jcp.transpose_src ?
- &tr_src[tr_src_off(ithr_mb, _ic_b_tr, 0)] :
- &src[src_d.blk_off(img, _ic_b)];
-
- int sp_b_end = sp_b + sp_b_step;
- const data_t *pdiff_dst
- = &diff_dst[diff_dst_d.blk_off(img, _oc_b)];
- const data_t *local_src = diff_src;
-
- auto p = jit_1x1_conv_call_s();
- auto rp = rtus_driver_t<avx512_common>::call_params_t();
-
- p.output_stride
- = jcp.ic * jcp.oc_block * jcp.typesize_out;
-
- p.load_dim = load_step * jcp.oc_block;
-
- p.bcast_dim = bcast_step * jcp.ic_block;
- rp.icb = bcast_step;
- p.output_data = store_to;
-
- p.reduce_dim = sp_b_step * jcp.reduce_block;
- rp.os = p.reduce_dim;
-
- p.first_last_flag = 0
- | (mb_sp_b == mb_sp_b_start ? FLAG_REDUCE_FIRST : 0)
- | (sp_b_end == sp_nb ? FLAG_SP_LAST : 0);
-
- int sp = sp_b * jcp.reduce_block;
- p.load_data = pdiff_dst + sp * jcp.oc_block;
-
- if (pd()->rtus_.reduce_src_) {
- const int oh = sp / jcp.ow;
- const int ow = sp % jcp.ow;
-
- const int ih = nstl::max(oh * stride_h - pad_t, 0);
- const int iw = nstl::max(ow * stride_w - pad_l, 0);
- rp.iw_start = iw;
-
- rp.ws = rtus_space
- + ithr * pd()->rtus_.space_per_thread_
- + sp * jcp.ic_block;
-
- if (ndims == 3)
- rp.src = local_src + iw
- * src_d.blocking_desc().strides[2];
- else
- rp.src = local_src + ih
- * src_d.blocking_desc().strides[2]
- + iw * src_d.blocking_desc().strides[3];
- rtus_driver_->ker_(&rp);
-
- p.bcast_data = rp.ws;
- } else
- p.bcast_data = local_src + sp * jcp.ic_block;
-
- kernel_->jit_ker(&p);
- }
- }
- }
- }
-
- /* diff_weights[:] += sum(wei_reduction[thr_mb][:]) */
- if (jcp.nthr_mb > 1) {
- simple_barrier::barrier(&reduction_barrier, jcp.nthr);
- const int work = g_work * oc_b_work * ic_b_work;
- int start{ 0 }, end{ 0 };
- balance211(work, jcp.nthr_mb, ithr_mb, start, end);
- if (start == end)
- return;
-
- for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
- int w = start;
- int sub_g_start{ 0 }, sub_oc_b_start{ 0 },
- sub_ic_b_start{ 0 };
- nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start,
- oc_b_work, sub_ic_b_start, ic_b_work);
- while (w < end) {
- const int g = g_start + sub_g_start;
- const int oc_b = oc_b_start + sub_oc_b_start;
- const int ic_b = ic_b_start + sub_ic_b_start;
-
- const int acc_size
- = nstl::min(end - w, ic_b_work - sub_ic_b_start)
- * jcp.ic_block * jcp.oc_block;
-
- const size_t off
- = wht_blk_off(diff_weights_d, g, oc_b, ic_b);
- data_t *d = diff_weights + off;
- data_t *s = wei_reduction + (thr_mb - 1) * wei_size + off;
-
- acc_ker_->accumulate(d, s, acc_size);
-
- nd_iterator_jump(w, end, sub_g_start, g_work,
- sub_oc_b_start, oc_b_work, sub_ic_b_start,
- ic_b_work);
- }
- }
- }
- };
-
- auto ker_bias = [&](int ithr, int nthr) {
- assert(nthr == rb->balancer().nthr_);
-
- const int b_job_start = rb->balancer().ithr_job_off(ithr);
- const int b_njobs = rb->balancer().ithr_njobs(ithr);
-
- if (b_njobs == 0)
- return;
-
- /* reduction dimension */
- int img_start{ 0 }, img_end{ 0 };
-
- balance211(jcp.mb, rb->balancer().nthr_per_group_,
- rb->balancer().id_in_group(ithr), img_start, img_end);
-
- /* jobs */
- int g_start{ 0 }, ocb_start{ 0 };
- nd_iterator_init(
- b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_load);
-
- for (int img = img_start; img < img_end; ++img) {
- int g = g_start, ocb = ocb_start;
- for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
- const size_t _oc = g * jcp.nb_load + ocb;
-
- const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
- data_t *d_bias = rb->get_local_ptr(ithr, diff_bias,
- reducer_bia_scratchpad)
- + b_job_loc * rb->balancer().job_size_;
-
- if (img == img_start)
- for (int o = 0; o < 16; ++o)
- d_bias[o] = 0.;
-
- for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) {
- PRAGMA_OMP_SIMD()
- for (int o = 0; o < 16; ++o)
- d_bias[o] += d_dst[o];
- d_dst += 16;
- }
-
- nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_load);
- }
- }
- rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
- };
-
- parallel(jcp.nthr, [&](const int ithr, const int nthr) {
- ker(ithr, jcp.nthr);
- if (pd()->with_bias())
- ker_bias(ithr, jcp.nthr);
- });
-
- /* TODO: put this in ker_bias */
- if (pd()->wants_padded_bias()) {
- assert(jcp.ngroups == 1);
- utils::array_copy(diff_bias_in, diff_bias, jcp.oc_without_padding);
- }
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp
deleted file mode 100644
index 2e9fda76d6..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp
+++ /dev/null
@@ -1,344 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP
-#define CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "mkldnn_thread.hpp"
-#include "utils.hpp"
-
-#include "cpu_convolution_pd.hpp"
-#include "cpu_primitive.hpp"
-#include "cpu_reducer.hpp"
-
-#include "jit_avx512_common_1x1_conv_kernel.hpp"
-#include "jit_uni_1x1_conv_utils.hpp"
-#include "jit_transpose_src_utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <impl::data_type_t src_type,
- impl::data_type_t wei_type = src_type,
- impl::data_type_t dst_type = src_type>
-struct jit_avx512_common_1x1_convolution_fwd_t : public cpu_primitive_t {
- struct pd_t: public cpu_convolution_fwd_pd_t {
- pd_t(engine_t *engine, const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const typename pd_t::base_class *hint_fwd_pd)
- : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_(), rtus_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""),
- jit_avx512_common_1x1_convolution_fwd_t);
-
- status_t init() {
- using namespace utils;
- bool ok = true
- && is_fwd()
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(src_type, wei_type, dst_type, dst_type,
- data_type::undef)
- && !has_zero_dim_memory()
- && set_default_formats();
- if (!ok) return status::unimplemented;
-
- const convolution_desc_t *conv_d = desc();
- const memory_desc_t *src_d = src_md();
- rtus_prepare(this, conv_d, src_d, dst_md());
-
- status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(
- jcp_, *conv_d, *src_d, *weights_md(), *dst_md(), *attr(),
- mkldnn_get_max_threads(), rtus_.reduce_src_);
- if (status != status::success) return status;
-
- auto scratchpad = scratchpad_registry().registrar();
- jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad,
- jcp_);
-
- rtus_prepare_space_info(this, scratchpad);
-
- return status::success;
- }
-
- jit_1x1_conv_conf_t jcp_;
- reduce_to_unit_stride_t rtus_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
-
- auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
- auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(),
- OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o);
-
- return set_default_formats_common(dat_tag, wei_tag, dat_tag);
- }
- };
-
- template <cpu_isa_t isa, typename conv_t>
- friend void init_rtus_driver(conv_t *self);
-
- jit_avx512_common_1x1_convolution_fwd_t(const pd_t *apd)
- : cpu_primitive_t(apd)
- , kernel_(nullptr), rtus_driver_(nullptr)
- {
- kernel_ =
- new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr());
- init_rtus_driver<avx512_common>(this);
- }
-
- ~jit_avx512_common_1x1_convolution_fwd_t() {
- delete kernel_;
- delete rtus_driver_;
- }
-
- typedef typename prec_traits<src_type>::type src_data_t;
- typedef typename prec_traits<wei_type>::type wei_data_t;
- typedef typename prec_traits<dst_type>::type dst_data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
- private:
- void execute_forward(const exec_ctx_t &ctx) const;
- void execute_forward_thr(const int ithr, const int nthr,
- const src_data_t *src, const wei_data_t *weights,
- const dst_data_t *bias, dst_data_t *dst,
- const memory_tracking::grantor_t &scratchpad) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_avx512_common_1x1_conv_kernel *kernel_;
- rtus_driver_t<avx512_common> *rtus_driver_;
-};
-
-using jit_avx512_common_1x1_convolution_fwd_f32_t
- = jit_avx512_common_1x1_convolution_fwd_t<data_type::f32>;
-
-template <impl::data_type_t diff_dst_type,
- impl::data_type_t wei_type = diff_dst_type,
- impl::data_type_t diff_src_type = diff_dst_type>
-struct jit_avx512_common_1x1_convolution_bwd_data_t : public cpu_primitive_t {
- struct pd_t : public cpu_convolution_bwd_data_pd_t {
- pd_t(engine_t *engine,
- const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_(), rtus_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""),
- jit_avx512_common_1x1_convolution_bwd_data_t);
-
- status_t init() {
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_data
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(diff_src_type, wei_type, data_type::undef,
- diff_dst_type, data_type::undef)
- && !has_zero_dim_memory()
- && set_default_formats();
- if (!ok) return status::unimplemented;
-
- const convolution_desc_t *conv_d = desc();
- const memory_desc_t *diff_src_d = diff_src_md();
- rtus_prepare(this, conv_d, diff_src_d, diff_dst_md());
-
- status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(
- jcp_, *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(),
- *attr(), mkldnn_get_max_threads(), rtus_.reduce_src_);
- if (status != status::success) return status;
-
- auto scratchpad = scratchpad_registry().registrar();
- jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad,
- jcp_);
-
- rtus_prepare_space_info(this, scratchpad);
-
- return status::success;
- }
-
- // TODO (Roma): structs conf header cleanup
- jit_1x1_conv_conf_t jcp_;
- reduce_to_unit_stride_t rtus_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
-
- auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
- auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(),
- IOw16o16i, gIOw16o16i, IOhw16o16i, gIOhw16o16i);
-
- return set_default_formats_common(dat_tag, wei_tag, dat_tag);
- }
- };
-
- template <cpu_isa_t isa, typename conv_t>
- friend void init_rtus_driver(conv_t *self);
-
- jit_avx512_common_1x1_convolution_bwd_data_t(const pd_t *apd)
- : cpu_primitive_t(apd)
- , kernel_(nullptr), rtus_driver_(nullptr)
- {
- kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_,
- *pd()->attr());
- init_rtus_driver<avx512_common>(this);
- }
-
- ~jit_avx512_common_1x1_convolution_bwd_data_t() {
- delete kernel_;
- delete rtus_driver_;
- }
-
- typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
- typedef typename prec_traits<wei_type>::type wei_data_t;
- typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward_data(ctx);
- return status::success;
- }
-
- private:
- void execute_backward_data(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_avx512_common_1x1_conv_kernel *kernel_;
- rtus_driver_t<avx512_common> *rtus_driver_;
-};
-
-using jit_avx512_common_1x1_convolution_bwd_data_f32_t
- = jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>;
-
-struct jit_avx512_common_1x1_convolution_bwd_weights_t : public cpu_primitive_t
-{
- struct pd_t : public cpu_convolution_bwd_weights_pd_t {
- pd_t(engine_t *engine,
- const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_(), rtus_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""),
- jit_avx512_common_1x1_convolution_bwd_weights_t);
-
- status_t init() {
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_weights
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::f32, data_type::f32, data_type::f32)
- && !has_zero_dim_memory()
- && set_default_formats();
- if (!ok) return status::unimplemented;
-
- const convolution_desc_t *conv_d = desc();
- const memory_desc_t *src_d = src_md();
- rtus_prepare(this, conv_d, src_d, diff_dst_md());
-
- status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(
- jcp_, *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(),
- *attr(), mkldnn_get_max_threads(), rtus_.reduce_src_);
- if (status != status::success) return status;
-
- init_balancers();
-
- auto scratchpad = scratchpad_registry().registrar();
- jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad,
- jcp_);
-
- auto reducer_bia_scratchpad = memory_tracking::registrar_t(
- scratchpad, memory_tracking::names::prefix_reducer_bia);
- reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
-
- rtus_prepare_space_info(this, scratchpad);
-
- return status::success;
- }
-
- // TODO (Roma): structs conf header cleanup
- jit_1x1_conv_conf_t jcp_;
- cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
- reduce_to_unit_stride_t rtus_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
-
- auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
- auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(),
- OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o);
-
- return set_default_formats_common(dat_tag, wei_tag, dat_tag);
- }
-
- private:
- void init_balancers() {
- const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16;
- if (with_bias()) {
- reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr,
- jcp_.oc_block, jcp_.ngroups * jcp_.nb_load,
- jcp_.mb, max_buffer_size));
- }
- }
- };
-
- template <cpu_isa_t isa, typename conv_t>
- friend void init_rtus_driver(conv_t *self);
-
- jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd);
-
- ~jit_avx512_common_1x1_convolution_bwd_weights_t() {
- delete kernel_;
- delete acc_ker_;
- delete reducer_bias_;
- delete rtus_driver_;
- delete trans_kernel_;
- }
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward_weights(ctx);
- return status::success;
- }
-
- private:
- void execute_backward_weights(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_avx512_common_1x1_conv_kernel *kernel_;
- cpu_accumulator_1d_t<data_type::f32> *acc_ker_;
- cpu_reducer_t<data_type::f32> *reducer_bias_;
- jit_transpose4x16_src *trans_kernel_;
- rtus_driver_t<avx512_common> *rtus_driver_;
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp
deleted file mode 100644
index 235fb02fef..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp
+++ /dev/null
@@ -1,4539 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_barrier.hpp"
-
-#include "jit_avx512_common_conv_kernel.hpp"
-
-#define GET_OFF(field) offsetof(jit_conv_call_s, field)
-#define KNx_L2_EFFECTIVE_CAPACITY ((512-64)*1024)
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::format_tag;
-using namespace mkldnn::impl::memory_tracking::names;
-using namespace mkldnn::impl::utils;
-using namespace Xbyak;
-
-namespace {
-
-constexpr auto small_spatial = 14;
-unsigned int L1_cache_size = get_cache_size(1, true);
-
-inline void pick_loop_order(jit_conv_conf_t &jcp) {
- using namespace prop_kind;
- assert(one_of(jcp.prop_kind,
- forward_training, forward_inference, backward_data));
- auto w = (jcp.prop_kind == backward_data) ? jcp.iw : jcp.ow;
- auto h = (jcp.prop_kind == backward_data) ? jcp.ih : jcp.oh;
-
- // ow-threading is currently implemented for forward only
- // TODO: single code for fwd and bwd after ow-thr for bwd
- // meaningless switch was removed
- if (jcp.prop_kind == backward_data) {
- jcp.loop_order = (w <= small_spatial && h <= small_spatial)
- ? loop_cgn : loop_gnc;
- } else {
- jcp.loop_order = (w <= small_spatial && h <= small_spatial)
- ? loop_cwgn : loop_gncw;
- }
-}
-
-inline bool is_1stconv(const jit_conv_conf_t &jcp) {
- if (mayiuse(avx512_core))
- return (jcp.ic < 16 && jcp.ngroups == 1);
- else
- return one_of(jcp.ic, 1, 3);
-}
-
-inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) {
- return (jcp.nb_ow > 1);
-}
-
-inline bool is_owb_prefetching(const jit_conv_conf_t &jcp) {
- return (jcp.ver == ver_4fma && is_ow_threading_on(jcp));
-}
-
-}
-
-template<typename Vmm>
-void _jit_avx512_common_conv_fwd_kernel<Vmm>::prepare_output(int ur_w)
-{
- for (int k = 0; k < jcp.nb_oc_blocking; k++)
- for (int j = 0; j < ur_w; j++) {
- Vmm vmm = vmm_out(j, k);
- vpxord(vmm, vmm, vmm);
- if (!is_owb_prefetching(jcp)) {
- size_t aux_output_offset = get_output_offset(j, k);
- mic_prefetcht1(EVEX_compress_addr_safe(reg_out_prf,
- aux_output_offset, reg_out_long_offt));
- }
- }
-}
-
-template<typename Vmm>
-void _jit_avx512_common_conv_fwd_kernel<Vmm>::store_output(int ur_w)
-{
- Label no_update_label, store_label, eltwise_label;
-
- mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
- if (jcp.with_bias) {
- mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
- }
-
- if (!jcp.with_sum) {
- cmp(reg_channel, 0);
- je(no_update_label, T_NEAR);
- }
-
- for (int k = 0; k < jcp.nb_oc_blocking; k++)
- for (int j = 0; j < ur_w; j++) {
- Vmm vmm = vmm_out(j, k);
- size_t aux_output_offset = get_output_offset(j, k);
- vaddps(vmm,
- make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt));
- }
-
- if (!jcp.with_sum) {
- jmp(eltwise_label, T_NEAR);
- } else {
- cmp(reg_channel, 0);
- jne(eltwise_label, T_NEAR);
- }
-
- L(no_update_label);
- if (jcp.with_bias) {
- for (int k = 0; k < jcp.nb_oc_blocking; k++) {
- int bias_offset = jcp.typesize_out * k * jcp.oc_block;
- for (int j = 0; j < ur_w; j++) {
- Vmm vmm = vmm_out(j, k);
- vaddps(vmm, EVEX_compress_addr(reg_bias, bias_offset));
- }
- mic_prefetcht1(EVEX_compress_addr(reg_bias, bias_offset + 64));
- }
- }
-
- L(eltwise_label);
- if (jcp.with_eltwise) {
- cmp(reg_channel, jcp.nb_ic - 1);
- jl(store_label, T_NEAR);
-
- if (ur_w == jcp.ur_w) {
- eltwise_injector_->compute_vector_range(0,
- jcp.nb_oc_blocking * jcp.ur_w);
- } else {
- for (int k = 0; k < jcp.nb_oc_blocking; k++)
- eltwise_injector_->compute_vector_range(k * jcp.ur_w,
- k * jcp.ur_w + ur_w);
- }
- }
-
- L(store_label);
- for (int k = 0; k < jcp.nb_oc_blocking; k++)
- for (int j = 0; j < ur_w; j++) {
- Vmm vmm = vmm_out(j, k);
- size_t aux_output_offset = (size_t)typesize *
- ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block;
- vmovups(EVEX_compress_addr_safe(reg_out, aux_output_offset,
- reg_out_long_offt), vmm);
- if (!is_owb_prefetching(jcp))
- mic_prefetcht0(EVEX_compress_addr_safe(reg_out_prf,
- aux_output_offset, reg_out_long_offt));
- }
-}
-
-template<typename Vmm>
-void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_4fma_1st(int ur_w,
- int pad_l, int pad_r)
-{
-}
-
-template<>
-void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_4fma_1st(int ur_w,
- int pad_l, int pad_r)
-{
- assert(jcp.dilate_d == 0 && jcp.dilate_h == 0 && jcp.dilate_w == 0);
-
- int iw = jcp.iw;
- int ih = jcp.ih;
- int kw = jcp.kw;
- int stride_w = jcp.stride_w;
- int ic_block = jcp.ic_block;
- int oc_block = jcp.oc_block;
-
- Label kh_label, kd_label;
-
- if (one_of(jcp.ndims, 3, 4)) {
- mov(aux_reg_inp, reg_inp);
- mov(aux_reg_ker, reg_ker);
- mov(aux_reg_inp_prf, reg_inp_prf);
- }
-
- size_t max_input_offset = (size_t)jcp.typesize_in
- * ((size_t)(kw + ur_w * stride_w - pad_l)
- + (size_t)ic_block * iw * ih * jcp.id);
- assert(reg_inp_prf == reg_long_offt);
- if (max_input_offset > INT_MAX) push(reg_inp_prf);
-
- if (jcp.ndims == 5) {
- push(reg_out_prf);
- push(reg_out);
-
- mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
- mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
- mov(aux_reg_inp_d, reg_inp);
- mov(aux_reg_inp_d_prf, reg_inp_prf);
-
- L(kd_label);
- }
- mov(reg_kj, reg_kh);
- if (jcp.ndims == 5) {
- mov(aux_reg_inp, aux_reg_inp_d);
- mov(aux_reg_ker, aux_reg_ker_d);
- mov(aux_reg_inp_prf, aux_reg_inp_d_prf);
- }
-
- L(kh_label);
- for (int ki = 0; ki < kw; ki += 4) {
- for (int ic = 0; ic < ic_block; ic++) {
- for (int i = 0; i < 4; i++) {
- int aux_ker_offset
- = jcp.typesize_in
- * ((ki + i) * oc_block
- + ic * kw * jcp.kh * jcp.kd * oc_block);
- if (ki + i < kw)
- vmovups(vmm_ker(i),
- EVEX_compress_addr(aux_reg_ker, aux_ker_offset));
- else
- vpxord(vmm_ker(i), vmm_ker(i), vmm_ker(i));
- }
-
- int j_start = get_ow_start(ki, pad_l);
- int j_end = get_ow_end(ur_w, ki, pad_r);
-
- for (int j = j_start, prf_count=0; j < j_end; j++) {
- size_t aux_input_offset = (size_t)jcp.typesize_in
- * ((size_t)(ki + j * stride_w
- - pad_l) + (size_t)ic * iw * ih * jcp.id);
- v4fmaddps(vmm_out(j, 0), vmm_ker(0),
- EVEX_compress_addr_safe(aux_reg_inp, aux_input_offset,
- reg_long_offt));
- if (ki + prf_count < kw && prf_count < 4
- && ((ki < 2 && j % 4) || j % 2)) {
- int aux_ker_offset = jcp.typesize_in
- * ((ki + prf_count) * oc_block
- + ic * kw * jcp.kh * jcp.kd * oc_block + kw * oc_block);
- mic_prefetcht0(EVEX_compress_addr(aux_reg_ker,
- aux_ker_offset));
- prf_count++;
- }
- if (ki == 0
- && j % (64 / (stride_w * jcp.typesize_in)) == 0) {
- mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp_prf,
- aux_input_offset, reg_long_offt));
- }
- if (ki == 1
- && j % (64 / (stride_w * jcp.typesize_in)) == 0) {
- mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp,
- aux_input_offset+jcp.typesize_in * iw, reg_long_offt));
- }
- }
- }
- }
- add(aux_reg_ker, jcp.typesize_in * kw * oc_block);
- add(aux_reg_inp, jcp.typesize_in * iw);
- add(aux_reg_inp_prf, jcp.typesize_in * iw);
-
- dec(reg_kj);
- cmp(reg_kj, 0);
- jg(kh_label, T_NEAR);
-
- if (jcp.ndims == 5) {
- add(aux_reg_inp_d, typesize * jcp.ih * jcp.iw);
- add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block);
- add(aux_reg_inp_d_prf, typesize * jcp.ih * jcp.iw);
-
- dec(reg_ki);
- cmp(reg_ki, 0);
- jg(kd_label, T_NEAR);
-
- pop(reg_out);
- pop(reg_out_prf);
- }
-
- if (max_input_offset > INT_MAX) pop(reg_inp_prf);
-}
-
-template<typename Vmm>
-void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_4fma(int ur_w,
- int pad_l, int pad_r)
-{
-}
-
-template<>
-void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_4fma(int ur_w,
- int pad_l, int pad_r)
-{
- int stride_w = jcp.stride_w;
- int ic_block = jcp.ic_block;
- int oc_block = jcp.oc_block;
- Label kh_label, last_iter_label, loop_end_label, kd_label;
- int ker_load_number = 4;
- int shift_kernel_ptr = typesize * jcp.kw * jcp.oc_block * jcp.ic_block;
- int shift_input_ptr = typesize * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block;
-
- bool check_last_kh = (jcp.kh > 3);
- bool pref_current_inp = (jcp.iw < 14 || jcp.iw > 28);
-
- int oi_ipref_t0 = get_ow_start(0, pad_l);
- int ow_end_ipref = get_ow_end(ur_w, 0, pad_r);
-
- assert(jcp.oc % jcp.nb_oc_blocking == 0);
-
- auto kernel_offset = [=](int ocb, int ic, int ki) {
- int blk_idx = ocb * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd + ki;
- int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
- int ic_offset = ic * jcp.oc_block;
- return typesize * (blk_offset + ic_offset);
- };
- auto kernel_loads = [=](int ki, int ic, int kk) {
- for (int ii = 0; ii < ker_load_number; ii++) {
- int aux_kernel_offset = kernel_offset(kk, ic + ii, ki);
- vmovups(vmm_ker(ii),
- EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
- }
- };
- auto prefetch_inp_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) {
- if (cnt1 >= ker_load_number && cnt0 >= ker_load_number
- && ki >= ki_start && oi_ipref_t0 < ow_end_ipref) {
- int aux_inp_offset
- = typesize
- * ((oi_ipref_t0 * stride_w - pad_l) * ic_block
- + (jcp.dilate_h + 1) * jcp.iw * ic_block);
- prefetcht0(EVEX_compress_addr(aux_reg_inp,
- aux_inp_offset));
- oi_ipref_t0++;
- }
- };
-
- if (one_of(jcp.ndims, 3, 4)) {
- mov(aux_reg_inp, reg_inp);
- mov(aux_reg_ker, reg_ker);
- mov(aux_reg_ker_prf, reg_ker_prf);
- mov(aux_reg_inp_prf, reg_inp_prf);
- }
-
- if (jcp.ndims == 5) {
- push(reg_out_prf);
- push(reg_out);
-
- mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
- mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
- mov(aux_reg_inp_d, reg_inp);
- mov(aux_reg_inp_d_prf, reg_inp_prf);
- mov(aux_reg_ker_d_prf, reg_ker_prf);
- L(kd_label);
- mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
- } else {
- mov(reg_kj, reg_kh);
- }
- if (jcp.ndims == 5) {
- mov(aux_reg_inp, aux_reg_inp_d);
- mov(aux_reg_ker, aux_reg_ker_d);
- mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
- mov(aux_reg_inp_prf, aux_reg_inp_d_prf);
- }
-
- align(16);
- L(kh_label);
- int kw = jcp.kw;
- if (check_last_kh) {
- for (int ki = 0; ki < kw; ki++)
- for (int ic = 0; ic < ic_block; ic += 4)
- for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
- bool last_kernel_loads = (kk == jcp.nb_oc_blocking - 1
- && ki == kw - 1 && (ic + 4) == ic_block);
-
- if (last_kernel_loads) {
- cmp(reg_kj, 1);
- je(last_iter_label, T_NEAR);
- }
-
- kernel_loads(ki, ic, kk);
- for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0,
- prf_count_t0 = 0;
- oi < get_ow_end(ur_w, ki, pad_r); oi++) {
- int aux_input_offset = typesize
- * ((ki * (jcp.dilate_w + 1) + oi * stride_w
- - pad_l) * ic_block
- + ic);
- v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
- EVEX_compress_addr(aux_reg_inp, aux_input_offset));
-
- if (oi % 2) {
- if (prf_count_t0 < 4) {
- int aux_kernel_prf;
- if (last_kernel_loads)
- aux_kernel_prf= kernel_offset(0,
- prf_count_t0 + ic + 4
- - ic_block, 0) + typesize * kw
- * oc_block * ic_block;
- else
- aux_kernel_prf = kernel_offset(kk, ic + 4
- + prf_count_t0, ki);
- mic_prefetcht0(EVEX_compress_addr(aux_reg_ker,
- aux_kernel_prf));
- prf_count_t0++;
- } else if (prf_count_t1 < 4) {
- mic_prefetcht1(EVEX_compress_addr(
- aux_reg_ker_prf, kernel_offset(kk, ic
- + prf_count_t1, ki)));
- prf_count_t1++;
- }
- } else
- prefetch_inp_next_kh(ki, 2, prf_count_t0,
- prf_count_t1);
- }
-
- if (last_kernel_loads) {
- jmp(loop_end_label, T_NEAR);
-
- L(last_iter_label);
-
- kernel_loads(ki, ic, kk);
- for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0,
- prf_count_t0 = 0;
- oi < get_ow_end(ur_w, ki, pad_r); oi++) {
- int aux_input_offset = typesize
- * ((ki * (jcp.dilate_w + 1) + oi * stride_w
- - pad_l) * ic_block
- + ic);
- v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
- EVEX_compress_addr(aux_reg_inp,
- aux_input_offset));
- if (oi % 2) {
- if (prf_count_t0 < 4) {
- mic_prefetcht0(EVEX_compress_addr(
- aux_reg_ker_prf, kernel_offset(0,
- prf_count_t0, 0)));
- prf_count_t0++;
- } else if (prf_count_t1 < 4) {
- mic_prefetcht1(EVEX_compress_addr(
- aux_reg_ker_prf, kernel_offset(kk,
- ic + prf_count_t1, ki)));
- prf_count_t1++;
- }
- }
- }
- L(loop_end_label);
- }
- }
- } else {
- for (int ki = 0; ki < kw; ki++)
- for (int ic = 0; ic < ic_block; ic += 4)
- for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
- kernel_loads(ki, ic, kk);
- for (int oi = get_ow_start(ki, pad_l),
- prf_count_t1 = 0, prf_count_t0 = 0;
- oi < get_ow_end(ur_w, ki, pad_r); oi++) {
- int aux_input_offset = typesize
- * ((ki * (jcp.dilate_w + 1) + oi * stride_w
- - pad_l) * ic_block + ic);
- v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
- EVEX_compress_addr(aux_reg_inp,
- aux_input_offset));
-
- if (!is_owb_prefetching(jcp)) {
- if ((oi % 2) && (prf_count_t1 < 4)) {
- mic_prefetcht1(EVEX_compress_addr(
- aux_reg_ker_prf, kernel_offset(kk,
- ic + prf_count_t1, ki)));
- prf_count_t1++;
- }
- } else {
- if (!(ki == 0 && ic == 0)
- && !(ki == kw-1 && ic == 0) &&
- (oi % 2) && (prf_count_t1 < 4)
- ) {
- mic_prefetcht0(EVEX_compress_addr(
- aux_reg_ker, kernel_offset(kk,
- ic + 4 + prf_count_t0, ki)));
- prf_count_t0++;
- }
- }
- if (!is_owb_prefetching(jcp)) {
- if (pref_current_inp) {
- if (ki == 0 && ic == 0 && kk == 0)
- mic_prefetcht0(EVEX_compress_addr(
- aux_reg_inp,
- aux_input_offset + shift_input_ptr));
- } else {
- if (ki == 1 && ic == 0 && kk == 0)
- mic_prefetcht1(EVEX_compress_addr(
- aux_reg_inp_prf, aux_input_offset));
- }
- } else {
- int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
- int inp_shift
- = jcp.typesize_in * ur_w * stride_w * inp_mult;
- bool kk_pref_slot = kk ? oi % 2 : !(oi % 2);
- if (ki == 0 && ic == 0 && kk_pref_slot)
- mic_prefetcht1(EVEX_compress_addr(
- aux_reg_inp,
- aux_input_offset + inp_shift));
-
- if (ki == kw - 1 && ic == 0 && kk_pref_slot)
- mic_prefetcht0(EVEX_compress_addr(
- aux_reg_inp,
- aux_input_offset + inp_shift));
- }
- }
- }
- }
-
- add(aux_reg_ker, shift_kernel_ptr);
- add(aux_reg_inp, shift_input_ptr);
- add(aux_reg_ker_prf, shift_kernel_ptr);
- add(aux_reg_inp_prf, shift_input_ptr);
-
- dec(reg_kj);
- cmp(reg_kj, 0);
- jg(kh_label, T_NEAR);
-
- if (jcp.ndims == 5) {
- add(aux_reg_inp_d,
- typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block);
- add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block
- * jcp.ic_block);
- add(aux_reg_inp_d_prf,
- typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block);
- add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block
- * jcp.ic_block);
-
- dec(reg_ki);
- cmp(reg_ki, 0);
- jg(kd_label, T_NEAR);
-
- pop(reg_out);
- pop(reg_out_prf);
- }
-}
-
-template<typename Vmm>
-void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma(int ur_w,
- int pad_l, int pad_r)
-{
- bool prf_ker = true;
- bool prf_inp = true;
- int ih = jcp.ih;
- int stride_w = jcp.stride_w;
- int id = jcp.id;
- int iw = jcp.iw;
- int kw = jcp.kw;
- int ic_block = jcp.ic_block;
- int oc_block = jcp.oc_block;
- int nb_oc_block = jcp.nb_oc_blocking;
- Label kh_label, kd_label;
-
- int ker_pipeline_depth = 4;
- assert(ker_reg_base_idx + ker_pipeline_depth <= 32);
- assert(oc_block >= ker_pipeline_depth);
-
- int num_ker_loads = ic_block * nb_oc_block * kw;
- int num_ker_prfs = prf_ker ? num_ker_loads : 0;
- int num_inp_prfs = prf_inp ?
- ur_w * nstl::min(kw, stride_w) + nstl::max(0, kw - stride_w) :
- 0;
- if (jcp.is_1stconv && prf_inp) {
- num_inp_prfs = div_up(num_inp_prfs, jcp.simd_w) * ic_block;
- }
- int num_prfs = num_ker_prfs + num_inp_prfs;
- int num_fmas = num_ker_loads * ur_w;
- int prf_inst_spacing
- = (prf_ker || prf_inp) ? nstl::max(1, num_fmas / num_prfs) : 1;
- int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2;
- int inp_mul = !jcp.is_1stconv ? ic_block : 1;
-
- if (one_of(jcp.ndims, 3, 4)) {
- mov(aux_reg_inp, reg_inp);
- mov(aux_reg_ker, reg_ker);
- mov(aux_reg_inp_prf, reg_inp_prf);
- mov(aux_reg_ker_prf, reg_ker_prf);
- }
-
- size_t max_input_offset = (size_t)jcp.typesize_in * ic_block * iw * ih * id;
- assert(reg_inp_prf == reg_long_offt);
- if (max_input_offset > INT_MAX) push(reg_inp_prf);
-
-
- if (jcp.ndims == 5) {
- push(reg_out_prf);
- push(reg_out);
-
- mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
- mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
- mov(aux_reg_inp_d, reg_inp);
- mov(aux_reg_inp_d_prf, reg_inp_prf);
- mov(aux_reg_ker_d_prf, reg_ker_prf);
-
- L(kd_label);
- mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
- } else {
- mov(reg_kj, reg_kh);
- }
-
- if (jcp.ndims == 5) {
- mov(aux_reg_inp, aux_reg_inp_d);
- mov(aux_reg_ker, aux_reg_ker_d);
- mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
- mov(aux_reg_inp_prf, aux_reg_inp_d_prf);
- }
-
- align(16);
- L(kh_label);
- {
- int step = 0;
- int ker_prfs = 0;
- for (int ki = 0; ki < kw; ki++) {
- for (int ic = 0; ic < ic_block; ic++) {
- int aux_kernel_offset = 0;
- if (step == 0) {
- for (int i = 0; i < ker_pipeline_depth; i++) {
- aux_kernel_offset = get_kernel_offset(ki, ic, 0, i);
- vmovups(vmm_ker(i), EVEX_compress_addr(
- aux_reg_ker, aux_kernel_offset));
- }
- } else if (step < num_ker_loads - ker_pipeline_depth + 1) {
- int load_offset = ker_pipeline_depth - 1;
- int ker_load_reg_idx
- = (step + load_offset) % ker_pipeline_depth;
- aux_kernel_offset
- = get_kernel_offset(ki, ic, 0, load_offset);
- vmovups(vmm_ker(ker_load_reg_idx),
- EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
- }
-
- bool ker_prf_inserted = false;
- Vmm vmm_kernel = vmm_ker(step % ker_pipeline_depth);
- int j_start = get_ow_start(ki, pad_l);
- int j_end = get_ow_end(ur_w, ki, pad_r);
- for (int j = j_start; j < j_end; j++) {
- size_t aux_input_offset = get_input_offset(ki, ic, j, pad_l);
- auto addr = EVEX_compress_addr_safe(aux_reg_inp,
- aux_input_offset, reg_long_offt, true);
- vfmadd231ps(vmm_out(j, 0), vmm_kernel, addr);
- int fma_idx = step * ur_w + j;
- int prf_slot_idx = fma_idx / prf_inst_spacing;
- if (fma_idx % prf_inst_spacing == prf_inst_trigger) {
- if (prf_ker && !ker_prf_inserted
- && ker_prfs < num_ker_prfs) {
- int ker_prf_offset
- = jcp.typesize_in * ker_prfs * jcp.oc_block;
- mic_prefetcht2(EVEX_compress_addr(
- aux_reg_ker_prf, ker_prf_offset));
- ker_prf_inserted = true;
- ker_prfs++;
- } else if (prf_inp) {
- int inp_prf_idx = prf_slot_idx - ker_prfs;
- if (inp_prf_idx < num_inp_prfs) {
- size_t inp_prf_stride = nstl::max(kw, stride_w);
- size_t inp_prf_offset;
- if (!jcp.is_1stconv) {
- inp_prf_offset
- = ic_block * jcp.typesize_in
- * ((inp_prf_idx / kw)
- * inp_prf_stride
- + (inp_prf_idx % kw));
- } else {
- size_t ic_prf_stride =
- (size_t)jcp.typesize_in * iw * ih * id;
- size_t iw_prf_stride
- = jcp.typesize_in * jcp.simd_w;
- inp_prf_offset = ((inp_prf_idx / ic_block)
- * iw_prf_stride
- + (inp_prf_idx % ic_block)
- * ic_prf_stride);
- }
- mic_prefetcht0(EVEX_compress_addr_safe(
- aux_reg_inp_prf, inp_prf_offset,
- reg_long_offt));
- }
- }
- }
- }
- step++;
- }
- }
- add(aux_reg_ker, jcp.typesize_in * kw * oc_block * ic_block);
- if (prf_ker)
- add(aux_reg_ker_prf, jcp.typesize_in * kw * oc_block * ic_block);
- add(aux_reg_inp, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
- if (prf_inp)
- add(aux_reg_inp_prf,
- jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
- dec(reg_kj);
- cmp(reg_kj, 0);
- jg(kh_label, T_NEAR);
- }
-
-
- if (jcp.ndims == 5) {
- add(aux_reg_inp_d,
- typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul);
- add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block
- * jcp.ic_block);
- add(aux_reg_inp_d_prf,
- typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul);
- add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block
- * jcp.ic_block);
-
- dec(reg_ki);
- cmp(reg_ki, 0);
- jg(kd_label, T_NEAR);
-
- pop(reg_out);
- pop(reg_out_prf);
- }
- if (max_input_offset > INT_MAX) pop(reg_inp_prf);
-}
-
-template<typename Vmm>
-void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma_core(int ur_w,
- int pad_l, int pad_r)
-{
- int kw = jcp.kw;
- int stride_w = jcp.stride_w;
- int ic_block = jcp.ic_block;
- int oc_block = jcp.oc_block;
- int nb_oc_block = jcp.nb_oc_blocking;
- Label kh_label, kd_label;
- int shift_kernel_ptr = jcp.typesize_in * jcp.kw * jcp.oc_block
- * jcp.ic_block;
- int inp_mul = !jcp.is_1stconv ? ic_block : 1;
- int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
- * inp_mul;
-
-
- auto input_offset = [=](int oi, int ic, int ki) {
- return (size_t)jcp.typesize_in
- * ((size_t)(ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l)
- * inp_mul + (size_t)ic
- * (!jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id));
- };
-
- if (one_of(jcp.ndims, 3, 4)) {
- mov(aux_reg_inp, reg_inp);
- mov(aux_reg_ker, reg_ker);
- }
-
- if (jcp.ndims == 5) {
- push(reg_out);
-
- mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
- mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
- mov(aux_reg_inp_d, reg_inp);
-
- L(kd_label);
- mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
- } else {
- mov(reg_kj, reg_kh);
- }
-
- if (jcp.ndims == 5) {
- mov(aux_reg_inp, aux_reg_inp_d);
- mov(aux_reg_ker, aux_reg_ker_d);
- }
-
- L(kh_label);
- {
- for (int ki = 0; ki < kw; ki++) {
- int jj_start = get_ow_start(ki, pad_l);
- int jj_end = get_ow_end(ur_w, ki, pad_r);
- for (int ic = 0; ic < ic_block; ic++) {
- if (jcp.kernel_kind == expl_bcast) {
- for (int jj = jj_start; jj < jj_end; jj++) {
- size_t aux_input_offset = input_offset(jj, ic, ki);
- vbroadcastss(vmm_inp(jj, nb_oc_block),
- EVEX_compress_addr_safe(aux_reg_inp,
- aux_input_offset, reg_long_offt));
- }
- }
- for (int ii = 0; ii < nb_oc_block; ii++) {
- int aux_kernel_offset = jcp.typesize_in
- * (ii * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd * ic_block
- * oc_block + ki * ic_block * oc_block + ic * oc_block);
- if (jj_end - jj_start > 0)
- vmovups(vmm_wei, EVEX_compress_addr(aux_reg_ker,
- aux_kernel_offset));
- for (int jj = jj_start; jj < jj_end; jj++)
- if (jcp.kernel_kind == expl_bcast)
- vfmadd231ps(vmm_out(jj, ii),
- vmm_inp(jj, nb_oc_block), vmm_wei);
- else {
- size_t aux_input_offset = input_offset(jj, ic, ki);
- vfmadd231ps(vmm_out(jj, ii), vmm_wei,
- EVEX_compress_addr_safe(aux_reg_inp,
- aux_input_offset, reg_long_offt, true));
- }
- }
- }
- }
- add(aux_reg_ker, shift_kernel_ptr);
- add(aux_reg_inp, shift_input_ptr);
- dec(reg_kj);
- cmp(reg_kj, 0);
- jg(kh_label, T_NEAR);
- }
-
- if (jcp.ndims == 5) {
- add(aux_reg_inp_d,
- typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul);
- add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block
- * jcp.ic_block);
-
- dec(reg_ki);
- cmp(reg_ki, 0);
- jg(kd_label, T_NEAR);
-
- pop(reg_out);
- }
-}
-
-template<typename Vmm>
-void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop(int ur_w,
- int pad_l, int pad_r)
-{
- if (jcp.ndims == 5) push(reg_oi);
-
- prepare_output(ur_w);
-
- Label skip_compute_loop;
- if (jcp.ndims == 5) {
- if ((jcp.dilate_d >= jcp.id)
- || (jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
- mov(reg_kj, ptr[param1 + GET_OFF(kd_padding)]);
- cmp(reg_kj, 0);
- je(skip_compute_loop, T_NEAR);
- }
- }
- if ((jcp.dilate_h >= jcp.ih)
- || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
- mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
- cmp(reg_kj, 0);
- je(skip_compute_loop, T_NEAR);
- }
-
- if (jcp.ver == ver_4fma)
- if(jcp.is_1stconv)
- compute_loop_4fma_1st(ur_w, pad_l, pad_r);
- else
- compute_loop_4fma(ur_w, pad_l, pad_r);
- else if (jcp.ver == ver_fma)
- if ((jcp.is_1stconv && jcp.kernel_kind != expl_bcast)
- || mayiuse(avx512_mic))
- compute_loop_fma(ur_w, pad_l, pad_r);
- else
- if (jcp.kernel_kind == embd_bcast && jcp.nb_oc_blocking == 1)
- compute_loop_fma(ur_w, pad_l, pad_r);
- else
- compute_loop_fma_core(ur_w, pad_l, pad_r);
- else
- assert(!"unknown convolution version");
-
- L(skip_compute_loop);
- store_output(ur_w);
- if (jcp.ndims == 5) pop(reg_oi);
-}
-
-template<typename Vmm>
-void _jit_avx512_common_conv_fwd_kernel<Vmm>::generate()
-{
- int iw = jcp.iw;
- int ow = jcp.ow;
- int ow_block = jcp.ow_block;
- int nb_ow = jcp.nb_ow;
- int kw = jcp.kw;
- int l_pad = jcp.l_pad;
- int ur_w = jcp.ur_w;
- int ur_w_tail = jcp.ur_w_tail;
- int dilate_w = jcp.dilate_w + 1;
- int stride_w = jcp.stride_w;
-
- int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
- int inp_shift_pad = jcp.typesize_in * (ur_w * stride_w - l_pad) * inp_mult;
- int inp_shift = jcp.typesize_in * ur_w * stride_w * inp_mult;
- int inp_shift_pad_second_block = -1 * jcp.typesize_in * l_pad * inp_mult;
- int out_shift = jcp.typesize_out * ur_w * jcp.oc_block;
-
- preamble();
- mov(reg_inp, ptr[param1 + GET_OFF(src)]);
- mov(reg_out, ptr[param1 + GET_OFF(dst)]);
- mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
- mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
- mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
-
- int r_pad = nstl::max(
- 0, (ow - 1) * stride_w + (kw - 1) * dilate_w - (iw + l_pad - 1));
- int n_oi = ow / ur_w;
- int r_pad1 = (ur_w * n_oi - 1) * stride_w + (kw - 1) * dilate_w
- - (iw + l_pad - 1);
-
- if (!is_ow_threading_on(jcp)) {
- // ow is being processed as a whole - with left and right paddings
- if (r_pad1 > 0) n_oi--;
-
- if (ow == ur_w) {
- mov(reg_inp_prf, ptr[param1 + GET_OFF(src_prf)]);
- mov(reg_out_prf, ptr[param1 + GET_OFF(dst_prf)]);
- compute_loop(ur_w, l_pad, r_pad);
- } else {
- mov(reg_inp_prf, reg_inp);
- mov(reg_out_prf, reg_out);
- if (n_oi == 0) {
- add(reg_inp_prf, inp_shift_pad);
- add(reg_out_prf, out_shift);
- compute_loop(ur_w, l_pad, r_pad1);
- add(reg_inp, inp_shift_pad);
- add(reg_out, out_shift);
- if (ur_w_tail != 0) {
- add(reg_inp_prf, inp_shift);
- add(reg_out_prf, out_shift);
- compute_loop(ur_w_tail, 0, r_pad);
- }
- } else {
- xor_(reg_oi, reg_oi);
- if (l_pad > 0) {
- add(reg_inp_prf, inp_shift_pad);
- add(reg_out_prf, out_shift);
- compute_loop(ur_w, l_pad, 0);
- add(reg_inp, inp_shift_pad);
- add(reg_out, out_shift);
- inc(reg_oi);
- }
- if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) {
- Label ow_loop_label;
- L(ow_loop_label);
- {
- add(reg_inp_prf, inp_shift);
- add(reg_out_prf, out_shift);
- compute_loop(ur_w, 0, 0);
- add(reg_inp, inp_shift);
- add(reg_out, out_shift);
- inc(reg_oi);
- cmp(reg_oi, n_oi);
- jl(ow_loop_label, T_NEAR);
- }
- }
- if (r_pad1 > 0) {
- add(reg_inp_prf, inp_shift);
- add(reg_out_prf, out_shift);
- compute_loop(ur_w, 0, r_pad1);
- add(reg_inp, inp_shift);
- add(reg_out, out_shift);
- }
- if (ur_w_tail != 0) {
- add(reg_inp_prf, inp_shift);
- add(reg_out_prf, out_shift);
- compute_loop(ur_w_tail, 0, r_pad);
- }
- }
- }
- } else {
- // ow block is only processed.
- // Number of block is passed as parameter owb,
- // and padding processing depends on this number.
-
- Label end_label, last_oi_label, middle_ow_blocks_label, tail_label;
- Label oi_loop_label, oi_loop_start_label, oi_loop_end_label;
-
- assert(ow_block % ur_w == 0);
- int n_oi_not_last_ow_block = ow_block / ur_w;
- // to simplify code (and general regs usage),
- // size of ow block must be >= 2 * ur_w
- assert(n_oi_not_last_ow_block > 1);
- int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
- int n_oi_first_ow_block = n_oi_not_last_ow_block;
-
- int n_oi_last_ow_block = (ow - ow_block * (nb_ow-1)) / ur_w;
-
- // prepare right padding
- bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
- bool first_ow_block_padded = next_last_ow_block_padded && jcp.nb_ow == 2;
- bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0;
-
- if (last_ow_block_padded) n_oi_last_ow_block--;
- else if (first_ow_block_padded) n_oi_first_ow_block--;
- else if (next_last_ow_block_padded) n_oi_next_last_ow_block--;
-
- mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
- cmp(reg_owb, 0); // is that the first ow-block ?
- jg(middle_ow_blocks_label, T_NEAR);
-
- // the first ow block, compute left padding
-
- mov(reg_oi, n_oi_first_ow_block);
- mov(reg_inp_prf, reg_inp);
- mov(reg_out_prf, reg_out);
-
- if (l_pad > 0) {
- mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
- add(reg_inp_prf, inp_shift_pad);
- add(reg_out_prf, out_shift);
- compute_loop(ur_w, l_pad, 0);
- add(reg_inp, inp_shift_pad);
- add(reg_out, out_shift);
- dec(reg_oi);
- }
- jmp(oi_loop_label, T_NEAR);
-
- // middle or last ow block entry
-
- L(middle_ow_blocks_label);
-
- if (l_pad > 0) {
- // just to consider left padding, not compute
- add(reg_inp, inp_shift_pad_second_block);
- add(reg_inp_prf, inp_shift_pad_second_block);
- }
-
- // set number of iteration for oi-loop
- cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
- mov(reg_oi, n_oi_last_ow_block);
- je(oi_loop_label, T_NEAR);
- cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
- mov(reg_oi, n_oi_next_last_ow_block);
- je(oi_loop_label, T_NEAR);
- mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
-
- // oi loop w/o padding
- L(oi_loop_label);
- mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
- L(oi_loop_start_label);
- cmp(reg_oi, 0);
- jle(oi_loop_end_label, T_NEAR);
-
- add(reg_inp_prf, inp_shift);
- add(reg_out_prf, out_shift);
- compute_loop(ur_w, 0, 0);
- add(reg_inp, inp_shift);
- add(reg_out, out_shift);
- dec(reg_oi);
- jmp(oi_loop_start_label, T_NEAR);
- L(oi_loop_end_label);
-
- mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
-
- cmp(reg_owb, 0); // first ow-block ?
- if (first_ow_block_padded) {
- je(last_oi_label, T_NEAR);
- } else {
- je(end_label, T_NEAR);
- }
- cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
- jl(end_label, T_NEAR);
- if (next_last_ow_block_padded) {
- je(last_oi_label, T_NEAR);
- } else {
- je(end_label, T_NEAR);
- }
- // that is last block
- if (!last_ow_block_padded) {
- jmp(tail_label, T_NEAR);
- }
-
- // last oi block with right padding
- L(last_oi_label);
- mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
- add(reg_inp_prf, inp_shift);
- add(reg_out_prf, out_shift);
- compute_loop(ur_w, 0, r_pad1);
- add(reg_inp, inp_shift);
- add(reg_out, out_shift);
-
- mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
- cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
- jl(end_label, T_NEAR);
-
- L(tail_label);
- mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
- if (ur_w_tail != 0) {
- add(reg_inp_prf, inp_shift);
- add(reg_out_prf, out_shift);
- compute_loop(ur_w_tail, 0, r_pad);
- }
- L(end_label);
- }
- postamble();
-
- if (jcp.with_eltwise)
- eltwise_injector_->prepare_table();
-}
-
-bool jit_avx512_common_conv_fwd_kernel::post_ops_ok(
- jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
- const auto &p = attr.post_ops_;
-
- auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
- auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
-
- switch (p.len_) {
- case 0: return true; // no post_ops
- case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
- case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
- default: return false;
- }
-
- return false;
-}
-
-status_t jit_avx512_common_conv_fwd_kernel::init_conf(
- jit_conv_conf_t &jcp, const convolution_desc_t &cd,
- memory_desc_t &src_md, memory_desc_t &weights_md,
- memory_desc_t &dst_md, memory_desc_t &bias_md,
- const primitive_attr_t &attr, int nthreads)
-{
- using namespace prop_kind;
-
- if (!mayiuse(avx512_common))
- return status::unimplemented;
-
- const memory_desc_wrapper src_d(&src_md);
- const memory_desc_wrapper weights_d(&weights_md);
- const memory_desc_wrapper dst_d(&dst_md);
- const memory_desc_wrapper bias_d(&bias_md);
-
- const int regs = 28;
- const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
- int ndims = src_d.ndims();
-
- jcp = zero<decltype(jcp)>();
- jcp.ndims = ndims;
- jcp.prop_kind = cd.prop_kind;
- jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
- jcp.mb = src_d.dims()[0];
- jcp.oc = dst_d.dims()[1] / jcp.ngroups;
- jcp.oc_without_padding = jcp.oc;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
- jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
- jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
- jcp.iw = src_d.dims()[ndims-1];
- jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
- jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims-2];
- jcp.ow = dst_d.dims()[ndims-1];
- jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
- jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2];
- jcp.kw = weights_d.dims()[with_groups + ndims-1];
- jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
- jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
- jcp.l_pad = cd.padding[0][ndims-3];
- jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
- jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
- jcp.stride_w = cd.strides[ndims-3];
-
- jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
- jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
- jcp.dilate_w = cd.dilates[ndims-3];
-
- jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
- - (jcp.ih + jcp.t_pad - 1);
- jcp.back_pad = (jcp.od - 1) * jcp.stride_d
- + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1);
-
- jcp.is_1stconv = is_1stconv(jcp);
-
- bool ok_to_pad_channels = true
- && jcp.ngroups == 1
- && src_d.data_type() == data_type::f32;
-
- const int full_simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
- jcp.simd_w = full_simd_w;
- bool ok_to_try_xmm = true
- && mayiuse(avx512_core)
- && src_d.data_type() == data_type::f32
- && !jcp.is_1stconv
- && !ok_to_pad_channels
- && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0)
- && (jcp.ic % 8 != 0 || jcp.oc % 8 != 0);
- if (ok_to_try_xmm)
- jcp.simd_w = 4;
-
- jcp.oc_block = jcp.simd_w;
- jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
- jcp.aligned_threads = 0;
-
- if (ok_to_pad_channels) {
- jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
- jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
- }
- bool args_ok = true
- && jcp.oc % jcp.oc_block == 0
- && jcp.ic % jcp.ic_block == 0;
- if (!args_ok)
- return status::unimplemented;
-
- if (!post_ops_ok(jcp, attr))
- return status::unimplemented;
-
- const auto &p = attr.post_ops_;
- jcp.with_sum = p.find(primitive_kind::sum) != -1;
- const int eltwise_ind = p.find(primitive_kind::eltwise);
- jcp.with_eltwise = eltwise_ind != -1;
- if (jcp.with_eltwise) {
- jcp.eltwise = p.entry_[eltwise_ind].eltwise;
- if (dst_d.data_type() == data_type::s32) return status::unimplemented;
- }
-
- auto src_tag = jcp.is_1stconv
- ? pick(ndims - 3, ncw, nchw, ncdhw)
- : ((jcp.simd_w == 4)
- ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c)
- : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c));
- auto dst_tag = (jcp.simd_w == 4)
- ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c)
- : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
- auto wei_tag = with_groups
- ? ((jcp.simd_w == 4)
- ? pick(ndims - 3, gOIw4i4o, gOIhw4i4o, gOIdhw4i4o)
- : pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o))
- : ((jcp.simd_w == 4)
- ? pick(ndims - 3, OIw4i4o, OIhw4i4o, OIdhw4i4o)
- : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o));
-
- if (src_d.format_kind() == format_kind::any) {
- CHECK(memory_desc_init_by_tag(src_md, src_tag));
- jcp.src_tag = src_tag;
- } else {
- jcp.src_tag = src_d.matches_one_of_tag(src_tag);
- }
- if (jcp.src_tag != src_tag)
- return status::unimplemented;
-
- if (dst_d.format_kind() == format_kind::any) {
- CHECK(memory_desc_init_by_tag(dst_md, dst_tag));
- jcp.dst_tag = dst_tag;
- } else {
- jcp.dst_tag = dst_d.matches_one_of_tag(dst_tag);
- }
- if (jcp.dst_tag != dst_tag)
- return status::unimplemented;
-
- jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
- if (jcp.with_bias) {
- if (bias_d.format_kind() == format_kind::any)
- CHECK(memory_desc_init_by_tag(bias_md, x));
- }
-
- if (mayiuse(avx512_common) &&
- src_d.data_type() == data_type::f32
- && weights_d.data_type() == data_type::f32
- && dst_d.data_type() == data_type::f32) {
- jcp.ver = ver_fma;
- jcp.typesize_in = sizeof(float);
- jcp.typesize_out = sizeof(float);
- if (mayiuse(avx512_mic_4ops))
- jcp.ver = ver_4fma;
-
- if (jcp.is_1stconv) {
- // TODO: fix & remove constraints below
- bool not_for_4fma
- = IMPLICATION(everyone_is(0, jcp.l_pad, jcp.t_pad),
- nstl::max(jcp.kw, jcp.kh) < 7);
- bool is_dilated
- = !everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w);
- if (one_of(true, not_for_4fma, is_dilated))
- jcp.ver = ver_fma;
- if (jcp.ver == ver_4fma) {
- wei_tag = with_groups
- ? ((jcp.simd_w == 4)
- ? pick(ndims - 3, gOiw4o, gOihw4o, gOidhw4o)
- : pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o))
- : ((jcp.simd_w == 4)
- ? pick(ndims - 3, Oiw4o, Oihw4o, Oidhw4o)
- : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o));
- } else {
- wei_tag = with_groups
- ? ((jcp.simd_w == 4)
- ? pick(ndims - 3, gOwi4o, gOhwi4o, gOdhwi4o)
- : pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o))
- : ((jcp.simd_w == 4)
- ? pick(ndims - 3, Owi4o, Ohwi4o, Odhwi4o)
- : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o));
- }
- }
- } else {
- return status::unimplemented;
- }
-
- if (weights_d.format_kind() == format_kind::any) {
- CHECK(memory_desc_init_by_tag(weights_md, wei_tag));
- jcp.wei_tag = wei_tag;
- } else {
- jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
- }
- if (jcp.wei_tag != wei_tag)
- return status::unimplemented;
-
- if (jcp.is_1stconv) {
- jcp.ur_w = nstl::min(jcp.ow, regs);
- } else {
- // avx512_core guard - just to avoid possible regression for other archs
- if (jcp.ver == ver_fma && mayiuse(avx512_core)) {
- jcp.ur_w = nstl::min(jcp.ow, regs);
- } else {
- for (int ur_w = regs; ur_w > 0; --ur_w) {
- if (jcp.ow % ur_w == 0) {
- jcp.ur_w = ur_w;
- break;
- }
- }
- }
- if ((ndims == 5 && jcp.ur_w <= 8) || (jcp.ur_w <= 1)) {
- jcp.ur_w = nstl::min(jcp.ow, regs);
- }
- }
- // TODO (Tanya): currently applied to Segnet convolutions only.
- // Need to try for other topologies
- if (jcp.ow > 150 && jcp.ur_w < regs/2)
- jcp.ur_w = regs;
-
- int n_oi = (jcp.ow / jcp.ur_w);
- int r_pad = (jcp.ur_w * n_oi - 1) * jcp.stride_w
- + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1);
- if (jcp.l_pad > 0 && r_pad > 0)
- n_oi--;
-
- bool large_code_size = jcp.ur_w != jcp.ow && jcp.l_pad > 0 && r_pad > 0
- && ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1));
- if (large_code_size) {
- const int max_code_size = 24 * 1024;
- const int num_ops_per_reg = 6 + jcp.ic_block * jcp.kw;
- int mult = 1;
- if (jcp.l_pad > 0) mult += 1;
- if (r_pad > 0) mult += 1;
- for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) {
- if (ur_w * mult * num_ops_per_reg * 9.0 < max_code_size) {
- jcp.ur_w = ur_w;
- break;
- }
- }
- }
-
- /* Grouped channel offset to support 'non-blocked data' format for
- * convolution sizes with '(input_channel / ngroups) < simd' */
- jcp.nonblk_group_off
- = (jcp.ngroups > 1 && one_of(jcp.src_tag, ncw, nchw, ncdhw)) ?
- jcp.ic :
- 1;
-
- jcp.nb_ic = jcp.ic / jcp.ic_block;
- jcp.nb_oc = jcp.oc / jcp.oc_block;
- jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
-
- auto is_ow_threading_applicable = [=]() {
- return (true && !jcp.is_1stconv && one_of(jcp.ndims, 3, 4)
- && IMPLICATION(mayiuse(avx512_mic),
- jcp.ver == ver_4fma
- && IMPLICATION(jcp.mb != 1,
- jcp.ih == 1 && jcp.kh == 1)));
- };
-
- if (jcp.ver == ver_4fma && !jcp.is_1stconv) {
- if ((jcp.kw <= 5 && jcp.kh <= 5 && jcp.kw == jcp.kh && jcp.ow <= 8
- && jcp.oh <= 8 && jcp.ow == jcp.oh)
- || (jcp.stride_h != 1 && jcp.ur_w < jcp.ow)) {
- if (jcp.nb_oc % 2 == 0) {
- jcp.nb_oc_blocking = 2;
- jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking);
- }
- } else {
- for (int i = jcp.nb_oc; i > 0; i--)
- if (i * jcp.ur_w <= regs && jcp.nb_oc % i == 0) {
- jcp.nb_oc_blocking = i;
- break;
- }
- }
- if (jcp.ver == ver_4fma && is_ow_threading_applicable()) {
- if (jcp.nb_oc % 2 == 0 && jcp.ur_w < jcp.ow
- && jcp.ow != 2 * jcp.ur_w) {
- jcp.nb_oc_blocking = 2;
- jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking);
- }
- }
- }
-
- jcp.ow_block = jcp.ow;
-
- auto get_thr_eff = [=](int nb_oc_blocking, int ow_block) {
- int nb_ow = div_up(jcp.ow, ow_block);
- int nb_oc_chunks = div_up(jcp.nb_oc, nb_oc_blocking);
- int work_amount = jcp.mb * jcp.oh * nb_oc_chunks * nb_ow;
- float disbalance = (float)jcp.ow / rnd_up(jcp.ow, ow_block);
- float thr_eff = disbalance * (float)work_amount
- / rnd_up(work_amount, nthreads);
- return thr_eff;
- };
-
- auto get_ow_block = [=](int nb_oc_blocking, int ur_w, float &eff) {
- int res_ow_block = jcp.ow;
- eff = get_thr_eff(nb_oc_blocking, res_ow_block);
- if (!is_ow_threading_applicable())
- return res_ow_block;
-
- int L2_part = (get_cache_size(2) * 7 / 8) / typesize;
- if (jcp.ver == ver_4fma)
- L2_part /= 2;
- int size_src_chunk = jcp.ic_block * ur_w * jcp.kh;
- int size_dst_chunk = jcp.oc_block * nb_oc_blocking * ur_w;
- int size_wei_chunk = jcp.oc_block * nb_oc_blocking * jcp.ic_block
- * jcp.kw * jcp.kh;
- int nurw_cache = (L2_part - 2 * size_wei_chunk)
- / (2 * size_dst_chunk + 2 * size_src_chunk);
- // current design of generate() requires ow_block >= 2 * ur_w
- int ow_block_cache = ur_w * nstl::max(2, nurw_cache);
-
- int ow_block_thr = ow_block_cache;
- eff = get_thr_eff(nb_oc_blocking, ow_block_thr);
-
- int max_nb_ow = div_up(jcp.ow, 2 * ur_w);
- int start_nb_ow = div_up(jcp.ow, ow_block_thr);
- for (int nb_ow = start_nb_ow; nb_ow <= max_nb_ow; nb_ow++) {
- int ow_block
- = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), ur_w), jcp.ow);
- float eff_threshold = (jcp.ver == ver_4fma) ? 0.8f : 0.9f;
- if (ow_block < nb_oc_blocking * jcp.oc_block && eff > eff_threshold)
- break;
- if (div_up(jcp.ow, ow_block) != nb_ow)
- continue;
- float thr_eff = get_thr_eff(nb_oc_blocking, ow_block);
- float eff_step = (jcp.ver == ver_4fma) ? 1.1f : 1.f;
- if (ow_block >= 2 * ur_w && thr_eff > eff_step * eff) {
- ow_block_thr = ow_block;
- eff = thr_eff;
- }
- eff_threshold = (jcp.ver == ver_4fma) ? 0.9f : 0.98f;
- if (eff > eff_threshold)
- break;
- }
- res_ow_block = nstl::min(jcp.ow, nstl::max(2 * ur_w, ow_block_thr));
- eff = get_thr_eff(nb_oc_blocking, res_ow_block);
- return res_ow_block;
- };
-
-
- if (jcp.ver == ver_fma && mayiuse(avx512_core)) {
- int try_nb_oc_blocking = 2;
- unsigned int ker_inp_size = typesize * div_up(jcp.iw, jcp.stride_w)
- * jcp.ic_block * jcp.kh * jcp.kd;
- unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block
- * try_nb_oc_blocking;
- unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block
- * jcp.oc_block * try_nb_oc_blocking * jcp.kd;
- unsigned int ker_total_size = ker_inp_size + ker_out_size
- + ker_wei_size;
-
- bool embd_bcast_condition = true
- && (jcp.kw == 3 && jcp.ow <= 28 && ker_total_size < L1_cache_size)
- && !(jcp.kw == 3 && jcp.ow == 13 && jcp.ic >= 192)
- && !(jcp.kw == 3 && jcp.ow == 28 && jcp.ic >= 512);
-
- if (jcp.mb == 1) {
- unsigned int inp_size = jcp.mb * div_up(jcp.ih, jcp.stride_h)
- * div_up(jcp.iw, jcp.stride_w) * jcp.ic;
- unsigned int wei_size = jcp.ic * jcp.oc * jcp.kh * jcp.kw;
-
- // Estimate whether we need to limit the number of threads
- // and calculate this number. Includes some heuristic.
- int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
- int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh;
- int job_size_min = work_amount / nthreads;
- int job_size_max = div_up(work_amount, nthreads);
- int ch_max = rnd_up(jcp.oh, job_size_max);
- int ch_min = (job_size_min == 0)
- ? jcp.oh
- : rnd_up(jcp.oh, job_size_min);
- bool not_aligned_max = ch_max % jcp.oh != 0 && ch_max / jcp.oh < 2
- && (jcp.oh != 8 || ch_max / jcp.oh > 1);
- bool not_aligned_min = ch_min % jcp.oh != 0 && ch_min / jcp.oh < 2
- && (jcp.oh != 8 || ch_min / jcp.oh > 1);
- bool eligible_case = (jcp.stride_h == 1 && jcp.stride_w == 1)
- || nthreads > oc_chunks;
- if (jcp.loop_order == loop_cgn && oc_chunks > 1 && nthreads > 1
- && wei_size / inp_size > 24
- && (not_aligned_max || not_aligned_min)
- && eligible_case) {
- // Try to find nthreads > mkldnn_get_max_threads() / 2 such
- // that oc_chunks is a multiple of nthreads, or nthreads is a
- // multiple of oc_chunks. Otherwise, keep default value.
- // TODO: implement a task-based alternative without throttling.
- jcp.aligned_threads = nthreads;
- for (int i = nthreads; i > nthreads / 2; i--) {
- if (oc_chunks % i == 0 || i % oc_chunks == 0) {
- jcp.aligned_threads = i;
- break;
- }
- }
- }
- }
-
- if (jcp.kw > 3
- || (jcp.stride_w == 1 && jcp.stride_h == 1
- && embd_bcast_condition)
- || ((jcp.stride_w != 1 || jcp.stride_h != 1)
- && ((jcp.mb <= 16 && (jcp.oc <= 192 || jcp.oh <= 10)
- && embd_bcast_condition)))
- || (jcp.mb == 1
- && (jcp.ur_w >= jcp.ow || jcp.is_1stconv
- || (jcp.ow <= 147 && jcp.oc <= 96)))) {
- jcp.kernel_kind = embd_bcast;
- jcp.ur_w = nstl::min(jcp.ow, regs);
- jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
- if (ker_total_size < L1_cache_size && jcp.ow <= 8 && jcp.kh <= 3
- && jcp.kw <= 3 && jcp.nb_oc % try_nb_oc_blocking == 0
- && IMPLICATION(jcp.is_1stconv, jcp.mb == 1)
- && IMPLICATION(jcp.mb == 1, jcp.ur_w < jcp.ow)) {
- jcp.nb_oc_blocking = try_nb_oc_blocking;
- jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1));
- }
- } else {
- jcp.kernel_kind = expl_bcast;
- jcp.nb_ic_blocking = 1;
- if (IMPLICATION(jcp.is_1stconv, jcp.mb > 1)) {
- float best_thr_eff = 0.f;
- int best_nb_oc_blocking = 1;
- for (int i = nstl::min(jcp.nb_oc, 5); i > 0; i--) {
- if (jcp.nb_oc % i == 0) {
- float thr_eff;
- int ur_w = nstl::min(jcp.ow, 31 / (i + 1));
- get_ow_block(i, ur_w, thr_eff);
- if (thr_eff > 1.05f * best_thr_eff) {
- best_nb_oc_blocking = i;
- best_thr_eff = thr_eff;
- }
- }
- }
- jcp.nb_oc_blocking = best_nb_oc_blocking;
- jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1));
- }
- }
- }
-
- jcp.ur_w_tail = jcp.ow % jcp.ur_w;
-
- args_ok = true
- && jcp.l_pad <= jcp.ur_w
- && jcp.ic <= src_d.padded_dims()[1]
- && jcp.oc <= dst_d.padded_dims()[1]
- && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
- && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
- if (!args_ok)
- return status::unimplemented;
-
- int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
- + (jcp.kw - 1) * (jcp.dilate_w + 1)
- - (jcp.iw + jcp.l_pad - 1));
- if (r_pad_no_tail > jcp.ur_w)
- return status::unimplemented;
-
- pick_loop_order(jcp);
-
- jcp.nb_ic_L2 = jcp.nb_ic;
-
- float thr_eff;
- jcp.ow_block = get_ow_block(jcp.nb_oc_blocking, jcp.ur_w, thr_eff);
- jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
-
- const int L2_size = get_cache_size(2, true) / sizeof(float);
- // Source and output data needs to fit in L2,
- // leaving some space for weights and prefetching.
- int h_L2 = int(((0.6f * L2_size) / jcp.simd_w
- - nstl::min(0, jcp.kh - jcp.stride_h) * jcp.iw)
- / (jcp.stride_h * jcp.iw + jcp.ow));
- jcp.h_blocking = nstl::max(1, nstl::min(jcp.oh, h_L2));
-
- if (jcp.ver == ver_4fma) {
- if (!is_ow_threading_on(jcp)) {
- for (int divf = 2, temp_nb = jcp.nb_ic_L2; divf <= jcp.nb_ic;
- divf++) {
- size_t l2_src
- = (size_t)jcp.iw * jcp.ic_block * jcp.ih * temp_nb * jcp.id;
- size_t l2_dst = (size_t)jcp.ow * jcp.oc_block * jcp.nb_oc_blocking
- * jcp.oh * jcp.od;
- size_t l2_filt = (size_t)jcp.kw * jcp.oc_block * jcp.ic_block
- * jcp.kh * jcp.nb_oc_blocking * temp_nb * jcp.kd;
- if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) {
- if (jcp.kh == 3 && jcp.oh == 7) {
- jcp.nb_ic_L2 = 1;
- break;
- }
- temp_nb = (jcp.nb_ic_L2 % divf == 0 ? jcp.nb_ic_L2 / divf
- : jcp.nb_ic_L2);
- } else {
- jcp.nb_ic_L2 = temp_nb;
- break;
- }
- }
- } else if (jcp.ic > 64) {
- jcp.nb_ic_L2 = 2; /* according to performance data*/
- }
- }
-
- return status::success;
-}
-
-void jit_avx512_common_conv_fwd_kernel::init_scratchpad(
- memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
- if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
- scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
-}
-
-void jit_avx512_common_conv_bwd_data_kernel_f32::prepare_output(int ur_w)
-{
- for (int k = 0; k < jcp.nb_ic_blocking; k++) {
- for (int j = 0; j < ur_w; j++) {
- Zmm zmm = zmm_out(j, k);
- vpxord(zmm, zmm, zmm);
- size_t aux_src_offset
- = (size_t)typesize * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j)
- * jcp.ic_block;
- mic_prefetcht1(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset,
- reg_long_offt));
- }
- }
-}
-
-void jit_avx512_common_conv_bwd_data_kernel_f32::store_output(int ur_w)
-{
- Label no_update_label;
-
- mov(reg_channel, ptr[param + GET_OFF(channel)]);
- cmp(reg_channel, 0);
- je(no_update_label, T_NEAR);
- for (int k = 0; k < jcp.nb_ic_blocking; k++) {
- for (int j = 0; j < ur_w; j++) {
- Zmm zmm = zmm_out(j, k);
- size_t aux_src_offset = (size_t)typesize
- * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block;
- vaddps(zmm, EVEX_compress_addr_safe(reg_src, aux_src_offset,
- reg_long_offt));
- }
- }
-
- L(no_update_label);
- for (int k = 0; k < jcp.nb_ic_blocking; k++) {
- for (int j = 0; j < ur_w; j++) {
- Zmm zmm = zmm_out(j, k);
- size_t aux_src_offset = (size_t)typesize
- * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block;
- vmovups(EVEX_compress_addr_safe(reg_src, aux_src_offset,
- reg_long_offt), zmm);
- mic_prefetcht0(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset,
- reg_long_offt));
- }
- }
-}
-
-void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma(
- int ur_w, int l_overflow, int r_overflow)
-{
- int ow = jcp.ow;
- int kw = jcp.kw;
- int ic_block = jcp.ic_block;
- int oc_block = jcp.oc_block;
- Label kh_label, last_iter_label, loop_end_label, kd_label;
- int ker_load_number = 4;
- int shift_ker_ptr = typesize * kw * oc_block * ic_block;
- int shift_dst_ptr = typesize * ow * oc_block;
- int ii_dpref_t0 = get_iw_start(0, l_overflow);
- int iw_end_ipref = get_iw_end(ur_w, 0, r_overflow);
-
- bool check_last_kh = (jcp.kh > 3);
- auto kernel_offset = [=](int icb, int oc, int ki) {
- int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki;
- int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
- int oc_offset = oc * jcp.oc_block;
- return typesize * (blk_offset + oc_offset);
- };
- auto kernel_loads = [=](int ki, int oc, int kk) {
- for (int ii = 0; ii < ker_load_number; ii++) {
- int aux_kernel_offset = kernel_offset(kk, oc + ii, ki);
- vmovups(zmm_ker(ii),
- EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
- }
- };
- auto prefetch_dst_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) {
- if (cnt1 >= ker_load_number && cnt0 >= ker_load_number
- && ki >= ki_start && ii_dpref_t0 < iw_end_ipref) {
- int aux_dst_offset = typesize * ((ii_dpref_t0
- + jcp.l_pad) * oc_block + jcp.ow * oc_block);
- prefetcht0(EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
- ii_dpref_t0++;
- }
- };
-
- if (one_of(jcp.ndims, 3, 4)) {
- mov(aux_reg_dst, reg_dst);
- mov(aux_reg_ker, reg_ker);
- mov(aux_reg_dst_prf, reg_dst_prf);
- mov(aux_reg_ker_prf, reg_ker_prf);
- }
-
- if (jcp.ndims == 5) {
- push(reg_src_prf);
- push(reg_src);
-
- mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
- mov(aux_reg_dst_d, reg_dst);
- mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
- mov(aux_reg_dst_d_prf, reg_dst_prf);
- mov(aux_reg_ker_d_prf, reg_ker_prf);
-
- L(kd_label);
- mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
- } else {
- mov(reg_kj, reg_kh);
- }
-
- if (jcp.ndims == 5) {
- mov(aux_reg_dst, aux_reg_dst_d);
- mov(aux_reg_ker, aux_reg_ker_d);
- mov(aux_reg_dst_prf, aux_reg_dst_d_prf);
- mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
- }
-
- align(16);
- L(kh_label);
- if (check_last_kh) {
- for (int ki = 0; ki < kw; ki++)
- for (int oc = 0; oc < oc_block; oc += 4)
- for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) {
- bool last_kernel_loads = (kk == jcp.nb_ic_blocking - 1
- && ki == kw - 1 && (oc + 4) == oc_block);
-
- if (last_kernel_loads) {
- cmp(reg_kj, 1);
- je(last_iter_label, T_NEAR);
- }
-
- kernel_loads(ki, oc, kk);
- for (int ii = get_iw_start(ki, l_overflow),
- prf_count_t0 = 0, prf_count_t1 = 0;
- ii < get_iw_end(ur_w, ki, r_overflow); ii++) {
- int aux_dst_offset = typesize
- * ((ii + jcp.l_pad - ki) * oc_block + oc);
- v4fmaddps(zmm_out(ii, kk), zmm_ker(0),
- EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
-
- if (ii % 2) {
- if (prf_count_t0 < 4) {
- int aux_kernel_prf;
- if (last_kernel_loads)
- aux_kernel_prf= kernel_offset(0, prf_count_t0
- + oc + 4 - oc_block, 0) + typesize * kw
- * oc_block * ic_block;
- else
- aux_kernel_prf = kernel_offset(kk, oc + 4
- + prf_count_t0, ki);
- mic_prefetcht0(EVEX_compress_addr(aux_reg_ker,
- aux_kernel_prf));
- prf_count_t0++;
- } else if (prf_count_t1 < 4) {
- mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf,
- kernel_offset(kk, oc + prf_count_t1, ki)));
- prf_count_t1++;
- }
- } else
- prefetch_dst_next_kh(ki, 2, prf_count_t0, prf_count_t1);
- }
- if (last_kernel_loads) {
- jmp(loop_end_label, T_NEAR);
-
- L(last_iter_label);
-
- kernel_loads(ki, oc, kk);
- for (int ii = get_iw_start(ki, l_overflow),
- prf_count_t0 = 0, prf_count_t1 = 0;
- ii < get_iw_end(ur_w, ki, r_overflow); ii++) {
- int aux_dst_offset = typesize
- * ((ii + jcp.l_pad - ki) * oc_block + oc);
- v4fmaddps(zmm_out(ii, kk), zmm_ker(0),
- EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
- if (ii % 2) {
- if (prf_count_t0 < 4) {
- mic_prefetcht0(EVEX_compress_addr(aux_reg_ker_prf,
- kernel_offset(0, prf_count_t0, 0)));
- prf_count_t0++;
- } else if (prf_count_t1 < 4) {
- mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf,
- kernel_offset(kk, oc + prf_count_t1, ki)));
- prf_count_t1++;
- }
- }
- }
- L(loop_end_label);
- }
- }
- } else {
- for (int ki = 0; ki < kw; ki++)
- for (int oc = 0; oc < oc_block; oc += 4)
- for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) {
- kernel_loads(ki, oc, kk);
-
- for (int ii = get_iw_start(ki, l_overflow), prf_count_t1 = 0;
- ii < get_iw_end(ur_w, ki, r_overflow); ii++) {
- int aux_dst_offset = typesize
- * ((ii + jcp.l_pad - ki) * oc_block + oc);
- v4fmaddps(zmm_out(ii, kk), zmm_ker(0),
- EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
- if ((ii % 2) && (prf_count_t1 < 4)) {
- mic_prefetcht1(EVEX_compress_addr(
- aux_reg_ker_prf, kernel_offset(kk,
- oc + prf_count_t1, ki)));
- prf_count_t1++;
- }
- if ( ki == 1 && oc == 0 && kk == 0)
- mic_prefetcht1(EVEX_compress_addr(
- aux_reg_dst_prf, aux_dst_offset));
- }
- }
- }
-
- add(aux_reg_ker, shift_ker_ptr);
- sub(aux_reg_dst, shift_dst_ptr);
- add(aux_reg_ker_prf, shift_ker_ptr);
- sub(aux_reg_dst_prf, shift_dst_ptr);
-
- dec(reg_kj);
- cmp(reg_kj, 0);
- jg(kh_label, T_NEAR);
-
- if (jcp.ndims == 5) {
- sub(aux_reg_dst_d, typesize * (jcp.oh * ow) * ic_block);
- add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block);
- sub(aux_reg_dst_d_prf, typesize * (jcp.oh * ow) * ic_block);
- add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh *oc_block * ic_block);
-
- dec(reg_ki);
- cmp(reg_ki, 0);
- jg(kd_label, T_NEAR);
-
- pop(reg_src);
- pop(reg_src_prf);
- }
-}
-
-void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma(
- int ur_w, int l_overflow, int r_overflow)
-{
- Label kh_label, kd_label;
- int kw = jcp.kw;
- int ow = jcp.ow;
-
- int ic_block = jcp.ic_block;
- int oc_block = jcp.oc_block;
- int l_pad = jcp.l_pad;
- int dilate_w = jcp.dilate_w + 1;
- int stride_w = jcp.stride_w;
- int stride_h = jcp.stride_h;
-
- int ker_pipeline_depth = 4;
- assert(ker_reg_base_idx + ker_pipeline_depth <= 32);
- assert(oc_block >= ker_pipeline_depth);
-
- int num_ker_loads = oc_block * kw;
- int num_inp_prfs = ur_w * nstl::min(kw, stride_w)
- + nstl::max(0, kw - stride_w);
- int num_prfs = num_ker_loads + num_inp_prfs;
- int num_fmas = num_ker_loads * ur_w / stride_w;
- int prf_inst_spacing = nstl::max(1, num_fmas / num_prfs);
- int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2;
-
- if (one_of(jcp.ndims, 3, 4)) {
- mov(aux_reg_dst, reg_dst);
- mov(aux_reg_ker, reg_ker);
-
- mov(aux_reg_dst_prf, reg_dst_prf);
- mov(aux_reg_ker_prf, reg_ker_prf);
- }
-
- if (jcp.ndims == 5) {
- push(reg_src_prf);
- push(reg_src);
-
- mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
- mov(aux_reg_dst_d, reg_dst);
- mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
- mov(aux_reg_dst_d_prf, reg_dst_prf);
- mov(aux_reg_ker_d_prf, reg_ker_prf);
-
- L(kd_label);
- mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
- } else {
- mov(reg_kj, reg_kh);
- }
-
- if (jcp.ndims == 5) {
- mov(aux_reg_dst, aux_reg_dst_d);
- mov(aux_reg_ker, aux_reg_ker_d);
- mov(aux_reg_dst_prf, aux_reg_dst_d_prf);
- mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
- }
-
- L(kh_label); {
- int step = 0;
- int ker_prfs = 0;
- for (int ki = 0; ki < kw; ki++) {
- for (int oc = 0; oc < oc_block; oc++) {
- if (step == 0) {
- for (int i = 0; i < ker_pipeline_depth; i++) {
- int aux_kernel_offset = typesize * ((oc + i) * oc_block
- + ki * ic_block * oc_block);
- vmovups(zmm_ker(i), EVEX_compress_addr(
- aux_reg_ker, aux_kernel_offset));
- }
- } else if (step < num_ker_loads - ker_pipeline_depth + 1) {
- int load_offset = ker_pipeline_depth - 1;
- int ker_load_reg_idx
- = (step + load_offset) % ker_pipeline_depth;
- int aux_kernel_offset = typesize * ((oc + load_offset)
- * oc_block + ki * ic_block * oc_block);
- vmovups(zmm_ker(ker_load_reg_idx),
- EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
- }
-
- bool ker_prf_inserted = false;
- auto zmm_kernel = zmm_ker(step % ker_pipeline_depth);
-
- int jj_start = get_iw_start(ki, l_overflow);
- int jj_end = get_iw_end(ur_w, ki, r_overflow);
- assert(stride_w != 1
- || jj_start == nstl::max(0,
- l_overflow - (kw - 1 - ki) * dilate_w));
- assert(stride_w != 1
- || jj_end == ur_w - nstl::max(0,
- r_overflow - ki * dilate_w));
-
- for (int jj = jj_start; jj < jj_end; jj += stride_w) {
- assert((jj + l_pad - ki * dilate_w) % stride_w == 0);
- int aux_dst_offset = typesize *
- (((jj + l_pad - ki * dilate_w)
- / stride_w) * jcp.oc_block + oc);
- vfmadd231ps(zmm_out(jj, 0), zmm_kernel,
- EVEX_compress_addr(aux_reg_dst, aux_dst_offset, true));
-
- int fma_idx = (step * ur_w + jj) / stride_w;
- int prf_slot_idx = fma_idx / prf_inst_spacing;
- if (fma_idx % prf_inst_spacing == prf_inst_trigger) {
- if (!ker_prf_inserted && ker_prfs < num_ker_loads) {
- int ker_prf_offset = typesize
- * ker_prfs * jcp.oc_block;
- mic_prefetcht1(EVEX_compress_addr(
- aux_reg_ker_prf, ker_prf_offset));
- ker_prf_inserted = true;
- ker_prfs++;
- } else {
- int inp_prf_idx = prf_slot_idx - ker_prfs;
- if (inp_prf_idx < num_inp_prfs) {
- int inp_prf_offset
- = ic_block * typesize
- * ((inp_prf_idx / kw) * kw
- + (inp_prf_idx % kw));
- mic_prefetcht0(EVEX_compress_addr(
- aux_reg_dst_prf, inp_prf_offset));
- }
- }
- }
- }
- step++;
- }
- }
-
- add(aux_reg_ker, typesize * stride_h * kw * oc_block * ic_block);
- sub(aux_reg_dst, typesize * (jcp.dilate_h + 1) * ow * oc_block);
- add(aux_reg_ker_prf, typesize * stride_h * kw * oc_block * ic_block);
- sub(aux_reg_dst_prf, typesize * (jcp.dilate_h + 1) * ow * oc_block);
-
- dec(reg_kj);
- cmp(reg_kj, 0);
- jg(kh_label, T_NEAR);
- }
- if (jcp.ndims == 5) {
- sub(aux_reg_dst_d,
- typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
- add(aux_reg_ker_d, typesize * jcp.stride_d * jcp.kw * jcp.kh
- * oc_block * ic_block);
- sub(aux_reg_dst_d_prf,
- typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
- add(aux_reg_ker_d_prf, typesize * jcp.stride_d * jcp.kw * jcp.kh
- * oc_block * ic_block);
-
- dec(reg_ki);
- cmp(reg_ki, 0);
- jg(kd_label, T_NEAR);
- }
-
- if (jcp.ndims == 5)
- {
- pop(reg_src);
- pop(reg_src_prf);
- }
-}
-
-void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core(
- int ur_w, int l_overflow, int r_overflow)
-{
- int kw = jcp.kw;
- int ow = jcp.ow;
- int dilate_w = jcp.dilate_w + 1;
- int stride_w = jcp.stride_w;
- int ic_block = jcp.ic_block;
- int oc_block = jcp.oc_block;
- int nb_ic_block = jcp.nb_ic_blocking;
- Label kh_label, kd_label;
-
- int shift_ker_ptr = typesize * kw * oc_block * ic_block;
- int shift_dst_ptr = typesize * (jcp.dilate_h + 1) * ow * oc_block;
-
- auto output_offset = [=](int oi, int oc, int ki) {
- return typesize *
- (((oi + jcp.l_pad - ki * dilate_w) / stride_w) * oc_block + oc);
- };
- auto kernel_offset = [=](int icb, int oc, int ki) {
- int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki;
- int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
- int oc_offset = oc * jcp.oc_block;
- return typesize * (blk_offset + oc_offset);
- };
-
- if (one_of(jcp.ndims, 3, 4)) {
- mov(aux_reg_dst, reg_dst);
- mov(aux_reg_ker, reg_ker);
- }
-
- if (jcp.ndims == 5) {
- push(reg_src_prf);
- push(reg_src);
-
- mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
- mov(aux_reg_dst_d, reg_dst);
- mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
-
- L(kd_label);
- mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
- } else {
- mov(reg_kj, reg_kh);
- }
-
- if (jcp.ndims == 5) {
- mov(aux_reg_dst, aux_reg_dst_d);
- mov(aux_reg_ker, aux_reg_ker_d);
- }
-
- L(kh_label);
- {
- for (int ki = 0; ki < kw; ki++) {
- int jj_start = get_iw_start(ki, l_overflow);
- int jj_end = get_iw_end(ur_w, ki, r_overflow);
- for (int oc = 0; oc < oc_block; oc++) {
- if (jcp.kernel_kind == expl_bcast) {
- for (int jj = jj_start; jj < jj_end; jj++) {
- int aux_output_offset = output_offset(jj, oc, ki);
- vbroadcastss(zmm_inp(jj, nb_ic_block),
- ptr[aux_reg_dst + aux_output_offset]);
- }
- }
- for (int ii = 0; ii < nb_ic_block; ii++) {
- int aux_kernel_offset = kernel_offset(ii, oc, ki);
- if (jj_end - jj_start > 0)
- vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker,
- aux_kernel_offset));
- for (int jj = jj_start; jj < jj_end; jj += stride_w)
- if (jcp.kernel_kind == expl_bcast)
- vfmadd231ps(zmm_out(jj, ii),
- zmm_inp(jj, nb_ic_block), zmm_wei);
- else
- vfmadd231ps(zmm_out(jj, ii), zmm_wei,
- EVEX_compress_addr(aux_reg_dst,
- output_offset(jj, oc, ki), true));
- }
- }
- }
- add(aux_reg_ker, shift_ker_ptr);
- sub(aux_reg_dst, shift_dst_ptr);
- dec(reg_kj);
- cmp(reg_kj, 0);
- jg(kh_label, T_NEAR);
- }
-
- if (jcp.ndims == 5) {
- sub(aux_reg_dst_d,
- typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
- add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block);
-
- dec(reg_ki);
- cmp(reg_ki, 0);
- jg(kd_label, T_NEAR);
-
- pop(reg_src);
- pop(reg_src_prf);
- }
-}
-
-inline void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop(
- int ur_w, int l_overflow, int r_overflow)
-{
- if (jcp.ndims == 5) push(reg_oi);
-
- prepare_output(ur_w);
-
- Label skip_compute_loop;
- if (jcp.ndims == 5) {
- mov(reg_kj, ptr[param + GET_OFF(kd_padding)]);
- cmp(reg_kj, 0);
- je(skip_compute_loop, T_NEAR);
- }
- mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
- cmp(reg_kj, 0);
- je(skip_compute_loop, T_NEAR);
-
- if (jcp.ver == ver_4fma)
- compute_loop_4fma(ur_w, l_overflow, r_overflow);
- else if (jcp.ver == ver_fma)
- if (mayiuse(avx512_mic))
- compute_loop_fma(ur_w, l_overflow, r_overflow);
- else
- if (jcp.kernel_kind == embd_bcast && jcp.nb_ic_blocking == 1)
- compute_loop_fma(ur_w, l_overflow, r_overflow);
- else
- compute_loop_fma_core(ur_w, l_overflow, r_overflow);
- else
- assert("!unknown convolution version");
-
- L(skip_compute_loop);
- store_output(ur_w);
- if (jcp.ndims == 5) pop(reg_oi);
-}
-
-void jit_avx512_common_conv_bwd_data_kernel_f32::generate()
-{
- int iw = jcp.iw;
- int kw = jcp.kw;
- int ur_w = jcp.ur_w;
- int ic_block = jcp.ic_block;
- int oc_block = jcp.oc_block;
- int ur_w_tail = jcp.ur_w_tail;
- int dilate_w = jcp.dilate_w + 1;
- int stride_w = jcp.stride_w;
-
- int dst_shift = jcp.typesize_in * (ur_w / stride_w) * ic_block;
- int src_shift = jcp.typesize_out * ur_w * oc_block;
-
- preamble();
-
- mov(reg_src, ptr[param + GET_OFF(src)]);
- mov(reg_dst, ptr[param + GET_OFF(dst)]);
- mov(reg_ker, ptr[param + GET_OFF(filt)]);
-
- mov(reg_kh, ptr[param + GET_OFF(kh_padding)]);
- mov(reg_src_prf, ptr[param + GET_OFF(src_prf)]);
- mov(reg_dst_prf, ptr[param + GET_OFF(dst_prf)]);
- mov(reg_ker_prf, ptr[param + GET_OFF(filt_prf)]);
-
- int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w);
- int r_overflow = nstl::max(0, ((kw - 1) * dilate_w
- - nstl::max(0, jcp.r_pad)) / stride_w);
- int r_overflow1 = nstl::max(0, ((kw - 1) * dilate_w
- - nstl::max(0, jcp.r_pad) - ur_w_tail) / stride_w);
-
- int n_oi = iw / ur_w;
- if (r_overflow1 > 0) n_oi--;
-
- if (ur_w == iw) {
- compute_loop(ur_w, l_overflow, r_overflow);
- } else if (n_oi == 0) {
- compute_loop(ur_w, l_overflow, r_overflow1);
- add(reg_src, src_shift);
- add(reg_dst, dst_shift);
- add(reg_src_prf, src_shift);
- add(reg_dst_prf, dst_shift);
- if (ur_w_tail != 0)
- compute_loop(ur_w_tail, 0, r_overflow);
- } else {
- xor_(reg_oi, reg_oi);
- if (l_overflow > 0) {
- compute_loop(ur_w, l_overflow, 0);
- add(reg_src, src_shift);
- add(reg_dst, dst_shift);
- add(reg_src_prf, src_shift);
- add(reg_dst_prf, dst_shift);
-
- inc(reg_oi);
- }
- if ((l_overflow <= 0 && n_oi > 0)
- || (l_overflow > 0 && n_oi > 1)) {
- Label ow_loop_label;
- L(ow_loop_label); {
- compute_loop(ur_w, 0, 0);
- add(reg_src, src_shift);
- add(reg_dst, dst_shift);
- add(reg_src_prf, src_shift);
- add(reg_dst_prf, dst_shift);
-
- inc(reg_oi);
- cmp(reg_oi, n_oi);
- jl(ow_loop_label, T_NEAR);
- }
- }
- if (r_overflow1 > 0) {
- compute_loop(ur_w, 0, r_overflow1);
- add(reg_src, src_shift);
- add(reg_dst, dst_shift);
- add(reg_src_prf, src_shift);
- add(reg_dst_prf, dst_shift);
- }
- if (ur_w_tail != 0) {
- compute_loop(ur_w_tail, 0, r_overflow);
- }
- }
-
- postamble();
-}
-
-status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(
- jit_conv_conf_t &jcp,
- const convolution_desc_t &cd,
- const memory_desc_wrapper &diff_src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &diff_dst_d)
-{
- if (!mayiuse(avx512_common)) return status::unimplemented;
-
- jcp = zero<decltype(jcp)>();
-
- jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
- const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
- int ndims = diff_src_d.ndims();
-
- jcp.ndims = ndims;
- jcp.prop_kind = cd.prop_kind;
-
- jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
- jcp.mb = diff_src_d.dims()[0];
-
- jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
- jcp.oc_without_padding = jcp.oc;
- jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
-
- jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1;
- jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2];
- jcp.iw = diff_src_d.dims()[ndims-1];
- jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
- jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
- jcp.ow = diff_dst_d.dims()[ndims-1];
-
- jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
- jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
- jcp.kw = weights_d.dims()[with_groups + ndims - 1];
-
- jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
- jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
- jcp.l_pad = cd.padding[0][ndims-3];
-
- jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
- jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
- jcp.stride_w = cd.strides[ndims-3];
-
- jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
- jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
- jcp.dilate_w = cd.dilates[ndims-3];
- if ((jcp.dilate_w != 0 && jcp.stride_w != 1)
- || (jcp.dilate_d != 0 && jcp.stride_d != 1)
- || (jcp.dilate_h != 0 && jcp.stride_h != 1))
- return status::unimplemented;
-
- jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
- - (jcp.iw + jcp.l_pad - 1);
- jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
- - (jcp.ih + jcp.t_pad - 1);
- jcp.back_pad = (jcp.od - 1) * jcp.stride_d
- + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1);
-
- jcp.aligned_threads = 0;
-
- jcp.is_1stconv = false;
-
- jcp.oc_block = jcp.simd_w;
- jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
-
- bool ok_to_pad_channels = true
- && jcp.ngroups == 1
- && diff_src_d.data_type() == data_type::f32;
-
- if (ok_to_pad_channels) {
- jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
- jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
- }
-
- auto dat_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
- auto wei_tag = with_groups
- ? pick(ndims - 3, gOIw16o16i, gOIhw16o16i, gOIdhw16o16i)
- : pick(ndims - 3, OIw16o16i, OIhw16o16i, OIdhw16o16i);
- jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag);
- jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag);
-
- bool args_ok = true
- && jcp.oc % jcp.oc_block == 0
- && jcp.ic % jcp.ic_block == 0
- && jcp.src_tag == dat_tag
- && jcp.dst_tag == dat_tag;
- if (!args_ok)
- return status::unimplemented;
-
- jcp.nb_ic = jcp.ic / jcp.ic_block;
- jcp.nb_oc = jcp.oc / jcp.oc_block;
-
- jcp.ur_w = jcp.stride_w;
-
- int regs = 28;
- if (jcp.iw <= regs)
- jcp.ur_w = jcp.iw;
- else {
- for (int ur_w = regs; ur_w > 0; --ur_w)
- if (ur_w % jcp.stride_w == 0) {
- jcp.ur_w = ur_w;
- break;
- }
- }
- int l_overflow = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
- - jcp.l_pad) / jcp.stride_w);
- int r_overflow1 = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
- - nstl::max(0, jcp.r_pad) - jcp.iw % jcp.ur_w) / jcp.stride_w);
- int n_oi = jcp.iw / jcp.ur_w;
- if (r_overflow1 > 0) n_oi--;
-
- if (mayiuse(avx512_common)
- && diff_dst_d.data_type() == data_type::f32
- && weights_d.data_type() == data_type::f32
- && diff_src_d.data_type() == data_type::f32) {
- jcp.ver = ver_fma;
- jcp.typesize_in = sizeof(float);
- jcp.typesize_out = sizeof(float);
- if (mayiuse(avx512_mic_4ops)
- && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1) {
- jcp.ver = ver_4fma;
- }
- } else {
- return status::unimplemented;
- }
-
- jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
- if (jcp.wei_tag != wei_tag)
- return status::unimplemented;
-
- if (!utils::everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w)
- && jcp.ver != ver_fma)
- return status::unimplemented;
-
- jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
- if (jcp.ver == ver_4fma) {
- if (jcp.kw == 3 && jcp.kh == 3 && jcp.iw == 7 && jcp.ih == 7) {
- jcp.nb_ic_blocking = 2;
- } else {
- for (int i = jcp.nb_ic; i > 0; i--)
- if (i * jcp.ur_w <= regs && jcp.nb_ic % i == 0) {
- jcp.nb_ic_blocking = i;
- break;
- }
- }
- }
-
- jcp.loop_order = loop_gnc;
-
- bool large_code_size = (jcp.ur_w != jcp.ow)
- && ((l_overflow <= 0 && n_oi > 0) ||(l_overflow > 0 && n_oi > 1))
- && (r_overflow1 > 0) && (l_overflow > 0);
- if (large_code_size) {
- const int max_code_size = 24 * 1024;
- const int num_ops_per_reg = 6 + jcp.oc_block * jcp.kw;
- int mult = 1;
- if (l_overflow > 0) mult += 1;
- if (r_overflow1 > 0) mult += 1;
- for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) {
- if ((ur_w / jcp.stride_w) * mult * num_ops_per_reg * 9.2
- < max_code_size) {
- if (ur_w % jcp.stride_w == 0) {
- jcp.ur_w = ur_w;
- break;
- }
- }
- }
- }
-
- if (jcp.ver == ver_fma && mayiuse(avx512_core)) {
- int try_nb_ic_blocking = 2;
- unsigned int ker_inp_size = typesize * jcp.iw * jcp.ic_block
- * try_nb_ic_blocking * jcp.kh;
- unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block;
- unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block
- * jcp.oc_block * try_nb_ic_blocking;
- unsigned int ker_total_size = ker_inp_size + ker_out_size
- + ker_wei_size;
- if (!(jcp.kw == 1 || (jcp.kw == 5 && jcp.iw < 8)
- || (jcp.kw < 5 && ((jcp.iw <= 5 || (jcp.iw > 8 && jcp.iw <= 13))
- || ker_total_size > L1_cache_size )))
- || jcp.stride_h > 1 || jcp.stride_d > 1) {
- jcp.kernel_kind = embd_bcast;
- jcp.ur_w = nstl::min(jcp.iw, regs);
- jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
- if (!(jcp.kw > 3 || (jcp.kw == 3 && ker_total_size < L1_cache_size
- && jcp.ow > 8)) && jcp.stride_h == 1)
- if (jcp.nb_ic % try_nb_ic_blocking == 0) {
- jcp.nb_ic_blocking = try_nb_ic_blocking;
- jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1);
- if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw;
- }
- } else {
- jcp.kernel_kind = expl_bcast;
- jcp.nb_oc_blocking = 1;
- jcp.nb_ic_blocking = 4;
- if (jcp.nb_ic < jcp.nb_ic_blocking) jcp.nb_ic_blocking = jcp.nb_ic;
- if (jcp.nb_ic % jcp.nb_ic_blocking != 0)
- for (int i = jcp.nb_ic_blocking; i > 0; i--)
- if (jcp.nb_ic % i == 0) {
- jcp.nb_ic_blocking = i;
- break;
- }
- jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1);
- if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw;
- }
- }
- jcp.ur_w_tail = jcp.iw % jcp.ur_w;
-
- if (l_overflow * jcp.stride_w > jcp.ur_w)
- return status::unimplemented;
- int r_overflow_no_tail = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
- - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
- if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w)
- return status::unimplemented;
- if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
- return status::unimplemented;
-
- pick_loop_order(jcp);
-
- jcp.nb_oc_L2 = jcp.nb_oc;
- if (jcp.ver == ver_4fma && (jcp.kh < 5 && jcp.kw < 5)) {
- for (int divf = 2, temp_nb = jcp.nb_oc_L2; divf <= jcp.nb_oc;
- divf++) {
- size_t l2_src = jcp.iw * jcp.ic_block * jcp.nb_ic_blocking * jcp.ih
- * jcp.id;
- size_t l2_dst = jcp.ow * jcp.oc_block * temp_nb * jcp.oh * jcp.od;
- size_t l2_filt = jcp.kw * jcp.oc_block * jcp.ic_block * jcp.kh
- * jcp.kd * jcp.nb_ic_blocking * temp_nb;
- if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) {
- if (jcp.kh == 3 && jcp.ih == 7) {
- jcp.nb_oc_L2 = 1;
- break;
- }
- temp_nb = (jcp.nb_oc_L2 % divf == 0 ? jcp.nb_oc_L2 / divf
- : jcp.nb_oc_L2);
- } else {
- jcp.nb_oc_L2 = temp_nb;
- break;
- }
- }
- }
-
- args_ok = true
- && jcp.ic <= diff_src_d.padded_dims()[1]
- && jcp.oc <= diff_dst_d.padded_dims()[1]
- && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
- && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
- if (!args_ok) return status::unimplemented;
-
- return status::success;
-}
-
-void jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad(
- memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
- UNUSED(scratchpad);
- UNUSED(jcp);
-}
-
-const int jit_avx512_common_conv_bwd_weights_kernel_f32::max_ur_w = 28;
-
-void jit_avx512_common_conv_bwd_weights_kernel_f32::od_step_comeback_pointers()
-{
- Label kd_comeback_label;
-
- /* 'depth' loop count bound by 'kd_work_size' */
- mov(kj, reg_kd_count);
- L(kd_comeback_label); {
- int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
- int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw;
- sub(reg_input,
- jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mult);
- sub(reg_kernel,
- jcp.typesize_out * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block);
- dec(kj);
- cmp(kj, 0);
- jg(kd_comeback_label, T_NEAR);
- }
-}
-
-void jit_avx512_common_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers()
-{
- Label kh_comeback_label, kd_comeback_label;
- mov(kj, reg_kh);
- L(kh_comeback_label); {
- int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
- int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw;
- sub(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mult);
- sub(reg_kernel,
- jcp.typesize_out * jcp.kw * jcp.ic_block * jcp.oc_block);
- dec(kj);
- cmp(kj, 0);
- jg(kh_comeback_label, T_NEAR);
- }
-}
-
-void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_fma(
- int ur_w, int pad_l, int pad_r,
- int ic_block_step, int input_offset, int kernel_offset,
- int output_offset, bool input_wraparound)
-{
-
- int kw = jcp.kw;
- int ic_block = jcp.ic_block;
- int oc_block = jcp.oc_block;
- for (int i_kw = 0; i_kw < kw; i_kw++)
- for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
- vmovups(Zmm(i_kw * ic_block_step + i_ic),
- EVEX_compress_addr(reg_kernel, typesize * (i_kw * ic_block
- + i_ic) * jcp.oc_block + kernel_offset));
-
- for (int i_ur = 0; i_ur < ur_w; i_ur++) {
- if (i_ur == 0) {
- vmovups(Zmm(kw * ic_block_step + (i_ur + 0) % 4),
- EVEX_compress_addr(reg_output, typesize * (i_ur + 0)
- * oc_block + output_offset));
- if (ur_w > 1) vmovups(Zmm(kw * ic_block_step + (i_ur + 1) % 4),
- EVEX_compress_addr(reg_output, typesize * (i_ur + 1) * oc_block
- + output_offset));
- if (ur_w > 2) vmovups(Zmm(kw * ic_block_step + (i_ur + 2) % 4),
- EVEX_compress_addr(reg_output, typesize * (i_ur + 2) * oc_block
- + output_offset));
- if (ur_w > 3) vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4),
- EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block
- + output_offset));
- } else if (i_ur + 3 < ur_w)
- vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4),
- EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block
- + output_offset));
-
- for (int i_kw = 0; i_kw < kw; i_kw++) {
- int i_iw = i_ur * jcp.stride_w + i_kw * (jcp.dilate_w + 1);
- if (i_iw - pad_l < 0 || i_iw > (ur_w - 1) * jcp.stride_w +
- (kw - 1) * (jcp.dilate_w + 1) - pad_r) continue;
- for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
- const size_t i_offset = (size_t)input_offset
- + (size_t)typesize * (jcp.ver == ver_4fma
- ? (i_iw - pad_l + i_ic * jcp.tr_iw)
- : (jcp.is_1stconv
- ? (i_iw - pad_l) + (size_t)i_ic
- * ((size_t)jcp.ih*jcp.iw*jcp.id)
- : (i_iw - pad_l) * ic_block + i_ic));
- vfmadd231ps(Zmm(i_kw * ic_block_step + i_ic),
- Zmm(kw * ic_block_step + i_ur % 4),
- EVEX_compress_addr_safe(reg_input, i_offset, reg_long_offt,
- true));
- }
- }
- }
-
- for (int i_kw = 0; i_kw < kw; i_kw++)
- for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
- vmovups(EVEX_compress_addr(reg_kernel, typesize
- * (i_kw * ic_block + i_ic) * jcp.oc_block + kernel_offset),
- Zmm(i_kw * ic_block_step + i_ic));
-}
-
-void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_4fma(
- int ur_w, int pad_l, int pad_r,
- int ic_block_step, int input_offset, int kernel_offset,
- int output_offset, bool input_wraparound)
-{
- // TODO: add prefetches to fma version as well
-
- assert(jcp.ver == ver_4fma);
-
- int kw = jcp.kw;
- int ic_block = jcp.ic_block;
- int oc_block = jcp.oc_block;
-
- auto zmm_ker = [=](int i_kw, int i_ic) {
- return Zmm(i_kw * ic_block_step + i_ic);
- };
-
- auto ker_addr = [=](int i_kw, int i_ic) {
- size_t local_offset
- = jcp.typesize_out * (i_kw * ic_block + i_ic) * jcp.oc_block;
- return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset);
- };
-
- auto inp_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0) {
- int stride = jcp.tr_iw * (jcp.is_1stconv ? jcp.ih : 1);
- int local_offset = jcp.typesize_in * (i_iw + i_ic * stride);
- return EVEX_compress_addr(reg_input,
- local_offset + input_offset + extra_offset);
- };
-
- auto zmm_out = [=](int i_iw) {
- // TODO: move reg calc to global member funcs
- const int out_zmm_base_idx = 28;
- return Zmm(out_zmm_base_idx + i_iw % 4);
- };
-
- auto out_addr = [=](int i_ur) {
- return EVEX_compress_addr(reg_output,
- jcp.typesize_in * i_ur * oc_block + output_offset);
- };
-
- auto pf_callback = [=](int i_ur, int i_kw, int i_ic) {
- assert(i_ur % 4 == 0);
- if (i_ur == 0)
- prefetcht1(ker_addr(i_kw, i_ic));
- if (i_ur + 4 >= ur_w)
- prefetcht0(ker_addr(i_kw, i_ic));
-
- const ptrdiff_t next_input_block_offset
- = jcp.typesize_in * ic_block_step * jcp.tr_iw;
- if (i_ur % 16 == 4 && i_kw == 0) {
- if (i_ur + 16 < ur_w)
- prefetcht0(inp_addr(i_ur + 16, i_ic));
- else
- prefetcht0(inp_addr(0, i_ic, next_input_block_offset));
- }
- if (i_ur % 16 == 4 && i_kw == 1) {
- if (input_wraparound)
- prefetcht1(inp_addr(i_ur, i_ic, -input_offset));
- else
- prefetcht1(inp_addr(i_ur, i_ic, next_input_block_offset));
- }
- };
-
- for (int i_kw = 0; i_kw < kw; i_kw++)
- for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
- auto zmm = zmm_ker(i_kw, i_ic);
- vpxord(zmm, zmm, zmm);
- }
-
- for (int i_ur = 0; i_ur < ur_w; i_ur += 4) {
-
- for (int i = 0; i < 4; i++) {
- auto zmm = zmm_out(i_ur + i);
- if (i_ur + i < ur_w)
- vmovups(zmm, out_addr(i_ur + i));
- else
- vpxord(zmm, zmm, zmm);
- prefetcht0(out_addr(i_ur + i + 4));
- }
-
- for (int i_kw = 0; i_kw < kw; i_kw++)
- for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
- int i_iw = i_ur + i_kw;
- v4fmaddps(zmm_ker(i_kw, i_ic),
- zmm_out(i_ur), inp_addr(i_iw, i_ic));
- pf_callback(i_ur, i_kw, i_ic);
- }
- }
-
- for (int i_kw = 0; i_kw < kw; i_kw++)
- for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
- auto addr = ker_addr(i_kw, i_ic);
- auto zmm = zmm_ker(i_kw, i_ic);
- vaddps(zmm, zmm, addr);
- vmovups(addr, zmm);
- }
-}
-
-void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step(
- int ur_w, int pad_l, int pad_r,
- int ic_block_step, int input_offset, int kernel_offset,
- int output_offset, bool input_wraparound)
-{
- if (jcp.ver == ver_4fma)
- compute_ic_block_step_4fma(ur_w, pad_l, pad_r,
- ic_block_step, input_offset, kernel_offset, output_offset,
- input_wraparound);
- else if (jcp.ver == ver_fma)
- compute_ic_block_step_fma(ur_w, pad_l, pad_r,
- ic_block_step, input_offset, kernel_offset, output_offset,
- input_wraparound);
- else
- assert(!"unknown convolution version");
-}
-
-void jit_avx512_common_conv_bwd_weights_kernel_f32
- ::compute_oh_step_unroll_ow_icblock(
- int ic_block_step, int max_ur_w)
-{
- UNUSED(max_ur_w);
-
- Label kh_label, kd_label;
-
- int ic_block = jcp.ic_block;
- int oc_block = jcp.oc_block;
- int inp_mul = !jcp.is_1stconv ? ic_block : 1;
- int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw;
- int ow = jcp.ow;
-
- int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
- + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
- int l_pad = jcp.l_pad;
-
- if (jcp.ndims == 5) {
- L(kd_label);
- mov(reg_input, aux_reg_input);
- mov(reg_kernel, aux_reg_kernel);
- }
-
- mov(kj, reg_kh);
- L(kh_label);
- {
- for (int i_b_ic = 0; i_b_ic < jcp.ic_block; i_b_ic += ic_block_step) {
- const int input_offset = jcp.typesize_in
- * (jcp.ver == ver_4fma ? i_b_ic * iw : i_b_ic);
- compute_ic_block_step(jcp.ur_w, l_pad, r_pad, ic_block_step,
- input_offset, jcp.typesize_out * i_b_ic * jcp.oc_block, 0,
- i_b_ic + ic_block_step >= jcp.ic_block);
- }
- add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
- add(reg_kernel, jcp.typesize_out * jcp.kw * ic_block * oc_block);
- dec(kj);
- cmp(kj, 0);
- jg(kh_label, T_NEAR);
- }
-
- if (jcp.ndims == 5) {
- add(aux_reg_input,
- jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mul);
- add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block
- * oc_block);
- dec(ki);
- cmp(ki, 0);
- jg(kd_label, T_NEAR);
- }
-}
-
-void jit_avx512_common_conv_bwd_weights_kernel_f32
- ::compute_oh_step_unroll_ow(
- int ic_block_step, int max_ur_w)
-{
- Label kh_label, ic_block_label, kd_label;
-
- UNUSED(max_ur_w);
-
- int ic_block = jcp.ic_block;
- int oc_block = jcp.oc_block;
-
- int ow = jcp.ow;
-
- int r_pad = nstl::max(0,
- (ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
- - (jcp.iw + jcp.l_pad - 1));
- int l_pad = jcp.l_pad;
-
- if (jcp.ndims == 5) {
- L(kd_label);
- mov(reg_input, aux_reg_input);
- mov(reg_kernel, aux_reg_kernel);
- }
-
- mov(kj, reg_kh);
- L(kh_label);
- {
- xor_(b_ic, b_ic);
- L(ic_block_label); {
- compute_ic_block_step(ow, l_pad, r_pad, ic_block_step,
- 0, 0, 0);
- size_t inp_icblk_stride = jcp.is_1stconv
- ? (size_t)jcp.ih * jcp.iw * jcp.id
- : (jcp.ver == ver_4fma ? jcp.tr_iw : 1);
- size_t input_offset
- = inp_icblk_stride * jcp.typesize_in * ic_block_step;
- safe_add(reg_input, input_offset, reg_long_offt);
- add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block);
- add(b_ic, ic_block_step);
- cmp(b_ic, jcp.ic_block);
- jl(ic_block_label, T_NEAR);
- }
-
- if (jcp.is_1stconv) {
- size_t input_offset
- = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block;
- safe_sub(reg_input, input_offset, reg_long_offt);
- add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw);
- } else if (jcp.ver != ver_4fma) {
- add(reg_input, jcp.typesize_in
- * ((jcp.dilate_h + 1) * jcp.iw - 1) * ic_block);
- }
- add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block);
- dec(kj);
- cmp(kj, 0);
- jg(kh_label, T_NEAR);
- }
- if (jcp.ndims == 5) {
- add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih
- * jcp.iw * (jcp.is_1stconv ? 1 : ic_block));
- add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block
- * oc_block);
- dec(ki);
- cmp(ki, 0);
- jg(kd_label, T_NEAR);
- }
-}
-
-void jit_avx512_common_conv_bwd_weights_kernel_f32
- ::compute_oh_step_common(
- int ic_block_step, int max_ur_w)
-{
- Label kh_label, ic_block_label, ow_block_label, kd_label;
-
- int ic_block = jcp.ic_block;
- int oc_block = jcp.oc_block;
-
- int ow = jcp.ow;
- int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
- + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
- int l_pad = jcp.ver == ver_4fma ? 0 : jcp.l_pad;
-
- int ur_w = nstl::min(ow, max_ur_w);
- int ur_w_trips = ow / ur_w;
- int ur_w_tail = ow % ur_w;
- if ((ur_w_tail == 0 && r_pad != 0)
- || r_pad >= ur_w_tail) {
- if (ur_w_trips > 1) {
- ur_w_tail += ur_w;
- ur_w_trips--;
- } else {
- ur_w_tail += (ur_w - ur_w / 2);
- ur_w = ur_w / 2;
- }
- }
-
- int inp_mult = (jcp.is_1stconv || jcp.ver == ver_4fma) ? 1 : ic_block;
- int input_comeback = (ur_w_trips * ur_w * jcp.stride_w - l_pad) * inp_mult;
- int output_comeback = ur_w_trips * ur_w * oc_block;
-
- if (jcp.ndims == 5) {
- L(kd_label);
- mov(reg_input, aux_reg_input);
- mov(reg_kernel, aux_reg_kernel);
- }
-
- mov(kj, reg_kh);
- L(kh_label); {
- xor_(b_ic, b_ic);
- L(ic_block_label); {
- if (l_pad != 0) {
- ur_w_trips--;
- compute_ic_block_step(ur_w, l_pad, 0, ic_block_step, 0, 0, 0);
- add(reg_input, jcp.typesize_in * (ur_w * jcp.stride_w - l_pad)
- * inp_mult);
- add(reg_output, jcp.typesize_in * ur_w * oc_block);
- }
-
- if (ur_w_trips > 0) {
- xor_(reg_ur_w_trips, reg_ur_w_trips);
- L(ow_block_label); {
- compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0);
- add(reg_input, jcp.typesize_in * ur_w * jcp.stride_w
- * inp_mult);
- add(reg_output, jcp.typesize_in * ur_w * oc_block);
-
- inc(reg_ur_w_trips);
- cmp(reg_ur_w_trips, ur_w_trips);
- jl(ow_block_label, T_NEAR);
- }
- }
-
- if (ur_w_tail > 0) compute_ic_block_step(ur_w_tail, 0, r_pad,
- ic_block_step, 0, 0, 0);
-
- sub(reg_input, jcp.typesize_in * input_comeback);
- sub(reg_output, jcp.typesize_in * output_comeback);
- int inp_icblk_stride = jcp.is_1stconv
- ? jcp.ih * jcp.iw * jcp.id
- : (jcp.ver == ver_4fma ? jcp.tr_iw : 1);
- size_t input_offset
- = inp_icblk_stride * jcp.typesize_in * ic_block_step;
- safe_add(reg_input, input_offset, reg_long_offt);
- add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block);
-
- add(b_ic, ic_block_step);
- cmp(b_ic, jcp.ic_block);
- jl(ic_block_label, T_NEAR);
- }
- if (jcp.is_1stconv) {
- size_t input_offset
- = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block;
- safe_sub(reg_input, input_offset, reg_long_offt);
- add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw);
- } else if (jcp.ver != ver_4fma) {
- add(reg_input, jcp.typesize_in
- * ((jcp.dilate_h + 1 ) * jcp.iw - 1) * ic_block);
- }
- add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block);
- dec(kj);
- cmp(kj, 0);
- jg(kh_label, T_NEAR);
- }
- if (jcp.ndims == 5) {
- add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih
- * jcp.iw * (jcp.is_1stconv ? 1 : ic_block));
- add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block
- * oc_block);
- dec(ki);
- cmp(ki, 0);
- jg(kd_label, T_NEAR);
- }
-}
-
-void jit_avx512_common_conv_bwd_weights_kernel_f32
- ::compute_oh_step_disp()
-{
- int ic_block_step = jcp.kw <= 3 ? 8 : (jcp.kw <= 7 ? 4 : 2);
- if (jcp.is_1stconv) {
- bool large_code = jcp.kw >= 7 && (jcp.l_pad > 0 || jcp.t_pad > 0);
- ic_block_step
- = (jcp.kw * jcp.ic_block <= 28 && !large_code) ? jcp.ic_block : 1;
- }
-
- bool too_large_to_unroll
- = (jcp.kw > 1 || jcp.kh > 1 || jcp.kd > 1)
- && (jcp.stride_w > 1 || jcp.stride_h > 1 || jcp.stride_d > 1);
-
- int ow = jcp.ow;
- if (jcp.ndims == 5) {
- /* NOTE: reg_kd_count = aux_reg_input = r12. The following order of
- * 'movs' must be guaranteed. */
- mov(ki, reg_kd_count);
- push(reg_kd_count);
- mov(aux_reg_input, reg_input);
- mov(aux_reg_kernel, reg_kernel);
- }
-
- if (jcp.kw <= 3 && ow <= 16 && !too_large_to_unroll)
- compute_oh_step_unroll_ow_icblock(ic_block_step, max_ur_w);
- else if (ow <= max_ur_w)
- compute_oh_step_unroll_ow(ic_block_step, max_ur_w);
- else
- compute_oh_step_common(ic_block_step, max_ur_w);
-
- if (jcp.ndims == 5) {
- mov(reg_input, aux_reg_input);
- mov(reg_kernel, aux_reg_kernel);
- pop(reg_kd_count);
- od_step_comeback_pointers();
- } else {
- oh_step_comeback_pointers();
- }
-}
-
-void jit_avx512_common_conv_bwd_weights_kernel_f32::maybe_zero_kernel()
-{
- Label skip_zeroing, zeroing_loop;
-
- mov(reg_tmp, ptr[param + GET_OFF(channel)]);
- cmp(reg_tmp, 0);
- jz(skip_zeroing, T_NEAR);
-
- Zmm zero = Zmm(0);
- vpxord(zero, zero, zero);
- xor_(reg_tmp, reg_tmp);
- L(zeroing_loop); {
- assert(jcp.oc_block * jcp.typesize_out
- == cpu_isa_traits<avx512_common>::vlen);
- for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
- vmovups(ptr[reg_kernel + reg_tmp + ic1 * jcp.oc_block
- * jcp.typesize_out], zero);
- add(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.typesize_out);
- cmp(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.kw * jcp.kh * jcp.kd
- * jcp.typesize_out);
- jnz(zeroing_loop);
- }
-
- L(skip_zeroing);
-}
-
-void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel()
-{
- Label skip_bias, bias_loop, skip_load_bias;
-
- mov(reg_tmp, ptr[param + GET_OFF(flags)]);
- test(reg_tmp,reg_tmp);
- jne(skip_bias, T_NEAR);
-
- mov(reg_bias, ptr[param + GET_OFF(bias)]);
- mov(reg_output, ptr[param + GET_OFF(dst)]);
- vpxord(Zmm(1), Zmm(1), Zmm(1));
-
- mov(reg_tmp, ptr[param + GET_OFF(channel)]);
- cmp(reg_tmp, 0);
- jne(skip_load_bias, T_NEAR);
- vmovups(Zmm(1), ptr[reg_bias]);
-
- L(skip_load_bias);
-
- mov(reg_oi, ptr[param + GET_OFF(d_worksize)]);
- sub(reg_oi, ptr[param + GET_OFF(d_index)]);
- mov(reg_tmp, jcp.oc_block * jcp.ow * jcp.oh * jcp.typesize_out);
- imul(reg_oi, reg_tmp);
-
- xor_(reg_tmp, reg_tmp);
- L(bias_loop); {
- vmovups(Zmm(0), ptr[reg_output + reg_tmp]);
- vaddps(Zmm(1), Zmm(1), Zmm(0));
- add(reg_tmp, jcp.oc_block * jcp.typesize_out);
- cmp(reg_tmp, reg_oi);
- jl(bias_loop);
- }
- vmovups(EVEX_compress_addr(reg_bias,0), Zmm(1));
-
- L(skip_bias);
-}
-
-void jit_avx512_common_conv_bwd_weights_kernel_f32
- ::compute_oh_loop_common()
-{
- int b_pad = jcp.b_pad;
- int t_pad = jcp.t_pad;
- bool is_dilated = jcp.dilate_h != 0;
- int dilate_h = jcp.dilate_h + 1;
- int stride_h = jcp.stride_h;
- const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
- int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw;
- Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_tail_label,
- oh_bpad_label, oh_bpad_label_end, od_label, od_label_end,
- oh_dilate_label_shift, oh_dilate_label_noshift, oh_dilate_label_end;
-
- int ow = jcp.ow;
-
- mov(reg_kh, jcp.kh);
- xor_(reg_ih_count, reg_ih_count);
- xor_(reg_oj, reg_oj);
- /* Compute 'top' edge */
- if (t_pad > 0) {
- const int kh_range = 1 + (jcp.kh - 1) * dilate_h;
- const int overflow
- = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h));
- const int underflow = div_up(t_pad, dilate_h);
- const int initial_inp_ker_overlap = jcp.kh - overflow - underflow;
- mov(reg_kh, initial_inp_ker_overlap);
- add(reg_kernel, jcp.typesize_out * underflow * jcp.kw * jcp.ic_block
- * jcp.oc_block);
- // generate loop to process kernel while it remains within t_pad + ih
- if (kh_range < t_pad + jcp.ih) {
- if (is_dilated) {
- const int tail = t_pad % dilate_h;
- const int shift = tail == 0 ? 0 : dilate_h - tail;
- mov(reg_tmp, shift);
- if (tail != 0)
- add(reg_input, jcp.typesize_in * shift * iw * inp_mult);
- }
- L(oh_tpad_label); {
- compute_oh_step_disp();
- add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
- if (is_dilated) {
- inc(reg_tmp);
- cmp(reg_tmp, dilate_h);
- jl(oh_dilate_label_shift, T_NEAR);
- // unshift input as new kernel element enters
- sub(reg_input, jcp.typesize_in * (dilate_h - 1) * iw * inp_mult);
- xor_(reg_tmp, reg_tmp);
- }
- // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
- sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw
- * jcp.ic_block * jcp.oc_block);
- add(reg_kh, stride_h);
- if (is_dilated) {
- jmp(oh_dilate_label_noshift, T_NEAR);
- L(oh_dilate_label_shift);
- // shift input as old kernel element progresses
- add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
- L(oh_dilate_label_noshift);
- }
- inc(reg_oj);
- add(reg_ih_count, stride_h);
-
- // final number of kernel elements that overlap with input
- const int final_inp_ker_overlap
- = nstl::min(jcp.kh, div_up(jcp.ih, dilate_h));
- cmp(reg_kh, final_inp_ker_overlap);
- jl(oh_tpad_label, T_NEAR);
- }
- }
- // need second loop to process kernel if it is larger than the input
- // (does not apply to dilations as they must have unit stride)
- if (kh_range >= jcp.ih + (t_pad % stride_h == 0 ? stride_h :
- t_pad % stride_h)) {
- assert(!is_dilated);
- mov(reg_kh, jcp.ih);
- L(oh_tpad_tail_label); {
- compute_oh_step_disp();
- add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
- sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw
- * jcp.ic_block * jcp.oc_block);
-
- inc(reg_oj);
- add(reg_ih_count, stride_h);
-
- cmp(reg_ih_count, nstl::min(t_pad, jcp.oh * stride_h));
- jl(oh_tpad_tail_label, T_NEAR);
- }
- }
- // correct any excess shifts to kernel and input
- // (does not apply to dilations as they must have unit stride,
- // kernel must fit inside input, and padding is smaller than input)
- if (t_pad <= jcp.oh * stride_h) {
- // kernel has moved beyond padding (adjust for stride effects)
- if (t_pad % stride_h != 0) {
- assert(!is_dilated);
- int inp_corr = stride_h - t_pad % stride_h;
- add(reg_kernel, jcp.typesize_out * inp_corr * jcp.kw
- * jcp.ic_block * jcp.oc_block);
- add(reg_input, jcp.typesize_in * inp_corr * iw * inp_mult);
- }
- } else {
- // kernel still overlaps padding (complete reset)
- assert(!is_dilated);
- sub(reg_kernel, jcp.typesize_out * (t_pad - jcp.oh * stride_h)
- * jcp.kw * jcp.ic_block * jcp.oc_block);
- }
- }
-
- cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h);
- jge(oh_label_end, T_NEAR);
- cmp(reg_oj, jcp.oh);
- jge(oh_label, T_NEAR);
-
- /* Compute middle block(s) */
- mov(reg_kh, jcp.kh);
- L(oh_label); {
- compute_oh_step_disp();
- add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
- add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
-
- inc(reg_oj);
- add(reg_ih_count, stride_h);
-
- cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h);
- jge(oh_label_end, T_NEAR);
-
- cmp(reg_oj, jcp.oh);
- jl(oh_label, T_NEAR);
- }
- L(oh_label_end);
-
- /* Compute bottom edge */
- if (b_pad > 0) {
- cmp(reg_oj, jcp.oh);
- jge(oh_bpad_label_end, T_NEAR);
-
- if (is_dilated) {
- mov(reg_kh, jcp.kh - 1); // assumes unit stride for dilations
- mov(reg_tmp, 0);
- } else {
- mov(reg_kh, jcp.ihp - b_pad);
- sub(reg_kh, reg_ih_count);
- }
- L(oh_bpad_label);
- {
- compute_oh_step_disp();
- add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
- add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
- if (is_dilated) {
- inc(reg_tmp);
- cmp(reg_tmp, dilate_h);
- jl(oh_dilate_label_end, T_NEAR);
- xor_(reg_tmp, reg_tmp);
- }
- sub(reg_kh, stride_h);
- cmp(reg_kh, 0);
- jle(oh_bpad_label_end, T_NEAR);
- if (is_dilated)
- L(oh_dilate_label_end);
-
- inc(reg_oj);
- cmp(reg_oj, jcp.oh);
- jl(oh_bpad_label, T_NEAR);
- }
- L(oh_bpad_label_end);
- }
-}
-
-void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_d_loop_common() {
- int ic_block = jcp.ic_block;
- int oc_block = jcp.oc_block;
- const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
- int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw;
- int ow = jcp.ow;
- const int input_backpad_overlap
- = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d);
-
- const size_t filter_shift
- = jcp.typesize_out * jcp.kh * jcp.kw * ic_block * oc_block;
- const size_t input_shift = jcp.typesize_in * jcp.ih * iw * inp_mult;
- const size_t output_shift = jcp.typesize_in * jcp.oh * ow * jcp.oc_block;
-
- Label d_loop_label, loop_end_label, common_block_label, fpad_end_label,
- backpad_end_label, backpad_label;
-
- if (jcp.with_bias) bias_kernel();
-
- /* initially offset 'kd' by f_pad */
- add(reg_kernel, ptr[param + GET_OFF(kd_offset)]);
-
- mov(reg_input_d, ptr[param + GET_OFF(src)]);
- mov(reg_output_d, ptr[param + GET_OFF(dst)]);
- mov(reg_d_index, ptr[param + GET_OFF(d_index)]);
- mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]);
-
- cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]);
- jge(loop_end_label, T_NEAR);
-
- L(d_loop_label);
-
- mov(reg_input, reg_input_d);
- mov(reg_output, reg_output_d);
-
- push(reg_input_d);
- push(reg_output_d);
- push(reg_d_index);
-
- compute_oh_loop_common();
-
- pop(reg_d_index);
- pop(reg_output_d);
- pop(reg_input_d);
-
- /* Compute 'front' edge */
- if (jcp.f_pad > 0) {
-
- /* Check if within fpad region */
- cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d));
- jge(fpad_end_label, T_NEAR);
-
- /* Fpad steps */
- sub(reg_kernel, filter_shift * jcp.stride_d);
- add(reg_kd_count, jcp.stride_d);
-
- /* Final number of kernel elements that overlap with input */
- const int inp_ker_overlap = nstl::min(jcp.kd, jcp.id);
- cmp(reg_kd_count, inp_ker_overlap);
- jl(common_block_label, T_NEAR);
-
- /* Correct any excess shifts to kernel and input */
- if (jcp.f_pad <= jcp.od * jcp.stride_d) {
- /* Filter has moved beyond padding (adjust for stride effects) */
- if (jcp.f_pad % jcp.stride_d != 0) {
- int inp_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d;
- add(reg_kernel, filter_shift * inp_corr);
- add(reg_input_d, input_shift * inp_corr);
- }
- } else {
- /* Filter still overlaps padding (complete reset) */
- sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift);
- }
-
- /* Apply correction */
- mov(reg_kd_count, jcp.kd);
- jmp(common_block_label);
-
- L(fpad_end_label);
- }
-
- /* Compute bottom edge */
- if (jcp.back_pad > 0) {
-
- /* Check if within back_pad region */
- cmp(reg_d_index, input_backpad_overlap - 1);
- jl(backpad_end_label, T_NEAR);
- jg(backpad_label, T_NEAR);
-
- /* Execute overlap correction between the filter and the initial
- * back_pad region. */
- mov(reg_kd_count,
- jcp.id + jcp.f_pad - input_backpad_overlap * jcp.stride_d);
- jmp(backpad_end_label, T_NEAR);
-
- L(backpad_label);
- sub(reg_kd_count, jcp.stride_d);
- cmp(reg_kd_count, 0);
- jle(loop_end_label, T_NEAR);
-
- L(backpad_end_label);
- }
-
- /* Compute middle block */
- add(reg_input_d, input_shift * jcp.stride_d);
-
- /* Execute common block and loop */
- L(common_block_label);
- add(reg_output_d, output_shift);
- inc(reg_d_index);
- cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]);
- jl(d_loop_label, T_NEAR);
-
- L(loop_end_label);
-}
-
-bool jit_avx512_common_conv_bwd_weights_kernel_f32::compute_full_spat_loop() {
- // FIXME: use register mapping from the class declaration
- bool ok = jcp.ver == ver_4fma
- && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
- && everyone_is(1, jcp.stride_h, jcp.stride_w);
- if (!ok) return false;
- if (jcp.l_pad != jcp.kw / 2 || jcp.t_pad != jcp.kh / 2)
- return false;
-
- // General code layout:
- //
- // Blocking over OH -- top level
- // (Reduces L2 pressure; not very useful right now)
- // Loop over all KHxKW kernel -- emit_kh_kw_loop()
- // Loop over OH block -- emit_h_loop()
- // Loop over OW blocks -- emit_fma_block()
- // (Supports both fully unrolled and partially unrolled versions to
- // reduce code size)
- // Loop over OW block -- emit_fma_step()
-
- int max_working_set_size = 128 * 1024;
- int pad_ow = jcp.ow;
-
- int inp_row_size = jcp.ic_block * jcp.tr_iw * jcp.typesize_in;
- int out_row_size = jcp.oc_block * pad_ow * jcp.typesize_in;
- int row_size = inp_row_size + out_row_size;
-
- int h_block_size = jcp.oh;
- int working_set_size = row_size * h_block_size;
-
- if (working_set_size > max_working_set_size) {
- int opt_working_set_size = 48 * 1024;
- assert(opt_working_set_size < max_working_set_size);
-
- while (working_set_size > opt_working_set_size) {
- for (int i = 2; i <= h_block_size; i++)
- if (i == h_block_size)
- h_block_size = h_block_size / 2;
- else if (h_block_size % i == 0) {
- h_block_size = h_block_size / i;
- break;
- }
- working_set_size = row_size * h_block_size;
-
- if (h_block_size == 1 && working_set_size > opt_working_set_size)
- return false;
- }
- }
-
- // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size (see below)
- if (h_block_size < nstl::max(1, jcp.t_pad)
- || jcp.b_pad > (jcp.oh % h_block_size == 0 ? h_block_size
- : jcp.oh % h_block_size))
- return false;
-
- // check that we can use simple arithmetic for prefetch address
- // calculations
- // TODO: we need some traits for this check (Roma)
- int cache_line_size = 64;
- assert(jcp.ic_block * typesize == 64);
- assert(jcp.oc_block * typesize == 64);
-
- int num_inp_l2_pfs = jcp.tr_iw * h_block_size;
- int avg_h_loop_len = h_block_size;
- int num_inp_l2_pfs_per_fma_block
- = div_up(num_inp_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh);
- int num_out_l2_pfs = pad_ow * h_block_size;
- int num_out_l2_pfs_per_fma_block
- = div_up(num_out_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh);
-
- Opmask reg_h_block = k1; // 32-bit only on Intel(R) Xeon Phi(TM) processors
- Reg64 reg_kh = rax;
- Reg64 reg_kw = rbx;
- Reg64 reg_tmp = abi_not_param1;
- Reg32 reg_tmp_w = reg_tmp.cvt32();
- Reg64 reg_ohs = rdx;
- Reg64 reg_ihs = rsi;
- Reg64 reg_h = r8;
- Reg64 reg_i = r9;
- Reg64 reg_j = r10;
-
- Reg64 reg_inp = r13;
- Reg64 reg_out = r14;
- Reg64 reg_ker = r15;
-
- Reg64 reg_inp_pf_l1 = rbp;
-
- Reg64 reg_inp_pf_l2 = r11;
- Reg64 reg_out_pf_l2 = r12;
-
- Xmm reg_inp_pf_save = xmm17;
- Xmm reg_out_pf_save = xmm18;
-
- Reg64 reg_inp_save = abi_param1;
- Reg64 reg_out_save = reg_tmp;
-
- auto zmm_out = [&](int oi) { return Zmm(24 + oi % 8); };
- auto zmm_ker = [&](int ic1) { return Zmm(ic1); };
- auto inp_addr = [&](int oi, int ic1) {
- return ptr[reg_inp + (ic1 * jcp.tr_iw + oi) * jcp.typesize_in];
- };
- auto out_addr = [&](int oi, int oj = 0) {
- assert(jcp.ver == ver_4fma);
- return ptr[reg_out
- + ((oi + oj * jcp.ow) * jcp.oc_block) * jcp.typesize_in];
- };
- auto ker_addr = [&](int ic1) {
- return ptr[reg_ker + ic1 * jcp.oc_block * jcp.typesize_out];
- };
-
- auto emit_block = [&](int h_block_size,
- bool is_last_block, bool is_last_kh_kw_iter, bool is_last_row)
- {
- // TODO: add an fma version (Roma)
- auto pad_ow = jcp.ow;
-
- int ow4u = rnd_up(pad_ow, 4);
- int def_step_size = 16;
-
- bool has_w_tail = (pad_ow % def_step_size != 0
- || pad_ow % 4 != 0);
- bool full_w_unroll = pad_ow / def_step_size < 2 + has_w_tail;
-
- auto emit_step = [&](int ur_ow,
- int num_inp_l1_pfs_per_fma_step,
- int num_inp_l2_pfs_per_fma_step,
- int num_out_l2_pfs_per_fma_step, bool is_w_tail)
- {
- bool block_wraparound = is_w_tail && is_last_row;
-
- assert(ur_ow % 4 == 0);
- int tail_size = ow4u % ur_ow;
- int this_ur_ow
- = (is_w_tail && tail_size) ? tail_size : ur_ow;
- int ow_last_chunk4 = pad_ow % 4;
- int ow_zero_tail4 = ow_last_chunk4
- ? 4 - ow_last_chunk4 : 0;
-
- auto emit_out_pf = [&](int oi) {
-#if 1
- if (oi + def_step_size < ur_ow || !block_wraparound)
- mic_prefetcht0(ptr[reg_out
- + ((def_step_size + oi)
- * jcp.oc_block * jcp.typesize_in)]);
- else {
- assert(block_wraparound);
- assert(oi + def_step_size >= ur_ow);
- mic_prefetcht0(ptr[reg_out_save
- + ((oi + def_step_size - ur_ow)
- * jcp.oc_block * jcp.typesize_in)]);
- }
-#else
- // XXX: This is an alternative prefetching strategy that
- // always prefetches the next row. Keeping it here for
- // future experiments (Roma)
- if (!block_wraparound)
- mic_prefetcht0(ptr[reg_out
- + (jcp.ow + oi) * jcp.oc_block * jcp.typesize_in]);
- else
- mic_prefetcht0(ptr[reg_out + reg_ohs
- - ((h_block_size - 1) * jcp.ow
- - oi) * jcp.oc_block * jcp.typesize_in]);
-#endif
- if (oi < num_out_l2_pfs_per_fma_step)
- mic_prefetcht1(ptr[reg_out_pf_l2
- + oi * jcp.oc_block * jcp.typesize_in]);
- };
-
- auto emit_inp_pf = [&](int oi4, int ic1) {
- int pf_slot_idx = ic1 + oi4 / 4 * jcp.ic_block;
- int num_pf_slots = jcp.ic_block * ur_ow / 4;
-
- int num_pfs = num_inp_l1_pfs_per_fma_step
- + num_inp_l2_pfs_per_fma_step;
- int pf_freq = nstl::max(1, num_pf_slots / num_pfs);
-
- if (pf_slot_idx % pf_freq)
- return;
-
- int pf_idx = pf_slot_idx / pf_freq;
-
- if (pf_idx < num_inp_l2_pfs_per_fma_step)
- mic_prefetcht1(ptr[reg_inp_pf_l2
- + pf_idx * jcp.ic_block * jcp.typesize_in]);
- else {
- pf_idx -= num_inp_l2_pfs_per_fma_step;
- // prefetch the 'tail' of the cache line because most of
- // the accesses are not aligned
- mic_prefetcht0(ptr[reg_inp_pf_l1
- + pf_idx * jcp.ic_block * jcp.typesize_in
- + cache_line_size - jcp.typesize_in]);
- }
- };
-
- auto numloads = 4;
-
- int steps = this_ur_ow;
- for (int oi4 = 0; oi4 < steps; oi4 += numloads) {
- for (int oi1 = 0; oi1 < numloads; oi1++) {
- int oi = oi4 + oi1;
- if (!is_w_tail || oi < (this_ur_ow - ow_zero_tail4)) {
- vmovups(zmm_out(oi), out_addr(oi));
- emit_out_pf(oi);
- } else {
- auto zmm = zmm_out(oi);
- vpxord(zmm, zmm, zmm);
- }
- }
-
- for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
- if (jcp.ver == ver_4fma) {
- v4fmaddps(zmm_ker(ic1),
- zmm_out(oi4), inp_addr(oi4, ic1));
- } else {
- assert(!"unknown convolution version");
- }
- emit_inp_pf(oi4, ic1);
- }
- }
- };
-
- // Input is transposed and padded but we only access about jcp.iw
- // elements so use that to compute the # of cache lines in each 'row'
- int num_inp_l1_pfs
- = div_up(jcp.iw * jcp.typesize_in, cache_line_size) * jcp.ic_block;
-
- if (full_w_unroll) {
- emit_step(ow4u, num_inp_l1_pfs,
- num_inp_l2_pfs_per_fma_block,
- num_out_l2_pfs_per_fma_block, true);
- add(reg_inp_pf_l2, num_inp_l2_pfs_per_fma_block * cache_line_size);
- add(reg_out_pf_l2, num_out_l2_pfs_per_fma_block * cache_line_size);
- } else {
- Label w_loop;
- int num_w_iters = pad_ow / def_step_size;
- int num_w_iters_full = num_w_iters + has_w_tail;
- int num_inp_l1_pfs_per_fma_step
- = div_up(num_inp_l1_pfs, num_w_iters_full);
- int num_inp_l2_pfs_per_fma_step
- = div_up(num_inp_l2_pfs_per_fma_block, num_w_iters_full);
- int num_out_l2_pfs_per_fma_step
- = div_up(num_out_l2_pfs_per_fma_block, num_w_iters_full);
- mov(reg_i, num_w_iters);
- L(w_loop); {
- emit_step(def_step_size, num_inp_l1_pfs_per_fma_step,
- num_inp_l2_pfs_per_fma_step,
- num_out_l2_pfs_per_fma_step, false);
- add(reg_inp, def_step_size * jcp.typesize_in);
- add(reg_out, def_step_size * jcp.oc_block * jcp.typesize_in);
- add(reg_inp_pf_l1,
- num_inp_l1_pfs_per_fma_step * cache_line_size);
- add(reg_inp_pf_l2,
- num_inp_l2_pfs_per_fma_step * cache_line_size);
- add(reg_out_pf_l2,
- num_out_l2_pfs_per_fma_step * cache_line_size);
- sub(reg_i, 1);
- jnz(w_loop);
- }
- if (has_w_tail) {
- emit_step(def_step_size, num_inp_l1_pfs_per_fma_step,
- num_inp_l2_pfs_per_fma_step,
- num_out_l2_pfs_per_fma_step, true);
- add(reg_inp_pf_l2,
- num_inp_l2_pfs_per_fma_step * cache_line_size);
- add(reg_out_pf_l2,
- num_out_l2_pfs_per_fma_step * cache_line_size);
- }
- // reset reg_inp and reg_out because emit_h_loop expects
- // unmodified pointers
- int w_offset = num_w_iters * def_step_size;
- sub(reg_inp, w_offset * jcp.typesize_in);
- sub(reg_out, w_offset * jcp.oc_block * jcp.typesize_in);
- }
- };
-
- auto emit_h_loop = [&](int h_block_size,
- bool is_last_block, bool is_last_kh_kw_iter)
- {
- Label h_loop, skip_h_loop;
- mov(reg_j, 1);
- cmp(reg_j, reg_h);
- je(skip_h_loop, T_NEAR);
- L(h_loop); {
-
- lea(reg_inp_pf_l1,
- ptr[reg_inp + jcp.tr_iw * jcp.ic_block * jcp.typesize_in]);
- emit_block(h_block_size,
- is_last_block, is_last_kh_kw_iter, false);
-
- add(reg_inp, jcp.tr_iw * jcp.ic_block * jcp.typesize_in);
- add(reg_out, pad_ow * jcp.oc_block * jcp.typesize_in);
- add(reg_j, 1);
- cmp(reg_j, reg_h);
- jb(h_loop);
- }
-
- L(skip_h_loop);
-
- for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
- mic_prefetcht0(ker_addr(ic1));
-
- lea(reg_inp_pf_l1, ptr[reg_inp_save + reg_kw * jcp.typesize_in]);
- emit_block(h_block_size, is_last_block, is_last_kh_kw_iter, true);
- };
-
- auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block,
- int h_block_size)
- {
- xor_(reg_kh, reg_kh);
- Label kh_loop, kh_loop_end;
-
- int last_oh_block_size
- = jcp.oh - rnd_up(jcp.oh - h_block_size, h_block_size);
- int oh_block_size = (is_last_block) ? last_oh_block_size : h_block_size;
- // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size
- int ih_block_size = oh_block_size - 1 + jcp.kh
- - is_first_block * jcp.t_pad - is_last_block * jcp.b_pad;
-
- L(kh_loop); {
- // determine starting indices for this block
- if (is_first_block) {
- xor_(reg_tmp, reg_tmp);
- mov(reg_ohs, jcp.t_pad);
- sub(reg_ohs, reg_kh);
- cmovb(reg_ohs, reg_tmp);
-
- mov(reg_ihs, reg_ohs);
- sub(reg_ihs, jcp.t_pad);
- add(reg_ihs, reg_kh);
- } else {
- xor_(reg_ohs, reg_ohs);
- mov(reg_ihs, reg_kh);
- }
-
- // determine effective size of block based on padding
- mov(reg_tmp, oh_block_size);
- sub(reg_tmp, reg_ohs);
- mov(reg_h, ih_block_size);
- sub(reg_h, reg_ihs);
- cmp(reg_tmp, reg_h);
- cmovb(reg_h, reg_tmp);
-
- Label kh_loop_work;
- cmp(reg_h, 0);
- jg(kh_loop_work, T_NEAR);
-
- // empty h loop for this jcp.kh:
- // - set the output to 0 if necessary
- // - move ker pt
- // - jump to the end
- sub(reg_h, 1);
- Label skip_ker_zeroing;
-
- // The reg_ker ptr has highest bit set if the output needs to be
- // zeroed. Those who have byte-aligned their data will suffer the
- // consiquences :(
- // TODO: move the flag to a mask register? (Roma)
- test(reg_ker, 1);
- jz(skip_ker_zeroing, T_NEAR);
-
- Label zeroing_loop;
- vpxord(zmm0, zmm0, zmm0);
- and_(reg_ker, ~1); // temporarily clear the zeroing flag
- mov(reg_tmp, jcp.kw);
- L(zeroing_loop); {
- for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
- vmovups(ker_addr(ic1), zmm0);
- add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.typesize_out);
- sub(reg_tmp, 1);
- jnz(zeroing_loop, T_NEAR);
- }
- // restore the zeroing flag (it will be cleared after the end of
- // emit_kh_kw_loop, but we may need it until then)
- or_(reg_ker, 1);
- jmp(kh_loop_end, T_NEAR);
-
- L(skip_ker_zeroing);
- add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.kw
- * jcp.typesize_out);
- jmp(kh_loop_end, T_NEAR);
-
- L(kh_loop_work);
-
- mul_by_const(reg_ihs, reg_tmp,
- jcp.tr_iw * jcp.ic_block * jcp.typesize_in);
- mul_by_const(reg_ohs, reg_tmp,
- pad_ow * jcp.oc_block * jcp.typesize_in);
-
- add(reg_inp, reg_ihs);
- add(reg_out, reg_ohs);
-
- Label kw_loop;
- xor_(reg_kw, reg_kw);
- L(kw_loop); {
- for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
- auto zmm = zmm_ker(ic1);
- vpxord(zmm, zmm, zmm);
- mic_prefetcht1(ker_addr(ic1));
- }
-
- mov(reg_out_save, reg_out);
- mov(reg_inp_save, reg_inp);
- lea(reg_inp, ptr[reg_inp + reg_kw * jcp.typesize_in]);
-
-#if 0
- // XXX: Generate code with special prefetches when switching
- // blocks or at the end of the last block. Disabled to reduce
- // code size and because there's no performance benefit (Roma)
- Label regular_h_loop, end_h_loop;
- cmp(reg_kw, jcp.kw - 1);
- jne(regular_h_loop, T_NEAR);
- cmp(reg_kh, jcp.kh - 1);
- jne(regular_h_loop, T_NEAR);
-
- emit_h_loop(oh_block_size, is_last_block, true);
- jmp(end_h_loop, T_NEAR);
-
- L(regular_h_loop);
- emit_h_loop(oh_block_size, is_last_block, false);
-
- L(end_h_loop);
-#else
- emit_h_loop(oh_block_size, is_last_block, false);
-#endif
-
- mov(reg_out, reg_out_save);
- mov(reg_inp, reg_inp_save);
-
- Label do_store;
- // The reg_ker ptr has highest bit set if the output needs to
- // be zeroed. Those who have byte-aligned their data will
- // suffer the consiquences :(
- mov(reg_tmp, reg_ker);
- and_(reg_ker, ~1);
- test(reg_tmp, 1);
- jnz(do_store, T_NEAR);
-
- for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
- auto zmm = zmm_ker(ic1);
- if (jcp.ver == ver_4fma) {
- vaddps(zmm, ker_addr(ic1));
- } else {
- assert(!"unknown convolution version");
- }
- }
-
- L(do_store);
- for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
- auto zmm = zmm_ker(ic1);
- vmovups(ker_addr(ic1), zmm);
- }
-
- mov(reg_ker, reg_tmp);
- add(reg_ker, jcp.ic_block * jcp.oc_block * jcp.typesize_out);
- add(reg_kw, 1);
- cmp(reg_kw, jcp.kw);
- jl(kw_loop);
- }
-
- sub(reg_inp, reg_ihs);
- sub(reg_out, reg_ohs);
-
-
- L(kh_loop_end);
- add(reg_kh, 1);
- cmp(reg_kh, jcp.kh);
- jl(kh_loop);
- }
- };
-
- mov(reg_inp, ptr[param + GET_OFF(src)]);
- mov(reg_out, ptr[param + GET_OFF(dst)]);
- mov(reg_ker, ptr[param + GET_OFF(filt)]);
- mov(reg_inp_pf_l2, ptr[param + GET_OFF(src_prf)]);
- mov(reg_out_pf_l2, ptr[param + GET_OFF(dst_prf)]);
- mov(reg_tmp, ptr[param + GET_OFF(channel)]);
- or_(reg_ker, reg_tmp);
-
- bool single_kh_kw_loop = (h_block_size == jcp.oh);
-
- size_t inp_row_step = jcp.tr_iw * jcp.ic_block * jcp.typesize_in;
- size_t first_inp_block_step = inp_row_step * (h_block_size - jcp.t_pad);
- size_t inp_block_step = inp_row_step * h_block_size;
- size_t out_block_step = pad_ow * jcp.oc_block * jcp.typesize_in
- * h_block_size;
-
- if (!single_kh_kw_loop) {
- // Save the original prefetch pointers from the OpenMP driver
- vmovq(reg_inp_pf_save, reg_inp_pf_l2);
- vmovq(reg_out_pf_save, reg_out_pf_l2);
- mov(reg_inp_pf_l2, reg_inp);
- add(reg_inp_pf_l2, first_inp_block_step);
- mov(reg_out_pf_l2, reg_out);
- add(reg_out_pf_l2, out_block_step);
- }
- emit_kh_kw_loop(true, single_kh_kw_loop, h_block_size);
-
- if (!single_kh_kw_loop) {
- size_t ker_reset_offset
- = jcp.oc_block * jcp.ic_block * jcp.typesize_out * jcp.kw * jcp.kh;
- sub(reg_ker, ker_reset_offset);
- and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates
-
- add(reg_inp, first_inp_block_step);
- add(reg_out, out_block_step);
- mov(reg_inp_pf_l2, reg_inp);
- add(reg_inp_pf_l2, inp_block_step);
- mov(reg_out_pf_l2, reg_out);
- add(reg_out_pf_l2, out_block_step);
-
- int num_innermost_iters = div_up(jcp.oh, h_block_size) - 2;
- if (num_innermost_iters > 0) {
- Label h_block_loop;
-
- mov(reg_tmp_w, num_innermost_iters);
- kmovw(reg_h_block, reg_tmp_w);
- L(h_block_loop); {
- emit_kh_kw_loop(false, false, h_block_size);
- sub(reg_ker, ker_reset_offset);
- add(reg_inp, inp_row_step * h_block_size);
- add(reg_out, out_block_step);
- mov(reg_inp_pf_l2, reg_inp);
- add(reg_inp_pf_l2, inp_block_step);
- mov(reg_out_pf_l2, reg_out);
- add(reg_out_pf_l2, out_block_step);
- kmovw(reg_tmp_w, reg_h_block);
- sub(reg_tmp_w, 1);
- kmovw(reg_h_block, reg_tmp_w);
- jnz(h_block_loop);
- }
- }
-
- // Restore the original prefetch pointers that came from the OpenMP
- // driver
- vmovq(reg_inp_pf_l2, reg_inp_pf_save);
- vmovq(reg_out_pf_l2, reg_out_pf_save);
- emit_kh_kw_loop(false, true, h_block_size);
- }
-
- return true;
-}
-
-bool jit_avx512_common_conv_bwd_weights_kernel_f32
- ::flat_4ops_compute() {
- const auto &j = jcp;
- const bool ok = j.ver == ver_4fma && j.is_1stconv
- && everyone_is(0, j.dilate_h, j.dilate_w);
- if (!ok) return false;
-
- Reg64 reg_ptr_tr_src = r8;
- Reg64 reg_ptr_dst = r9;
- Reg64 reg_ptr_wei = r10;
- Reg64 reg_ptr_bia = r11;
-
- Reg64 reg_kh_step = rax;
- Reg64 reg_oh = abi_not_param1;
- Reg64 reg_kh = rdx;
-
- Reg32 reg_flag_save = ebx;
- Reg32 reg_flag = esi;
-
- Zmm vbia(31);
-
- auto zmm_wei = [&](int kh, int kw) {
- return Zmm(8 + kh * j.kw + kw);
- };
- auto zmm_dst = [&](int ow) {
- return Zmm(ow % 8);
- };
-
- auto addr_tr_src = [&](int kh, int iw) {
- return ptr[reg_ptr_tr_src
- + (kh * j.stride_w * j.tr_ld + iw) * jcp.typesize_in];
- };
- auto addr_dst = [&](int ow) {
- return ptr[reg_ptr_dst + ow * jcp.oc_block * jcp.typesize_in];
- };
- auto addr_wei = [&](int kh, int kw) {
- return ptr[reg_ptr_wei + (kh * j.kw + kw) * j.oc_block
- * jcp.typesize_out];
- };
-
- auto emit_fma_block = [&](int kh_step) {
- for (int kh = 0; kh < kh_step; ++kh) {
- for (int kw = 0; kw < j.kw; ++kw) {
- auto vwei = zmm_wei(kh, kw);
- vpxord(vwei, vwei, vwei);
- }
- }
-
- for (int ow = 0; ow < j.ow; ow += 4) {
- for (int _ow = ow; _ow < ow + 4; ++_ow) {
- auto vdst = zmm_dst(_ow);
- if (_ow < j.ow)
- vmovups(vdst, addr_dst(_ow));
- else
- vpxord(vdst, vdst, vdst);
- }
-
- for (int kh = 0; kh < kh_step; ++kh) {
- for (int kw = 0; kw < j.kw; ++kw) {
- const int iw = ow + (kw % j.stride_w) * j.tr_ld
- + (kw / j.stride_w);
- v4fmaddps(zmm_wei(kh, kw), zmm_dst(ow),
- addr_tr_src(kh, iw));
- if (1 && kh == 0 && kw < 4) {
- prefetcht1(ptr[reg_ptr_dst
- + (j.ow + ow + kw) * jcp.oc_block
- * jcp.typesize_in]);
- }
- if (j.with_bias && kh_step == 1) { /* [bwd_w:b:r1] */
- const int off = kw + 4 - j.kw;
- if (off >= 0 && ow + off < j.ow)
- vaddps(vbia, vbia, zmm_dst(ow + off));
- }
- }
- }
- }
-
- Label l_store;
- test(reg_flag, FLAG_MB_FIRST);
- jnz(l_store, T_NEAR);
- for (int kh = 0; kh < kh_step; ++kh) {
- for (int kw = 0; kw < j.kw; ++kw)
- vaddps(zmm_wei(kh, kw), addr_wei(kh, kw));
- }
- L(l_store);
- for (int kh = 0; kh < kh_step; ++kh) {
- for (int kw = 0; kw < j.kw; ++kw)
- vmovups(addr_wei(kh, kw), zmm_wei(kh, kw));
- }
- };
-
- auto emit_kh_loop = [&]() {
- const int kh_step_rem = j.kh % j.kh_step;
- xor_(reg_kh, reg_kh);
- mov(reg_kh_step, j.kh_step);
-
- Label l_kh_loop;
- L(l_kh_loop); {
- Label l_done;
-
- if (kh_step_rem != 0) {
- Label l_keep_kh_step;
- cmp(reg_kh, j.kh - j.kh_step);
- jle(l_keep_kh_step, T_NEAR);
-
- mov(reg_kh_step, kh_step_rem);
- emit_fma_block(kh_step_rem);
- jmp(l_done, T_NEAR);
-
- L(l_keep_kh_step);
- }
-
- emit_fma_block(j.kh_step);
-
- L(l_done);
-
- add(reg_ptr_tr_src, j.kh_step * j.stride_w * j.tr_ld
- * jcp.typesize_in);
- add(reg_ptr_wei, j.kh_step * j.kw * j.oc_block * jcp.typesize_out);
- add(reg_kh, j.kh_step);
-
- cmp(reg_kh, j.kh);
- jl(l_kh_loop, T_NEAR);
- }
-
- const int kh_steps = rnd_up(j.kh, j.kh_step);
- sub(reg_ptr_tr_src, kh_steps * j.stride_w * j.tr_ld * jcp.typesize_in);
- sub(reg_ptr_wei, kh_steps * j.kw * j.oc_block * jcp.typesize_out);
- };
-
- auto emit_oh_loop = [&]() {
- mov(reg_oh, j.oh);
-
- Label l_oh_loop;
- L(l_oh_loop); {
- Label l_restore_mb_flag, l_jump;
-
- cmp(reg_oh, j.oh);
- je(l_restore_mb_flag, T_NEAR);
-
- and_(reg_flag, ~FLAG_MB_FIRST);
- jmp(l_jump, T_NEAR);
-
- L(l_restore_mb_flag);
- mov(reg_flag, reg_flag_save);
-
- L(l_jump);
-
- emit_kh_loop();
-
- add(reg_ptr_tr_src, j.stride_h * j.stride_w * j.tr_ld
- * jcp.typesize_in);
- add(reg_ptr_dst, j.ow * j.oc_block * jcp.typesize_in);
-
- dec(reg_oh);
- jnz(l_oh_loop, T_NEAR);
- }
- };
-
- auto emit_bia_store = [&]() {
- if (!j.with_bias) return;
-
- Label l_bia_store, l_bia_skip;
- test(reg_flag, FLAG_IC_FIRST);
- jz(l_bia_skip);
-
- test(reg_flag, FLAG_MB_FIRST);
- jnz(l_bia_store, T_NEAR);
- vaddps(vbia, ptr[reg_ptr_bia]);
- L(l_bia_store);
- vmovups(ptr[reg_ptr_bia], vbia);
- L(l_bia_skip);
- };
-
- mov(reg_ptr_tr_src, ptr[param + GET_OFF(src)]);
- mov(reg_ptr_dst, ptr[param + GET_OFF(dst)]);
- mov(reg_ptr_wei, ptr[param + GET_OFF(filt)]);
- mov(reg_ptr_bia, ptr[param + GET_OFF(bias)]);
- mov(reg_flag_save, ptr[param + GET_OFF(flags)]);
-
- vpxord(vbia, vbia, vbia);
- emit_oh_loop();
- emit_bia_store();
-
- return true;
-}
-
-void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_loop()
-{
- if (flat_4ops_compute())
- return;
- if (compute_full_spat_loop())
- return;
-
- maybe_zero_kernel();
-
- if (jcp.ndims == 5) compute_d_loop_common();
- else compute_oh_loop_common();
-}
-
-void jit_avx512_common_conv_bwd_weights_kernel_f32::generate()
-{
- preamble();
-
- mov(reg_input, ptr[param + GET_OFF(src)]);
- mov(reg_output, ptr[param + GET_OFF(dst)]);
- mov(reg_kernel, ptr[param + GET_OFF(filt)]);
-
- compute_loop();
-
- postamble();
-}
-
-status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
- jit_conv_conf_t &jcp, const convolution_desc_t &cd,
- memory_desc_t &src_md, memory_desc_t &diff_weights_md,
- memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md) {
- if (!mayiuse(avx512_common))
- return status::unimplemented;
-
- const memory_desc_wrapper src_d(&src_md);
- const memory_desc_wrapper diff_weights_d(&diff_weights_md);
- const memory_desc_wrapper diff_bias_d(&diff_bias_md);
- const memory_desc_wrapper diff_dst_d(&diff_dst_md);
-
- const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
- int ndims = src_d.ndims();
-
- jcp = zero<decltype(jcp)>();
-
- jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
- jcp.ndims = ndims;
- jcp.prop_kind = cd.prop_kind;
-
- jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
- jcp.mb = src_d.dims()[0];
-
- jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
- jcp.oc_without_padding = jcp.oc;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
-
- jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
- jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
- jcp.iw = src_d.dims()[ndims-1];
- jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
- jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
- jcp.ow = diff_dst_d.dims()[ndims-1];
-
- jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
- jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2];
- jcp.kw = diff_weights_d.dims()[with_groups + ndims-1];
-
- jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
- jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
- jcp.l_pad = cd.padding[0][ndims-3];
-
- jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
- jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
- jcp.stride_w = cd.strides[ndims-3];
-
- jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
- jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
- jcp.dilate_w = cd.dilates[ndims-3];
-
- const int kh_range = 1 + (jcp.kh - 1) * (jcp.dilate_h + 1);
- bool ok = true
- // general condition to simplify dilations
- && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1)
- && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1)
- && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1)
- // special condition to simplify dilations in compute_oh_loop_common
- && IMPLICATION(jcp.dilate_h != 0, kh_range <= jcp.ih);
- if (!ok)
- return status::unimplemented;
-
- jcp.r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w
- + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
- jcp.b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h
- + (jcp.kh - 1) * (jcp.dilate_h + 1) - (jcp.ih + jcp.t_pad - 1));
- jcp.back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d
- + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1));
-
- /* XXX: currently, does not support dilation_d > 0 */
- if (ndims == 5)
- if (jcp.dilate_d > 0)
- return status::unimplemented;
-
- jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
- jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
- jcp.ohp = jcp.oh;
- jcp.owp = jcp.ow;
- jcp.aligned_threads = 0;
-
- /* check for the 1st convolution */
- jcp.is_1stconv = is_1stconv(jcp);
-
- jcp.oc_block = jcp.simd_w;
-
- bool ok_to_pad_channels = true
- && jcp.ngroups == 1
- && src_d.data_type() == data_type::f32;
-
- if (ok_to_pad_channels)
- jcp.oc = rnd_up(jcp.oc, jcp.simd_w);
-
- if (jcp.oc % jcp.oc_block)
- return status::unimplemented;
-
- auto dst_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
- auto wei_tag = with_groups
- ? pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o)
- : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o);
-
- if (diff_dst_d.format_kind() == format_kind::any) {
- CHECK(memory_desc_init_by_tag(diff_dst_md, dst_tag));
- jcp.dst_tag = dst_tag;
- } else {
- jcp.dst_tag = diff_dst_d.matches_one_of_tag(dst_tag);
- }
- if (jcp.dst_tag != dst_tag)
- return status::unimplemented;
-
- /* conditions on bias memory */
- jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
- if (jcp.with_bias) {
- if (diff_bias_d.format_kind() == format_kind::any)
- CHECK(memory_desc_init_by_tag(diff_bias_md, x));
- }
-
- jcp.nb_oc = jcp.oc / jcp.oc_block;
-
- /* kernel applicability check wrt boundaries
- * the conditions are quite general across the kernels we have,
- * but ideally the check should belong to a specific kernel... */
- const int max_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1) / 2;
- const bool boundaries_ok = true
- && jcp.t_pad <= max_pad
- && jcp.b_pad <= max_pad
- && IMPLICATION(jcp.f_pad > 0, jcp.kd < jcp.id + jcp.f_pad)
- && jcp.f_pad < jcp.kd;
- if (!boundaries_ok)
- return status::unimplemented;
-
- /* yet another common check */
- if (jcp.kw > 14)
- return status::unimplemented;
-
- /* setting register strategy */
- for (int ur_w = nstl::min(max_ur_w, jcp.ow); ur_w > 0; --ur_w) {
- if (jcp.ow % ur_w == 0) { jcp.ur_w = ur_w; break; }
- }
-
- if (jcp.is_1stconv) {
- auto src_tag = pick(ndims - 3, ncw, nchw, ncdhw);
- if (src_d.format_kind() == format_kind::any) {
- CHECK(memory_desc_init_by_tag(src_md, src_tag));
- jcp.src_tag = src_tag;
- } else {
- jcp.src_tag = src_d.matches_one_of_tag(src_tag);
- if (jcp.ic == 1 && jcp.src_tag != src_tag)
- jcp.src_tag = src_d.matches_one_of_tag(
- pick(ndims - 3, nwc, nhwc, ndhwc));
- }
- if (jcp.src_tag == format_tag::undef)
- return status::unimplemented;
-
- const bool src_ok = true
- && utils::everyone_is(data_type::f32,
- src_d.data_type(), diff_weights_d.data_type(),
- diff_dst_d.data_type())
- && one_of(jcp.ic, 1, 2, 3)
- && jcp.ngroups == 1;
- if (!src_ok)
- return status::unimplemented;
-
- const int tr_ld = rnd_up(div_up(jcp.iw + jcp.l_pad + jcp.r_pad,
- jcp.stride_w), 16);
- const int kh_step = nstl::max((28 - jcp.with_bias) / jcp.kw, 1);
- const int kh_step_rem = jcp.kh % kh_step;
-
- const auto wei_4fma_tag = with_groups
- ? pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o)
- : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o);
-
- auto current_wei_tag = format_tag::undef;
- if (diff_weights_d.format_kind() != format_kind::any)
- current_wei_tag = diff_weights_d.matches_one_of_tag(wei_4fma_tag);
-
- const bool use_4fma = true
- && one_of(ndims, 3, 4)
- && mayiuse(avx512_mic_4ops)
- && mkldnn_thr_syncable()
- && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w)
- && everyone_is(0, jcp.l_pad, jcp.r_pad, jcp.t_pad, jcp.b_pad)
- && jcp.kw <= 28 - jcp.with_bias
- && jcp.stride_w == 4
- && tr_ld / jcp.simd_w <= 4 /* [bwd_w:tr_src:r1] */
- && IMPLICATION(jcp.with_bias, kh_step_rem == 1) /* [bwd_w:b:r1] */
- && IMPLICATION(diff_weights_d.format_kind() != format_kind::any,
- current_wei_tag == wei_4fma_tag);
-
- if (use_4fma) {
- jcp.ver = ver_4fma;
- jcp.kh_step = kh_step;
- jcp.tr_ld = tr_ld;
- jcp.ic_block = 1;
- if (diff_weights_d.format_kind() == format_kind::any)
- CHECK(memory_desc_init_by_tag(diff_weights_md, wei_4fma_tag));
- jcp.wei_tag = wei_4fma_tag;
- } else {
- jcp.ver = ver_fma;
- jcp.ic_block = jcp.ic;
-
- wei_tag = with_groups
- ? pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)
- : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o);
-
- if (diff_weights_d.format_kind() == format_kind::any) {
- CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag));
- jcp.wei_tag = wei_tag;
- } else {
- jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
- }
- if (jcp.wei_tag != wei_tag)
- return status::unimplemented;
- }
-
- jcp.nb_ic = jcp.ic / jcp.ic_block;
- } else {
- auto src_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
- if (src_d.format_kind() == format_kind::any) {
- CHECK(memory_desc_init_by_tag(src_md, src_tag));
- jcp.src_tag = src_tag;
- } else {
- jcp.src_tag = src_d.matches_one_of_tag(src_tag);
- }
- if (jcp.src_tag != src_tag)
- return status::unimplemented;
-
- if (diff_weights_d.format_kind() == format_kind::any) {
- CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag));
- jcp.wei_tag = wei_tag;
- } else {
- jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
- }
- if (jcp.wei_tag != wei_tag)
- return status::unimplemented;
-
- jcp.ic_block = jcp.simd_w;
- if (ok_to_pad_channels)
- jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
- jcp.nb_ic = jcp.ic / jcp.ic_block;
- if ((mayiuse(avx512_mic) || mayiuse(avx512_core))
- && utils::everyone_is(data_type::f32,
- src_d.data_type(), diff_weights_d.data_type(),
- diff_dst_d.data_type())) {
- jcp.ver = ver_fma;
- if (one_of(ndims, 3, 4) && mayiuse(avx512_mic_4ops) && jcp.stride_w == 1 &&
- everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) &&
- mkldnn_thr_syncable()) {
- jcp.ver = ver_4fma;
- }
- } else {
- return status::unimplemented;
- }
- if (jcp.ver == ver_4fma) {
- jcp.ur_w = jcp.ow;
- // XXX, BUGBUGBUG, but not a FIXME: this assumes that it's OK to
- // cross the right boundary. The only requirement is not to have
- // NaNs there because another multiplicand is always guaranteed to
- // be zero. This also may require the top-level driver to allocate
- // four extra guarding elements at the very end of the buffer.
- // I'm not proud of this hack, but it improves performance by
- // about 5-10% depending on the dimensions (Roma)
-
- const int tr_round = 4;
-
- jcp.tr_iw = rnd_up(jcp.iw + jcp.kw - 1, tr_round);
- jcp.tr_src_num_guard_elems = tr_round; // upper bound
- }
- }
-
- if (utils::one_of(jcp.ver, ver_4fma, ver_fma)) {
- jcp.typesize_in = sizeof(float);
- jcp.typesize_out = sizeof(float);
- } else
- return status::unimplemented;
-
- bool args_ok = true
- && jcp.ic % jcp.ic_block == 0
- && jcp.oc % jcp.oc_block == 0
- && jcp.ic <= src_d.padded_dims()[1]
- && jcp.oc <= diff_dst_d.padded_dims()[1]
- && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
- && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
- if (!args_ok) return status::unimplemented;
-
- { // balancing
- int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
- balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b);
- jcp.nthr = nthr;
- jcp.nthr_mb = nthr_mb;
- jcp.nthr_g = nthr_g;
- jcp.nthr_oc_b = nthr_oc_b;
- jcp.nthr_ic_b = nthr_ic_b;
- }
-
- return status::success;
-}
-
-void jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad(
- memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
- if (jcp.ver == ver_4fma) {
- if (jcp.is_1stconv) {
- const size_t tr_src_size =
- jcp.nthr / jcp.nthr_oc_b * jcp.ih * jcp.stride_w * jcp.tr_ld;
- scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size);
- } else {
- // XXX: See the comment about tr_iw and guarding elements in
- // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf()
- const size_t max_nthr = jcp.nthr_mb * jcp.ngroups * jcp.nb_ic;
- const size_t min_tr_src_size_per_thr
- = jcp.ih * jcp.ic_block * jcp.tr_iw;
- const size_t tr_src_size = max_nthr * min_tr_src_size_per_thr
- + jcp.tr_src_num_guard_elems;
- scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size);
- }
-
- /* prepare synchronization contexts */
- if (jcp.nthr_oc_b > 1) {
- const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
- scratchpad.book(key_conv_tr_src_bctx,
- sizeof(simple_barrier::ctx_t) * tr_src_bctx_size);
- }
- }
-
- if (jcp.nthr_mb > 1) {
- const int wei_size = jcp.ngroups * jcp.oc * jcp.ic
- * jcp.kh * jcp.kw * jcp.kd;
- const int bia_size = jcp.ngroups * jcp.oc;
- const size_t wei_bia_reduction_size = wei_size + bia_size;
-
- scratchpad.book(key_conv_wei_bia_reduction,
- jcp.typesize_out * wei_bia_reduction_size * (jcp.nthr_mb - 1));
- scratchpad.book(key_conv_wei_bia_reduction_bctx,
- sizeof(simple_barrier::ctx_t));
- }
-
- if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
- scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
-}
-
-void jit_avx512_common_conv_bwd_weights_kernel_f32::balance(
- const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_,
- int &nthr_oc_b_, int &nthr_ic_b_)
-{
- nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
-
- const int max_threads = mkldnn_get_max_threads();
-
- if (max_threads < j.ngroups) {
- /* simplification... fortunately it doesn't hurt much */
- return;
- }
-
- if (!mkldnn_thr_syncable() && j.ver == ver_4fma) {
- // should not happen -- the driver is not ready
- // for TBB-like non-synchronous threading yet
- return;
- }
-
- if (j.ver == ver_4fma && j.is_1stconv) {
- nthr_g_ = 1;
- nthr_oc_b_ = 1;
- nthr_ic_b_ = nstl::min(j.nb_ic, max_threads);
- nthr_mb_ = nstl::min(max_threads / nthr_ic_b_, j.mb);
- nthr_ = nthr_mb_ * nthr_oc_b_ * nthr_ic_b_ * nthr_g_;
- return;
- }
-
- nthr_g_ = j.ngroups;
- const int nthr = max_threads / nthr_g_;
-
- auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
- /* calculate per thread memory cost (read/write). high level optimizer
- * tries to minimize memory consumption. few notes:
- * (n1) unclear why, but that essentially helps first convolution...
- * (n2) assuming the reduction over minibatch is always there:
- * - instead of 8 it should be 5 here (write ~= 2 read):
- * kernel: temporal workspace 1 write
- * reduction: 1 read from workspace and 1 write to the diff_wei
- * - but experiments showed 8 works better than 5 or 6... */
-
- const int src_coef = j.ver == ver_4fma ? 4 : 1;
- const int dst_coef = 1;
- const int wei_coef = 8;
-
- return 0
- + src_coef
- * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
- * div_up(j.nb_ic, nthr_ic_b) * j.ic_block * j.ih * j.iw * j.id
- / j.stride_d / j.stride_h / j.stride_w /* (n1) */
- + dst_coef
- * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
- * div_up(j.nb_oc, nthr_oc_b) * j.oc_block * j.oh * j.ow * j.od
- + wei_coef /* (n2) */
- * div_up(j.ngroups, nthr_g_)
- * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b)
- * j.kh * j.kw * j.kd * j.ic_block * j.oc_block;
- };
-
- int best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
-
- /* step 1: find the best thread distribution with lowest memory cost */
- const int nthr_mb_max = nstl::min(nthr, j.mb * j.od);
- for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
- const int nthr_par = nthr / nthr_mb;
- const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
- for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
- int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
-
- int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
- if (mem_cost <= best_mem_cost) {
- best_mem_cost = mem_cost;
- nthr_mb_ = nthr_mb;
- nthr_oc_b_ = nthr_oc_b;
- nthr_ic_b_ = nthr_ic_b;
- }
- }
-
- if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
- }
-
- if (!mayiuse(avx512_mic)) {
- auto calc_comp_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
- return 1
- * div_up(j.mb, nthr_mb)
- * div_up(j.ngroups, nthr_g_)
- * div_up(j.nb_oc, nthr_oc_b)
- * div_up(j.nb_ic, nthr_ic_b);
- };
-
- /* step 2: search for a thread distribution with lower compute cost.
- * the constrains:
- * - memory cost cannot exceed 110% of the best found in the step 1
- * - unless compute cost is 133% lower than the current best case
- * note: both constants were found empirically */
- int best_comp_cost = calc_comp_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
- for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
- const int nthr_par = nthr / nthr_mb;
- const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
- for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
- int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
- int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
- int comp_cost = calc_comp_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
-
- const bool opt1 = comp_cost <= best_comp_cost
- && mem_cost < 1.1 * best_mem_cost;
- const bool opt2 = 4 * comp_cost <= 3 * best_comp_cost;
-
- if (opt1 || opt2) {
- best_comp_cost = comp_cost;
- nthr_mb_ = nthr_mb;
- nthr_oc_b_ = nthr_oc_b;
- nthr_ic_b_ = nthr_ic_b;
- }
- }
-
- if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
- }
- }
-
- if (nthr_mb_ > max_threads/2 && nthr_mb_ < max_threads)
- nthr_mb_ = nstl::min(j.mb * j.od, max_threads);
- nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
-
- assert(nthr_ <= max_threads);
- assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_mb_ == 1));
-}
-
-template struct _jit_avx512_common_conv_fwd_kernel<Zmm>;
-template struct _jit_avx512_common_conv_fwd_kernel<Xmm>;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp
deleted file mode 100644
index f76770797a..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp
+++ /dev/null
@@ -1,423 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP
-#define JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-
-#include "jit_generator.hpp"
-#include "jit_primitive_conf.hpp"
-#include "jit_uni_eltwise.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template<typename Vmm>
-struct _jit_avx512_common_conv_fwd_kernel : public jit_generator {
-
- _jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp,
- const primitive_attr_t &attr)
- : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr)
- {
- if (jcp.with_eltwise)
- eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>(
- this, jcp.eltwise);
-
- generate();
- jit_ker_ = (void (*)(jit_conv_call_s *))getCode();
- }
-
- ~_jit_avx512_common_conv_fwd_kernel() {
- delete eltwise_injector_;
- }
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_fwd_kernel)
-
- jit_conv_conf_t jcp;
- const primitive_attr_t &attr_;
- void (*jit_ker_)(jit_conv_call_s *);
-
-private:
- using reg64_t = const Xbyak::Reg64;
- enum {
- typesize = sizeof(float),
- ker_reg_base_idx = 28,
- };
-
- reg64_t param = abi_param1;
- reg64_t reg_inp = r8;
- reg64_t reg_ker = r9;
- reg64_t reg_out = r10;
-
- reg64_t reg_inp_prf = r11;
- reg64_t reg_ker_prf = r12;
- reg64_t reg_out_prf = r13;
- reg64_t reg_owb = r12;
-
- reg64_t aux_reg_inp = r14;
- reg64_t aux_reg_ker = r15;
-
- reg64_t aux_reg_inp_prf = rsi;
- reg64_t aux_reg_ker_prf = rdx;
-
- reg64_t reg_channel = rsi;
- reg64_t reg_bias = rdx;
-
- reg64_t aux_reg_ker_d = r9;
- reg64_t aux_reg_inp_d = rbx;
- reg64_t aux_reg_inp_d_prf = r13;
- reg64_t aux_reg_ker_d_prf = abi_not_param1;
- reg64_t reg_ki = r10;
-
- reg64_t reg_kj = rax;
- reg64_t reg_relu_ns = rax;
- reg64_t reg_oi = rbx;
- reg64_t reg_kh = abi_not_param1;
-
- reg64_t reg_tmp = rbp;
-
- reg64_t reg_ic_loop = rdx;
- reg64_t reg_inp_loop = rsi;
-
- reg64_t reg_init_flag = r13;
- reg64_t reg_bias_ptr = param;
-
- reg64_t aux_reg_ic = r12;
- reg64_t reg_binp = rax;
- reg64_t reg_bout = r11;
- reg64_t aux1_reg_inp = rbx;
- reg64_t aux_reg_out = abi_not_param1;
-
- reg64_t reg_long_offt = r11;
- reg64_t reg_out_long_offt = r14;
-
- inline Vmm vmm_ker(int i_ic) {
- assert(i_ic < 4);
- return Vmm(ker_reg_base_idx + i_ic);
- }
-
- inline Vmm vmm_out(int i_ur, int i_oc) {
- int idx = i_ur + i_oc * jcp.ur_w;
- assert(idx < ker_reg_base_idx);
- return Vmm(idx);
- }
-
- inline Vmm vmm_inp(int i_ic, int nb_x_blocking) {
- int idx = i_ic + nb_x_blocking * jcp.ur_w;
- assert(idx < 31);
- return Vmm(idx);
- }
-
- Xbyak::Reg64 imm_addr64 = r15;
- Vmm vmm_wei = Vmm(31);
-
- jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_;
-
- inline void prepare_output(int ur_w);
- inline void store_output(int ur_w);
- inline void compute_loop_fma(int ur_w, int pad_l, int pad_r);
- inline void compute_loop_fma_core(int ur_w, int pad_l, int pad_r);
- inline void compute_loop_4fma(int ur_w, int pad_l, int pad_r);
- inline void compute_loop_4fma_1st(int ur_w, int pad_l, int pad_r);
- inline void compute_loop(int ur_w, int pad_l, int pad_r);
-
- void generate();
-
- inline size_t get_output_offset(int oi, int n_oc_block) {
- return (size_t)jcp.typesize_out * ((size_t)n_oc_block * jcp.oh
- * jcp.ow * jcp.od + oi) * jcp.oc_block;
- }
-
- inline size_t get_input_offset(int ki, int ic, int oi, int pad_l) {
- size_t iw_str = !jcp.is_1stconv ? jcp.ic_block : 1;
- size_t ic_str = !jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id;
- return (size_t)jcp.typesize_in * ((size_t)(ki * (jcp.dilate_w + 1)
- + oi * jcp.stride_w - pad_l) * iw_str + ic * ic_str);
- }
-
- inline int get_kernel_offset(int ki,int ic,int n_oc_block,int ker_number) {
- return jcp.typesize_in * jcp.oc_block
- * (n_oc_block * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd
- + (ic + ker_number) + ki * jcp.ic_block);
- }
-
- inline int get_ow_start(int ki, int pad_l) {
- return nstl::max(0,
- utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
- }
-
- inline int get_ow_end(int ur_w, int ki, int pad_r) {
- return ur_w - nstl::max(0, utils::div_up(pad_r
- - (jcp.kw - 1 - ki)
- * (jcp.dilate_w + 1),
- jcp.stride_w));
- }
-};
-
-struct jit_avx512_common_conv_fwd_kernel {
-
- jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp,
- const primitive_attr_t &attr) :
- jit_ker(nullptr),
- zmm_kernel_(nullptr),
- xmm_kernel_(nullptr) {
- int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.oc_block;
- switch (ch_block) {
- case 16:
- zmm_kernel_ =
- new _jit_avx512_common_conv_fwd_kernel<Xbyak::Zmm>(
- ajcp, attr);
- jit_ker = zmm_kernel_->jit_ker_;
- return;
- case 4:
- xmm_kernel_ =
- new _jit_avx512_common_conv_fwd_kernel<Xbyak::Xmm>(
- ajcp, attr);
- jit_ker = xmm_kernel_->jit_ker_;
- return;
- default:
- assert(!"invalid channel blocking");
- }
- }
-
- ~jit_avx512_common_conv_fwd_kernel() {
- delete xmm_kernel_;
- delete zmm_kernel_;
- }
-
- enum {
- typesize = sizeof(float)
- };
-
- static bool post_ops_ok(jit_conv_conf_t &jcp,
- const primitive_attr_t &attr);
- static status_t init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd,
- memory_desc_t &src_pd,
- memory_desc_t &weights_pd,
- memory_desc_t &dst_pd,
- memory_desc_t &bias_pd,
- const primitive_attr_t &attr,
- int nthreads);
- static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
- const jit_conv_conf_t &jcp);
-
- void(*jit_ker)(jit_conv_call_s *);
- _jit_avx512_common_conv_fwd_kernel<Xbyak::Zmm> *zmm_kernel_;
- _jit_avx512_common_conv_fwd_kernel<Xbyak::Xmm> *xmm_kernel_;
-};
-
-struct jit_avx512_common_conv_bwd_data_kernel_f32: public jit_generator {
-
- jit_avx512_common_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp)
- {
- generate();
- jit_ker = (void (*)(jit_conv_call_s *))getCode();
- }
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_data_kernel_f32)
-
- static status_t init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd,
- const memory_desc_wrapper &diff_src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &diff_dst_d);
- static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
- const jit_conv_conf_t &jcp);
-
- jit_conv_conf_t jcp;
- void (*jit_ker)(jit_conv_call_s *);
-
-private:
- using reg64_t = const Xbyak::Reg64;
- enum {
- typesize = sizeof(float),
- ker_reg_base_idx = 28,
- };
-
- reg64_t param = abi_param1;
- reg64_t reg_dst = r8;
- reg64_t reg_ker = r9;
- reg64_t reg_src = r10;
-
- reg64_t reg_dst_prf = r11;
- reg64_t reg_ker_prf = r12;
- reg64_t reg_src_prf = r13;
-
- reg64_t aux_reg_dst = r14;
- reg64_t aux_reg_ker = r15;
-
- reg64_t aux_reg_dst_prf = rsi;
- reg64_t aux_reg_ker_prf = rdx;
-
- reg64_t aux_reg_dst_d_prf = r13;
- reg64_t aux_reg_dst_d = rbx;
- reg64_t aux_reg_ker_d_prf = abi_not_param1;
- reg64_t aux_reg_ker_d = r9;
- reg64_t reg_ki = r10;
-
- reg64_t reg_kj = rax;
- reg64_t reg_oi = rbx;
- reg64_t reg_kh = abi_not_param1;
-
- reg64_t reg_channel = rsi;
-
- reg64_t reg_tmp = rbp;
- reg64_t reg_long_offt = r14;
-
- inline Xbyak::Zmm zmm_ker(int i_ic) {
- assert(i_ic < 4);
- return Xbyak::Zmm(ker_reg_base_idx + i_ic);
- }
- inline Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) {
- int idx = i_ic + nb_x_blocking * jcp.ur_w;
- assert(idx < 31);
- return Xbyak::Zmm(idx);
- }
- inline Xbyak::Zmm zmm_out(int i_ur, int i_oc) {
- int idx = i_ur + i_oc * jcp.ur_w;
- assert(idx < ker_reg_base_idx);
- return Xbyak::Zmm(idx);
- }
-
- Xbyak::Zmm zmm_wei = Xbyak::Zmm(31);
-
- inline void prepare_output(int ur_w);
- inline void store_output(int ur_w);
- inline void compute_loop_4fma(int ur_w, int l_overflow, int r_overflow);
- inline void compute_loop_fma(int ur_w, int l_overflow, int r_overflow);
- inline void compute_loop_fma_core(int ur_w, int l_overflow, int r_overflow);
- inline void compute_loop(int ur_w, int l_overflow, int r_overflow);
- void generate();
-
- inline int get_iw_start(int ki, int l_overflow)
- {
- int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w
- + l_overflow * jcp.stride_w
- - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1);
- while (res < 0)
- res += jcp.stride_w;
-
- return res;
- }
-
- inline int get_iw_end(int ur_w, int ki, int r_overflow)
- {
- if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail))
- ur_w += nstl::min(0, jcp.r_pad); // remove negative padding
- int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w
- + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1);
- while (res < 0)
- res += jcp.stride_w;
-
- return ur_w - res;
- }
-};
-
-struct jit_avx512_common_conv_bwd_weights_kernel_f32 : public jit_generator {
-
- jit_avx512_common_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp)
- : jcp(ajcp)
- {
- generate();
- jit_ker = (void (*)(jit_conv_call_s *))getCode();
- }
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_weights_kernel_f32)
-
- static status_t init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd,
- memory_desc_t &src_md,
- memory_desc_t &diff_weights_md,
- memory_desc_t &diff_bias_md,
- memory_desc_t &diff_dst_md);
- static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
- const jit_conv_conf_t &jcp);
-
- jit_conv_conf_t jcp;
- void (*jit_ker)(jit_conv_call_s *);
-
-private:
- using reg64_t = const Xbyak::Reg64;
- enum {typesize = sizeof(float)};
- static const int max_ur_w;
-
- reg64_t param = abi_param1;
- reg64_t reg_input = rax;
- reg64_t reg_kernel = rdx;
- reg64_t reg_output = rsi;
- reg64_t b_ic = abi_not_param1;
- reg64_t kj = r8;
- reg64_t reg_kh = r9;
- reg64_t reg_ur_w_trips = r10;
- reg64_t reg_oj = r15;
- reg64_t reg_ih_count = rbx;
- reg64_t reg_tmp = r14;
- reg64_t reg_long_offt = r14;
-
- reg64_t ki = r11;
- reg64_t reg_kd_count = r12;
- reg64_t reg_oi = r12;
- reg64_t reg_d_index = r13;
- reg64_t reg_input_d = r15;
- reg64_t reg_output_d = rbx;
- reg64_t aux_reg_input = r12;
- reg64_t aux_reg_kernel = r13;
- reg64_t reg_bias = rbx;
-
- inline void bias_kernel();
- inline void maybe_zero_kernel();
- inline void compute_oh_step_unroll_ow_icblock(int ic_block_step,
- int max_ur_w);
- inline void od_step_comeback_pointers();
- inline void oh_step_comeback_pointers();
- inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w);
- inline void compute_ic_block_step(int ur_w,
- int pad_l, int pad_r, int ic_block_step,
- int input_offset, int kernel_offset, int output_offset,
- bool input_wraparound = false);
- inline void compute_ic_block_step_fma(int ur_w,
- int pad_l, int pad_r, int ic_block_step,
- int input_offset, int kernel_offset, int output_offset,
- bool input_wraparound);
- inline void compute_ic_block_step_4fma(int ur_w,
- int pad_l, int pad_r, int ic_block_step,
- int input_offset, int kernel_offset, int output_offset,
- bool input_wraparound);
- inline void compute_oh_step_common(int ic_block_step, int max_ur_w);
- inline void compute_oh_step_disp();
- inline void compute_oh_loop_common();
- inline void compute_d_loop_common();
-
- inline bool compute_full_spat_loop();
- inline bool flat_4ops_compute();
-
- inline void compute_loop();
-
- void generate();
-
- static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb,
- int &nthr_g, int &nthr_oc_b, int &nthr_ic_b);
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp
deleted file mode 100644
index 1bdcd0d6a8..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp
+++ /dev/null
@@ -1,1163 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-#include "cpu_memory.hpp"
-
-#include <math.h>
-
-#include "jit_avx512_common_conv_winograd_kernel_f32.hpp"
-
-#ifndef KERNEL_SIZE_THRESHOLD
-#define KERNEL_SIZE_THRESHOLD 16
-#endif
-
-#define MIN_REQUIRED_DIMN_REG_BLOCK 14
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-namespace {
-
-using namespace mkldnn::impl::utils;
-
-unsigned int L1_cache_size = get_cache_size(1, true);
-unsigned int L2_cache_size = get_cache_size(2, true);
-unsigned int LLC_data_size = get_cache_size(3, false);
-
-// the test funtion takes jcp, the candidate and the current best.
-// it returns true if the new candidate is better
-int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number,
- int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int))
-{
- int best_divisor = default_best;
- auto test_num
- = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) {
- if (test(jcp, num, best_divisor)) {
- best_divisor = num;
- }
- };
-
- for (int divisor = 1; divisor <= ::sqrt(number); divisor++) {
- if (number % divisor == 0) {
- test_num(jcp, divisor);
- test_num(jcp, number / divisor);
- }
- }
-
- return best_divisor;
-}
-
-namespace {
-bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) {
- if (jcp.ver == ver_4fma)
- return jcp.mb >= 32;
- else
- return jcp.mb >= 16;
-}
-}
-
-/* assumes 512 bits registers */
-/* TODO: add support for strides */
-/* TODO: handle the prefetch distance automatically */
-typedef enum cache_t_ { L1, L2, L3 } cache_t;
-
-template <typename data_t>
-struct prefetcher_t {
- prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr,
- cache_t cache_type, size_t block_size, /* in number of elements*/
- int nb_instructions_in_block, int fma_ipc)
- : cg_(generator)
- , reg_base_addr_(reg_base_addr)
- , cache_type_(cache_type)
- , cache_block_size_(block_size)
- {
- nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t));
- prefetch_spread_
- = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_);
- prefetch_blk_
- = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block);
-
- /* assumption: when fetch in Li, data is already in L(i+1) */
- int cache_latency;
- switch (cache_type_) {
- case L1: cache_latency = 14; break;
- case L2:
- case L3:
- default: cache_latency = 250; break;
- }
-
- prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_);
- }
-
- void prefetch(int instruction_number)
- {
- if (instruction_number % prefetch_spread_ == 0) {
- for (int i = 0; (i < prefetch_blk_)
- && (prefetches_issued_ < nb_cache_lines_to_prefetch_);
- i++, prefetches_issued_++) {
- prefetch_inst_(cg_->EVEX_compress_addr(
- reg_base_addr_, (cache_block_size_ * prefetch_distance_)
- * sizeof(data_t)
- + (prefetches_issued_ * 64)));
- }
- }
- }
-
-private:
- void prefetch_inst_(const Xbyak::Address &addr)
- {
- switch (cache_type_) {
- case L1: cg_->prefetcht0(addr); break;
- case L2: cg_->prefetcht1(addr); break;
- case L3: cg_->prefetcht2(addr); break;
- default:
- break; // TODO: raise an exception or put an assert
- }
- }
-
- jit_generator *cg_;
- Xbyak::Reg64 reg_base_addr_;
- cache_t cache_type_;
- int cache_block_size_ = 0;
- int nb_cache_lines_to_prefetch_ = 0;
- int prefetches_issued_ = 0;
- int prefetch_spread_ = 0;
- int prefetch_blk_ = 0;
- int prefetch_distance_ = 0;
-};
-
-// utilities to support kernel parameter selection
-bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block,
- int dimM_block, int dimM_simd_block, float C)
-{
- float lhs = (dimM_block * dimN_reg_block * dimM_simd_block
- + dimM_block * dimK_block * dimK_reg_block
- * dimM_simd_block
- + dimK_block * dimN_reg_block * dimK_reg_block)
- * (float)sizeof(float);
- float rhs = C * L1_cache_size;
- return (lhs < rhs);
-}
-
-bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block,
- int dimM_block, int dimM_simd_block, float C)
-{
- float lhs = (dimM_block * dimK_block * dimK_reg_block * dimM_simd_block
- + dimK_block * dimN_reg_block * dimK_reg_block)
- * (float)sizeof(float);
- float rhs = C * L1_cache_size;
- return (lhs < rhs);
-}
-
-bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block,
- int dimK_block, int dimK_reg_block, int dimM_block, int dimM_simd_block,
- float C)
-{
- float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block * dimM_simd_block
- + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block
- * dimM_simd_block
- + nb_dimN_reg_block * dimK_nb_block * dimK_block
- * dimN_reg_block * dimK_reg_block)
- * (float)sizeof(float);
- float rhs = C * L2_cache_size;
- return (lhs < rhs);
-}
-}
-
-using namespace mkldnn::impl::format_tag;
-using namespace mkldnn::impl::utils;
-using namespace Xbyak;
-
-void _jit_avx512_common_conv_winograd_data_kernel_f32::gemm_loop_generate(
- bool is_beta_zero)
-{
- // const int dimK_simd_block = jcp.dimK_reg_block;
-
- // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++)
- // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++)
- // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block;
- // dimK_reg_block++)
- // for (int tile =0; tile < jcp.dimN_reg_block; tile++)
- // C[dimM_block][tile] +=
- // A[dimM_block][dimK_block][dimK_reg_block] *
- // broadcast(B[dimK_block][tile][dimK_reg_block]);
- // 1) We do register blocking on A[dimM_block][dimK_block][dimK_reg_block],
- // so we load it before the loop on tile
- // 2) the loop on tile must be fully unrolled. Don't know about the one on
- // dimK_reg_block. I think it should be
-
- auto inner_loops = [=]() {
- Label dimM_block_loop, dimK_block_loop;
- const int inc_dimK_reg_block = jcp.ver == ver_4fma ? 4 : 1;
- const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2;
-
- prefetcher_t<float> L1_pf(this, reg_srcB, L1,
- jcp.dimN_reg_block * jcp.dimK_reg_block,
- jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block,
- fma_ipc);
- prefetcher_t<float> L2_pf(this, reg_srcB, L2,
- jcp.dimN_reg_block * jcp.dimK_reg_block,
- jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block,
- fma_ipc);
-
- if (jcp.dimM_block > 1) {
- mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
- L(dimM_block_loop);
- }
- {
- // First, we zero the accumulators if first nb_ic iteration,
- // otherwise we load them
- for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
- Zmm zmm(jcp.zmm_start + tile);
- if (is_beta_zero)
- vpxord(zmm, zmm, zmm);
- else
- vmovups(zmm, zword[reg_dstC + 64 * tile]);
- }
-
- if (jcp.dimK_block > 1) {
- mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
- L(dimK_block_loop);
- }
- {
- auto load_A = [=](int reg_idx, int offset) {
- for (int i = 0; i < inc_dimK_reg_block; i++)
- vmovups(Zmm(reg_idx + i),
- zword[reg_srcA + 64 * (offset + i)]);
- };
-
- // Used when doing double buffering
- int next = 0;
- if (jcp.double_buffering) {
- load_A(next, 0);
- }
- for (int dimK_reg_block = 0;
- dimK_reg_block < jcp.dimK_reg_block;
- dimK_reg_block += inc_dimK_reg_block) {
- int current;
- /* Loading the next vector from A */
- current = next;
- if (jcp.double_buffering) {
- next = (dimK_reg_block + inc_dimK_reg_block)
- % (2 * inc_dimK_reg_block);
- load_A(next, dimK_reg_block + inc_dimK_reg_block);
- } else {
- next = 0;
- load_A(next, dimK_reg_block);
- }
- /* Performing the fmas */
- for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
- Zmm zmm(jcp.zmm_start + tile);
- if (jcp.ver != ver_avx512_core)
- L1_pf.prefetch(
- dimK_reg_block * jcp.dimN_reg_block + tile);
- if (jcp.ver == ver_4fma)
- v4fmaddps(zmm, Zmm(current),
- EVEX_compress_addr(reg_srcB,
- 64 * tile + dimK_reg_block * 4));
- else
- vfmadd231ps(zmm, Zmm(current),
- EVEX_compress_addr(reg_srcB,
- 64 * tile + dimK_reg_block * 4,
- true));
- if (jcp.ver != ver_avx512_core)
- L2_pf.prefetch(
- dimK_reg_block * jcp.dimN_reg_block + tile);
- }
- }
-
- add(reg_srcA, jcp.dimK_reg_block * 64);
- add(reg_srcB, jcp.dimN_reg_block * 64);
- if (jcp.dimK_block > 1) {
- sub(reg_dimK_block_loop_cnt, 1);
- jnz(dimK_block_loop);
- }
- }
-
-
- auto store_output = [=](bool output_is_aligned) {
- for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
- Zmm zmm(jcp.zmm_start + tile);
- if (output_is_aligned
- && jcp.dimK_nb_block == 1
- && (jcp.dimN * jcp.dimM * alpha * alpha
- * sizeof(float) > 2 * LLC_data_size))
- vmovntps(zword[reg_dstC + 64 * tile], zmm);
- else
- vmovups(zword[reg_dstC + 64 * tile], zmm);
- }
- };
-
- Label unaligned_store, end_store;
- test(reg_dstC, cpu_isa_traits<avx512_common>::vlen - 1);
- jnz(unaligned_store, T_NEAR);
- store_output(true);
- jmp(end_store, T_NEAR);
- L(unaligned_store); {
- store_output(false);
- }
- L(end_store);
-
- if (jcp.dimM_block > 1) {
- sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64);
- add(reg_dstC, jcp.dimN_reg_block * 64);
- sub(reg_dimM_block_loop_cnt, 1);
- jnz(dimM_block_loop);
- }
- }
- };
-
- /* Preamble */
- preamble();
-
- /* kernel */
- inner_loops();
-
- /* Postamble */
- postamble();
- ret();
-}
-
-status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_common(
- jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d)
-{
-
- if (mayiuse(avx512_core))
- return status::unimplemented;
- else if (!mayiuse(avx512_common))
- return status::unimplemented;
- else if (mayiuse(avx512_mic_4ops))
- jcp.ver = ver_4fma;
- else
- jcp.ver = ver_fma;
-
- jcp.nthr = mkldnn_get_max_threads();
-
- const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
-
- jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
- jcp.mb = src_d.dims()[0];
- jcp.oc = dst_d.dims()[1] / jcp.ngroups;
- jcp.oc_without_padding = jcp.oc;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
- jcp.ih = src_d.dims()[2];
- jcp.iw = src_d.dims()[3];
- jcp.oh = dst_d.dims()[2];
- jcp.ow = dst_d.dims()[3];
- jcp.kh = weights_d.dims()[with_groups + 2];
- jcp.kw = weights_d.dims()[with_groups + 3];
- jcp.t_pad = cd.padding[0][0];
- jcp.l_pad = cd.padding[0][1];
- jcp.stride_h = cd.strides[0];
- jcp.stride_w = cd.strides[1];
- jcp.dilate_h = cd.dilates[0];
- jcp.dilate_w = cd.dilates[1];
- jcp.r_pad = nstl::max(
- 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
- jcp.b_pad = nstl::max(
- 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
- jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
- jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
- jcp.ohp = jcp.oh;
- jcp.owp = jcp.ow;
-
- bool ok_to_pad_channels = jcp.ngroups == 1;
- if (ok_to_pad_channels) {
- jcp.oc = rnd_up(jcp.oc, simd_w);
- jcp.ic = rnd_up(jcp.ic, simd_w);
- }
-
- if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
- is_winograd_faster_than_direct(jcp)))
- return status::unimplemented;
-
- // Checking conditions not supported by these kernels
- if (jcp.ngroups != 1)
- return status::unimplemented;
- if ((jcp.kh != 3) || (jcp.kw != 3))
- return status::unimplemented;
- if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
- return status::unimplemented;
- if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
- return status::unimplemented;
- if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
- return status::unimplemented;
-
- format_tag_t dat_tag = nChw16c;
- format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o;
- jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
- jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
- jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
-
- if (jcp.src_tag != dat_tag) return status::unimplemented;
- if (jcp.wei_tag != wei_tag) return status::unimplemented;
- if (jcp.dst_tag != dat_tag) return status::unimplemented;
-
- bool layout_consistency = true
- && jcp.ic <= src_d.padded_dims()[1]
- && jcp.oc <= dst_d.padded_dims()[1]
- && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
- && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
- if (!layout_consistency) return status::unimplemented;
-
- return status::success;
-}
-
-
-status_t set_wsched_DATA_W_S_G_D_avx512_common(jit_conv_winograd_conf_t &jcp) {
-
- auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
- int dimN_reg_block, int current_best) {
- return (dimN_reg_block >= MIN_REQUIRED_DIMN_REG_BLOCK)
- && (dimN_reg_block < jcp.nb_reg)
- && (dimN_reg_block < current_best);
- };
- jcp.dimN_reg_block = get_divisor_satisfying_cond(
- jcp, jcp.dimN, jcp.dimN, test_cond_dimN_reg_block);
-
- if (jcp.dimN_reg_block >= jcp.nb_reg) {
- auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
- int dimN_reg_block, int current_best) {
- return (dimN_reg_block < jcp.nb_reg)
- && (dimN_reg_block > current_best);
- };
-
- jcp.dimN_reg_block = get_divisor_satisfying_cond(
- jcp, jcp.dimN, 1, test_cond_dimN_reg_block);
- }
-
- //********************* Choosing dimK_block **********************//
- auto test_cond1_dimK_block = [](
- jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
- return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block,
- 1, jcp.dimM_simd_block, .75f)
- && (dimK_block > current_best);
- };
-
- auto test_cond1_bis_dimK_block = [](
- jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
- return check_cond1_bis(jcp.dimN_reg_block, dimK_block,
- jcp.dimK_reg_block, 1, jcp.dimM_simd_block, .9f)
- && (dimK_block > current_best);
- };
-
- jcp.dimK_block = get_divisor_satisfying_cond(
- jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block);
- // If we are not able to use streams, we fall back to condition [1]
- if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
- jcp.dimK_block = get_divisor_satisfying_cond(
- jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block);
- jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block;
-
- //********************* Choosing dimM_block **********************//
- jcp.dimM_simd_block = 16;
- /*XXX: Why C=0.5 here but C=0.75 for dimK_block?*/
- auto test_cond1_dimM_block = [](
- jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
- return check_cond1(jcp.dimN_reg_block, jcp.dimK_block,
- jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .5f)
- && (dimM_block > current_best);
- };
-
- auto test_cond1_bis_dimM_block = [](
- jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
- return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block,
- jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .3f)
- && (dimM_block > current_best);
- };
-
- if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
- jcp.dimM_block = get_divisor_satisfying_cond(
- jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block);
- else
- jcp.dimM_block = get_divisor_satisfying_cond(jcp,
- jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_bis_dimM_block);
- jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block;
-
- //******************* Choosing dimN_block *******************//
- auto test_cond2_dimN_block = [](
- jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) {
- return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block,
- jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block,
- jcp.dimM_simd_block, .5f)
- && (dimN_block > current_best);
- };
-
- jcp.dimN_block = get_divisor_satisfying_cond(
- jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block);
- jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block);
- jcp.sched_policy = WSCHED_DATA_W_S_G_D;
- return status::success;
-}
-
-status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_kernel(
- jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK)
-{
- jcp.dimK_reg_block = 16;
- jcp.dimM_simd_block = 16;
-
- // TODO: replace double buffering with nuple buffering to maximize register
- // usage.
- // the choice of the number of buffers will then come after choosing
- // dimN_reg_block
- jcp.double_buffering = true;
- if (jcp.double_buffering)
- jcp.zmm_start = 2 * ((jcp.ver == ver_4fma) ? 4 : 2);
- else
- jcp.zmm_start = 1;
- jcp.nb_reg = 32 - jcp.zmm_start;
-
- jcp.dimN = dimN;
- jcp.dimK = dimK;
- jcp.dimM = dimM;
-
- jcp.sched_policy = WSCHED_INVALID;
- set_wsched_DATA_W_S_G_D_avx512_common(jcp);
-
- assert(jcp.sched_policy == WSCHED_DATA_W_S_G_D);
- return status::success;
-}
-
-bool jit_avx512_common_conv_winograd_fwd_kernel_f32::post_ops_ok(
- jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
- const auto &p = attr.post_ops_;
-
- auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
- auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
-
- switch (p.len_) {
- case 0: return true; // no post_ops
- case 1: return is_relu(0) || is_sum(0); // relu or sum
- case 2: return (is_sum(0) && is_relu(1)) ||
- (is_relu(0) && is_sum(1)); // sum->relu or relu->sum
- case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu
- default: return false;
- }
-
- return false;
-}
-
-status_t jit_avx512_common_conv_winograd_fwd_kernel_f32::init_conf(
- jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) {
- status_t st = init_conf_common(jcp, cd, src_d, weights_d, dst_d);
-
- if (st != status::success)
- return st;
-
- // Winograd specific initialization
- jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
- jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
- jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
-
- jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
-
- if (!post_ops_ok(jcp, attr))
- return status::unimplemented;
-
- const auto &p = attr.post_ops_;
- const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1);
- jcp.with_eltwise = eltwise_ind != -1;
- if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise;
- jcp.with_sum = p.find(primitive_kind::sum, 0) != -1;
-
- status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic);
- jcp.ic_simd_block = jcp.dimK_reg_block;
- jcp.ic_block = jcp.dimK_block;
- jcp.nb_ic = jcp.dimK_nb_block;
- jcp.oc_simd_block = jcp.dimM_simd_block;
- jcp.oc_block = jcp.dimM_block;
- jcp.nb_oc = jcp.dimM_nb_block;
- jcp.tile_block_ur = jcp.dimN_reg_block;
- jcp.nb_tile_block_ur = jcp.dimN_block;
- jcp.tile_block = jcp.dimN_nb_block;
- jcp.tile_4fma_padding = 0; // only relevant for backward weights
-
- return res;
-}
-
-status_t jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf(
- jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
- const memory_desc_wrapper &diff_src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &diff_dst_d)
-{
- status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d);
-
- if (st != status::success)
- return st;
-
- jcp.itiles = (jcp.iw + tile_size - 1) / tile_size;
- jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size;
- jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
-
- status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc);
- jcp.oc_simd_block = jcp.dimK_reg_block;
- jcp.oc_block = jcp.dimK_block;
- jcp.nb_oc = jcp.dimK_nb_block;
- jcp.ic_simd_block = jcp.dimM_simd_block;
- jcp.ic_block = jcp.dimM_block;
- jcp.nb_ic = jcp.dimM_nb_block;
- jcp.tile_block_ur = jcp.dimN_reg_block;
- jcp.nb_tile_block_ur = jcp.dimN_block;
- jcp.tile_block = jcp.dimN_nb_block;
- jcp.tile_4fma_padding = 0; // only relevant for backward weights
-
- return res;
-}
-
-void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::transpose_ker_generate()
-{
- auto load_B = [=](int reg_idx, int offset) {
- for (int i = 0; i < 4; i++) {
- vmovups(Zmm(reg_idx + i), zword[reg_origB + (offset + i) * jcp.dimN_reg_block * sizeof(float)]);
- }
- };
-
- preamble();
- int curr = 0;
- for (int j = 0; j < alpha; j++) {
- for (int i = 0; i < alpha; i++) {
- int origB_offset = (j * alpha + i) * jcp.dimK_4fma;
- size_t transB_offset = (size_t)(j * alpha + i) * jcp.dimK_nb_block *
- jcp.dimN_block * jcp.dimK_block * jcp.dimK_reg_block *
- jcp.dimK_4fma * jcp.dimN_reg_block * sizeof(float);
- mov(reg_transB_idx, transB_offset);
- for (int tb = 0; tb < jcp.dimK_4fma; tb+=4) {
- /*double buffering to hide load latencies*/
- int next = (curr + 4) % 8;
- if (i == 0 && tb == 0) {
- load_B(0, origB_offset);
- }
- if (tb + 4 < (jcp.dimK_4fma -1)) {
- load_B(next, origB_offset + 4);
- } else if (i < alpha - 1) {
- load_B(next, origB_offset + jcp.dimK_4fma);
- }
-
- vunpcklps(Zmm(8), Zmm(curr), Zmm(curr + 1));
- vunpcklps(Zmm(9), Zmm(curr + 2), Zmm(curr + 3));
- vunpckhps(Zmm(curr), Zmm(curr), Zmm(curr + 1));
- vunpckhps(Zmm(curr + 1), Zmm(curr + 2), Zmm(curr + 3));
-
- vunpcklpd(Zmm(curr + 2), Zmm(8), Zmm(9));
- vunpckhpd(Zmm(curr + 3), Zmm(8), Zmm(9));
-
- vunpcklpd(Zmm(8), Zmm(curr), Zmm(curr + 1));
- vunpckhpd(Zmm(9), Zmm(curr), Zmm(curr + 1));
-
- vmovntps(zword[reg_transB + reg_transB_idx
- + sizeof(float) * tb * jcp.dimN_reg_block],
- Zmm(curr+2));
- vmovntps(zword[reg_transB + reg_transB_idx
- + sizeof(float) * (tb + 1) * jcp.dimN_reg_block],
- Zmm(curr+3));
- vmovntps(zword[reg_transB + reg_transB_idx
- + sizeof(float) * (tb + 2) * jcp.dimN_reg_block],
- Zmm(8));
- vmovntps(zword[reg_transB + reg_transB_idx
- + sizeof(float) * (tb + 3) * jcp.dimN_reg_block],
- Zmm(9));
- curr = next;
-
- }
- }
- }
- postamble();
- ret();
-}
-void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::gemm_loop_generate(
- bool is_first_tile)
-{
- // for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++)
- // for (int ifm2 = 0; ifm2 < jcp.ic_block; ifm2++)
- // for (int nb_tile_block_ur = 0; nb_tile_block_ur <
- // jcp.nb_tile_block_ur; nb_tile_block_ur++)
- // for (int tile_block_ur = 0; tile_block_ur <
- // jcp.tile_block_ur; tile_block_ur++)
- // for (int ifm3 = 0; ifm3 < jcp.ic_reg_block; ++ifm3)
- // U[ofm2][ifm2][ofm3][ifm3][0:oc_simd_block] +=
- // M[ofm2][ofm3][nb_tile_block_ur][tile_block_ur][0:oc_simd_block]
- // *
- // broadcast(V[ifm2][nb_tile_block_ur][ifm3][tile_block_ur])
- auto inner_loops = [=]() {
- int inc_fma = jcp.ver == ver_4fma ? 4 : 1;
- const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2;
- prefetcher_t<float> L1_pf(this, reg_srcB, L1,
- jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma,
- jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma
- / inc_fma,
- fma_ipc);
- prefetcher_t<float> L2_pf(this, reg_srcB, L2,
- jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma,
- jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma
- / inc_fma,
- fma_ipc);
-
- auto load_A = [=](int reg_idx, int offset) {
- for (int i = 0; i < inc_fma; i++) {
- vmovups(Zmm(reg_idx + i),
- zword[reg_srcA +
- sizeof(float) * jcp.dimM_simd_block * (offset + i)]);
- }
- };
-
- Label dimM_block_loop, dimK_block_loop, dimN_block_loop;
- if (jcp.dimM_block > 1) {
- mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
- L(dimM_block_loop);
- }
- { /************* OC_block (M) loop ***********/
- if (jcp.dimN_block > 1) {
- mov(reg_dimN_block_loop_cnt, jcp.dimN_block);
- L(dimN_block_loop);
- }
- { /*************** IC_block (N) loop *********/
- for (int dimN_reg_block = 0;
- dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) {
- Zmm zmm(jcp.zmm_start + dimN_reg_block);
- if (is_first_tile)
- vpxord(zmm, zmm, zmm);
- else
- vmovups(zmm, zword[reg_dstC +
- dimN_reg_block * jcp.dimM_simd_block *
- sizeof(float)]);
- }
-
- if (jcp.dimK_block > 1) {
- mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
- L(dimK_block_loop);
- }
- { /************* nb_tile_ur(K) loop ********/
- int next = 0;
- if (jcp.double_buffering) {
- load_A(next, 0);
- }
- for (int dimK_reg_block = 0;
- dimK_reg_block < jcp.dimK_reg_block;
- dimK_reg_block++) {
- int srcB_offset = dimK_reg_block * jcp.dimK_4fma
- * jcp.dimN_reg_block;
- for (int dimK_4fma = 0; dimK_4fma < jcp.dimK_4fma;
- dimK_4fma += inc_fma) {
- int current = next;
- if (jcp.double_buffering) {
- next = (dimK_reg_block * jcp.dimK_4fma
- + dimK_4fma + inc_fma)
- % (2 * inc_fma);
- load_A(next, dimK_reg_block * jcp.dimK_4fma
- + dimK_4fma + inc_fma);
- } else {
- next = 0;
- load_A(next, dimK_reg_block * jcp.dimK_4fma
- + dimK_4fma);
- }
- for (int dimN_reg_block = 0;
- dimN_reg_block < jcp.dimN_reg_block;
- ++dimN_reg_block) {
- L1_pf.prefetch(srcB_offset / inc_fma
- + dimK_4fma / inc_fma
- * jcp.dimN_reg_block
- + dimN_reg_block);
- L2_pf.prefetch(srcB_offset / inc_fma
- + dimK_4fma / inc_fma
- * jcp.dimN_reg_block
- + dimN_reg_block);
- if (jcp.ver == ver_4fma) {
- int srcB_trans_offset = (dimK_4fma / 4) * 64
- + dimK_4fma % 4;
- v4fmaddps(
- Zmm(jcp.zmm_start + dimN_reg_block),
- Zmm(current),
- EVEX_compress_addr(reg_srcB,
- sizeof(float) * (
- srcB_offset +
- srcB_trans_offset +
- (dimN_reg_block % 4) * 16 +
- (dimN_reg_block / 4) * 4)));
- } else {
- vfmadd231ps(
- Zmm(jcp.zmm_start + dimN_reg_block),
- Zmm(current),
- EVEX_compress_addr(reg_srcB,
- sizeof(float) * (srcB_offset + dimN_reg_block),
- true));
- }
- }
- }
- }
- }
-
- add(reg_srcA, jcp.dimK_reg_block * jcp.dimK_4fma
- * jcp.dimM_simd_block * sizeof(float));
- add(reg_srcB, jcp.dimK_reg_block * jcp.dimN_reg_block
- * jcp.dimK_4fma * sizeof(float));
- if (jcp.dimK_block > 1) {
- sub(reg_dimK_block_loop_cnt, 1);
- jnz(dimK_block_loop);
- }
-
- /******** Write C back to memory *******/
- for (int dimN_reg_block = 0;
- dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) {
- Zmm zmm(jcp.zmm_start + dimN_reg_block);
- vmovups(zword[reg_dstC +
- dimN_reg_block * jcp.dimM_simd_block * sizeof(float)],
- zmm);
- }
-
- sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block *
- jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float));
- add(reg_dstC, jcp.dimN_reg_block * jcp.dimM_simd_block
- * sizeof(float));
- if (jcp.dimN_block > 1) {
- sub(reg_dimN_block_loop_cnt, 1);
- jnz(dimN_block_loop);
- }
- }
-
- if (jcp.dimM_block > 1) {
- sub(reg_srcB, jcp.dimN_block * jcp.dimK_block
- * jcp.dimK_reg_block * jcp.dimN_reg_block
- * jcp.dimK_4fma * sizeof(float));
- add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block
- * jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float));
- sub(reg_dimM_block_loop_cnt, 1);
- jnz(dimM_block_loop);
- }
- }
- };
-
- /* Preamble */
- // register used to handle long fma encoding
- preamble();
- mov(reg_srcA, reg_srcA_const);
- inner_loops();
-
- /* Postamble */
- postamble();
- ret();
-}
-
-namespace {
-bool check_cond1_wu(int dimM_block, int dimM_simdw, int dimK_block,
- int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C)
-{
- float lhs = 1.0f * dimM_block * dimN_reg_block * dimM_simdw;
- lhs += dimM_block * dimK_block * dimK_reg_block * dimK_4fma * dimM_simdw;
- lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma;
- lhs *= sizeof(float);
- float rhs = C * L1_cache_size;
- return (lhs <= rhs);
-}
-
-bool check_cond1bis_wu(int dimM_block, int dimM_simdw, int dimK_block,
- int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C)
-{
- float lhs = 1.0f * dimM_block * dimK_block * dimK_reg_block * dimK_4fma
- * dimM_simdw;
- lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma;
- lhs *= sizeof(float);
- float rhs = C * L1_cache_size;
- return (lhs <= rhs);
-}
-
-bool check_cond2bis_wu(int dimM_block, int dimM_simdw, int dimK_block,
- int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block,
- float C)
-{
- float lhs = 1.0f * dimM_block * dimM_simdw * dimK_block * dimK_reg_block
- * dimK_4fma;
- lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block
- * dimN_reg_block;
- lhs *= sizeof(float);
- float rhs = C * L2_cache_size;
- return (lhs <= rhs);
-}
-
-bool check_cond2_wu(int dimM_block, int dimM_simdw, int dimK_block,
- int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block,
- float C)
-{
- float lhs = 1.0f * dimM_block * dimM_simdw * dimN_block * dimN_reg_block;
- lhs += dimM_block * dimM_simdw * dimK_block * dimK_reg_block * dimK_4fma;
- lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block
- * dimN_reg_block;
- lhs *= sizeof(float);
- float rhs = C * L2_cache_size;
- return (lhs <= rhs);
-}
-} // namespace
-
-status_t set_wsched_WEI_S_D_G_W_avx512_common(jit_conv_winograd_conf_t &jcp)
-{
- /*************** Choose dimN_reg_block (ic_simd_block)
- * *******************************/
- jcp.dimN = jcp.ic;
- /*Hardcoded to 16 because N = ic for bwd weights and
- innermost dimension for ic is assumed 16 in src transforms. This
- choice covers load latencies while maintaining simplicity of kernel
- for POR topologies. FIXME in future??: Will not work for future topologies
- when ic%16 != 0*/
- jcp.dimN_reg_block = jcp.ic_simd_block;
-
- /****************************** Choose dimK_block
- * **************************/
- // No freedom for choosing dimM_simd_block because ic_simd_block
- // is determined by input data format
- jcp.dimM_simd_block = jcp.oc_simd_block;
-
- auto test_cond1bis_dimK_block = [](
- jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
- return check_cond1bis_wu(1, jcp.dimM_simd_block, dimK_block, 1,
- jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f)
- && (dimK_block > current_best);
- };
-
- auto test_cond1_dimK_block = [](
- jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
- return check_cond1_wu(1, jcp.dimM_simd_block, dimK_block, 1,
- jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f)
- && (dimK_block > current_best);
- };
-
- auto test_cond2bis_dimK_block = [](
- jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
- return check_cond2bis_wu(1, jcp.dimM_simd_block, dimK_block, 1,
- jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.5f)
- && (dimK_block > current_best);
- };
-
- auto test_cond2_dimK_block = [](
- jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
- return check_cond2_wu(1, jcp.dimM_simd_block, dimK_block, 1,
- jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.1f)
- && (dimK_block > current_best);
- };
-
- jcp.dimK_block = get_divisor_satisfying_cond(
- jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2bis_dimK_block);
- if (jcp.dimK_block < jcp.dimK / jcp.dimK_4fma)
- jcp.dimK_block = get_divisor_satisfying_cond(
- jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2_dimK_block);
-
- jcp.dimK_reg_block = get_divisor_satisfying_cond(
- jcp, jcp.dimK_block, 1, test_cond1bis_dimK_block);
- if (jcp.dimK_reg_block < jcp.dimK_block) {
- jcp.dimK_reg_block = get_divisor_satisfying_cond(
- jcp, jcp.dimK_block, 1, test_cond1_dimK_block);
- }
- jcp.dimK_block /= jcp.dimK_reg_block;
- jcp.dimK_nb_block
- = jcp.dimK / jcp.dimK_4fma / jcp.dimK_reg_block / jcp.dimK_block;
- jcp.tile_block_ur = jcp.dimK_reg_block;
- jcp.nb_tile_block_ur = jcp.dimK_block;
- jcp.tile_block = jcp.dimK_nb_block;
-
- /***************************** Chose dimN block
- * ****************************/
- auto test_cond2_dimN_block = [](
- jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) {
- return check_cond2_wu(1, jcp.dimM_simd_block, jcp.dimK_block,
- jcp.dimK_reg_block, jcp.dimK_4fma, dimN_block,
- jcp.dimN_reg_block, 0.5f)
- && (dimN_block > current_best);
- };
-
- jcp.dimN_block = get_divisor_satisfying_cond(
- jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block);
- jcp.ic_block = jcp.dimN_block;
- jcp.dimN_nb_block = jcp.dimN / jcp.dimN_reg_block / jcp.dimN_block;
- jcp.nb_ic = jcp.dimN_nb_block;
-
- /********************************* Choose dimM block
- * ************************/
- jcp.dimM = jcp.oc;
-
- auto test_cond1_dimM_block = [](
- jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
- return check_cond1_wu(dimM_block, jcp.dimM_simd_block, 1,
- jcp.dimK_reg_block, jcp.dimK_4fma, jcp.dimN_reg_block,
- 1.0f)
- && (dimM_block > current_best)
- && (jcp.dimM / jcp.dimM_simd_block / dimM_block) >= 2;
- };
-
- jcp.dimM_block = get_divisor_satisfying_cond(
- jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block);
- jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block;
-
- jcp.sched_policy = WSCHED_WEI_S_D_G_W;
- return status::success;
-}
-
-status_t jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::init_conf(
- jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d,
- const memory_desc_wrapper &diff_weights_d)
-{
- jcp.nthr = mkldnn_get_max_threads();
-
- const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
-
- jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
- jcp.mb = src_d.dims()[0];
- jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
- jcp.oc_without_padding = jcp.oc;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
- jcp.ih = src_d.dims()[2];
- jcp.iw = src_d.dims()[3];
- jcp.oh = diff_dst_d.dims()[2];
- jcp.ow = diff_dst_d.dims()[3];
- jcp.kh = diff_weights_d.dims()[with_groups + 2];
- jcp.kw = diff_weights_d.dims()[with_groups + 3];
- jcp.t_pad = cd.padding[0][0];
- jcp.l_pad = cd.padding[0][1];
- jcp.stride_h = cd.strides[0];
- jcp.stride_w = cd.strides[1];
- jcp.r_pad = nstl::max(
- 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
- jcp.b_pad = nstl::max(
- 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
- jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
- jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
- jcp.ohp = jcp.oh;
- jcp.owp = jcp.ow;
- jcp.with_bias = (cd.diff_bias_desc.format_kind != format_kind::undef);
- jcp.dilate_h = cd.dilates[0];
- jcp.dilate_w = cd.dilates[1];
-
- bool ok_to_pad_channels = jcp.ngroups == 1;
- if (ok_to_pad_channels) {
- jcp.oc = rnd_up(jcp.oc, simd_w);
- jcp.ic = rnd_up(jcp.ic, simd_w);
- }
-
- if (mayiuse(avx512_core))
- return status::unimplemented;
- if (!mayiuse(avx512_common))
- return status::unimplemented;
- else if (mayiuse(avx512_mic_4ops))
- jcp.ver = ver_4fma;
- else
- jcp.ver = ver_fma;
-
- if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
- is_winograd_faster_than_direct(jcp)))
- return status::unimplemented;
- // Winograd specific initialization
- jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
- jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
- jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
-
- // Winograd kernel works only for 3x3 convolution with stride 1
- if (jcp.ngroups != 1)
- return status::unimplemented;
- if ((jcp.kh != 3) || (jcp.kw != 3))
- return status::unimplemented;
- if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
- return status::unimplemented;
- if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
- return status::unimplemented;
- if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
- return status::unimplemented;
-
- format_tag_t dat_tag = nChw16c;
- format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o;
- jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
- jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
- jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag);
-
- if (jcp.src_tag != dat_tag) return status::unimplemented;
- if (jcp.wei_tag != wei_tag) return status::unimplemented;
- if (jcp.dst_tag != dat_tag) return status::unimplemented;
-
- bool layout_consistency = true
- && jcp.ic <= src_d.padded_dims()[1]
- && jcp.oc <= diff_dst_d.padded_dims()[1]
- && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
- && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
- if (!layout_consistency) return status::unimplemented;
-
- /*************************** New Kernel Parameters
- * *****************************/
- jcp.ic_simd_block = simd_w;
- jcp.oc_simd_block = simd_w;
- jcp.dimK_4fma = 1;
- jcp.tile_4fma_padding = 0;
-
-#define MAX_4FMA_UR 8
- if (jcp.ver == ver_4fma) {
- auto test_cond_4fma = [](jit_conv_winograd_conf_t &jcp, int dimK_4fma,
- int current_best) {
- return (dimK_4fma % 4 == 0) && (dimK_4fma <= MAX_4FMA_UR)
- && (dimK_4fma > current_best);
- };
- jcp.dimK_4fma = get_divisor_satisfying_cond(
- jcp, jcp.itiles * jcp.jtiles, 4, test_cond_4fma);
- if (jcp.dimK_4fma == 1)
- jcp.dimK_4fma = 4;
- if ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma != 0)
- jcp.tile_4fma_padding = jcp.dimK_4fma
- - ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma);
- }
-
- jcp.tile_4fma = jcp.dimK_4fma;
- /*NOTE: When (itiles * jtiles) % dimK_4fma != 0, transpose in diff_src
- * transform
- * will not work correctly, this is solved by applying padding.*/
- jcp.dimK = jcp.mb * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding);
- jcp.dimN = jcp.ic;
- jcp.dimM = jcp.oc;
-
- jcp.double_buffering = true;
- if (jcp.double_buffering)
- jcp.zmm_start = jcp.ver == ver_4fma ? 8 : 2;
- else
- jcp.zmm_start = jcp.ver == ver_4fma ? 4 : 1;
- jcp.nb_reg = 32 - jcp.zmm_start;
-
- jcp.sched_policy = WSCHED_INVALID;
- status_t res = set_wsched_WEI_S_D_G_W_avx512_common(jcp);
- assert(jcp.sched_policy == WSCHED_WEI_S_D_G_W);
-
- jcp.tile_block_ur = jcp.dimK_reg_block;
- jcp.nb_tile_block_ur = jcp.dimK_block;
- jcp.tile_block = jcp.dimK_nb_block;
-
- jcp.ic_block = jcp.dimN_block;
- jcp.nb_ic = jcp.dimN_nb_block;
-
- jcp.oc_block = jcp.dimM_block;
- jcp.nb_oc = jcp.dimM_nb_block;
-
- return res;
-
-}
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp
deleted file mode 100644
index 6c117143f5..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp
+++ /dev/null
@@ -1,179 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP
-#define JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP
-
-#include "c_types_map.hpp"
-#include "cpu_memory.hpp"
-
-#include "jit_generator.hpp"
-#include "jit_primitive_conf.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-//alpha determines the output tile_size
-constexpr int alpha = 6;
-constexpr int tile_size = 4;
-//simd length used for vectorization
-constexpr int simd_w = 16;
-
-struct _jit_avx512_common_conv_winograd_data_kernel_f32 : public jit_generator {
- _jit_avx512_common_conv_winograd_data_kernel_f32(
- jit_conv_winograd_conf_t ajcp)
- : jcp(ajcp)
- {
- //******************* First iter kernel ********************//
- this->gemm_loop_generate(true);
- gemm_loop_ker_first_iter
- = (decltype(gemm_loop_ker_first_iter)) this->getCode();
-
- //************** Subsequent iterations kernel **************//
- if (jcp.dimK_nb_block > 1) {
- align();
- const Xbyak::uint8 *addr = getCurr();
- this->gemm_loop_generate(false);
- gemm_loop_ker = (decltype(gemm_loop_ker))addr;
- }
- }
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_data_kernel_f32)
-
- static status_t init_conf_common(jit_conv_winograd_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d);
-
- static status_t init_conf_kernel(
- jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK);
-
- jit_conv_winograd_conf_t jcp;
- void (*gemm_loop_ker)(float *, const float *, const float *);
- void (*gemm_loop_ker_first_iter)(float *, const float *, const float *);
-
-protected:
- using reg64_t = const Xbyak::Reg64;
- enum { typesize = sizeof(float) };
-
- void gemm_loop_generate(bool is_beta_zero);
-
- /* registers used for GEMM */
- reg64_t reg_dstC = abi_param1;
- reg64_t reg_srcA = abi_param2;
- reg64_t reg_srcB = abi_param3;
-
- reg64_t reg_dimM_block_loop_cnt = r10;
- reg64_t reg_dimK_block_loop_cnt = r11;
-};
-
-struct jit_avx512_common_conv_winograd_fwd_kernel_f32
- : _jit_avx512_common_conv_winograd_data_kernel_f32 {
- using _jit_avx512_common_conv_winograd_data_kernel_f32::
- _jit_avx512_common_conv_winograd_data_kernel_f32;
-
- static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr);
-
- static status_t init_conf(jit_conv_winograd_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
-};
-
-struct jit_avx512_common_conv_winograd_bwd_data_kernel_f32
- : public _jit_avx512_common_conv_winograd_data_kernel_f32 {
- using _jit_avx512_common_conv_winograd_data_kernel_f32::
- _jit_avx512_common_conv_winograd_data_kernel_f32;
-
- static status_t init_conf(jit_conv_winograd_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &diff_dst_d);
-};
-
-struct jit_avx512_common_conv_winograd_bwd_weights_kernel_f32
- : public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_bwd_weights_kernel_f32)
-
- jit_avx512_common_conv_winograd_bwd_weights_kernel_f32(
- jit_conv_winograd_conf_t ajcp)
- : jcp(ajcp)
- {
-
- //******************* First iter kernel ********************//
- {
- align();
- const Xbyak::uint8 *addr = getCurr();
- this->gemm_loop_generate(true);
- gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))addr;
- }
-
- if (jcp.tile_block > 1) {
- align();
- const Xbyak::uint8 *addr = getCurr();
- this->gemm_loop_generate(false);
- gemm_loop_ker = (decltype(gemm_loop_ker))addr;
- }
-
- if (jcp.ver == ver_4fma) {
- align();
- const Xbyak::uint8 *addr = getCurr();
- this->transpose_ker_generate();
- transpose_4fma_ker = (decltype(transpose_4fma_ker))addr;
- }
- }
-
- static status_t init_conf(jit_conv_winograd_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &diff_dst_d,
- const memory_desc_wrapper &diff_weights_d);
-
- jit_conv_winograd_conf_t jcp;
- void (*gemm_loop_ker)(float *, const float *, const float *);
- void (*gemm_loop_ker_first_iter)(float *, const float *, const float *);
- void (*transpose_4fma_ker)(float *, float *);
-
-private:
- using reg64_t = const Xbyak::Reg64;
- enum { typesize = sizeof(float) };
-
- void gemm_loop_generate(bool is_first_tile);
- void transpose_ker_generate();
-
- reg64_t reg_origB = abi_param2;
- reg64_t reg_transB = abi_param1;
-
- reg64_t reg_dstC = abi_param1;
- reg64_t reg_srcA_const = abi_param2;
- reg64_t reg_srcB = abi_param3;
-
- reg64_t reg_sp = rsp;
- reg64_t reg_srcA = r9;
- reg64_t reg_nb_ic = r10;
- reg64_t reg_loop_cpt = r11;
- reg64_t reg_transB_idx = r13;
-
- /* Registers used by new kernel */
- reg64_t reg_dimM_block_loop_cnt = r10;
- reg64_t reg_dimK_block_loop_cnt = r12;
- reg64_t reg_dimN_block_loop_cnt = r11;
-};
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp
deleted file mode 100644
index abddc19221..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp
+++ /dev/null
@@ -1,1526 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "jit_avx512_common_convolution.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::memory_tracking::names;
-using namespace mkldnn::impl::utils;
-
-using namespace nstl;
-
-using jit_conv_ker_t = void (*)(jit_conv_call_s *);
-
-#define PIPELINE(field) \
- do { \
- p.field = p.field ## _prf; \
- p.field ## _prf = field; \
- } while (0)
-
-inline void jit_conv_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p,
- const void *src, const void *dst, const void *filt, const void *bias,
- int channel, int kh_padding)
-{
- PIPELINE(src);
- PIPELINE(dst);
- PIPELINE(filt);
- PIPELINE(bias);
- PIPELINE(channel);
- PIPELINE(kh_padding);
-
- if (p.src)
- ker(&p);
-}
-// The special case for the driver with ow-parallelization (FWD)
-// TODO: implement it for BWD_D and BWD_W too
-inline void jit_conv_ker_pipeline_ow_thr(jit_conv_ker_t ker, jit_conv_call_s &p,
- const void *src, const void *dst, const void *filt, const void *bias,
- int channel, int kh_padding, int owb)
-{
- PIPELINE(src);
- PIPELINE(dst);
- PIPELINE(filt);
- PIPELINE(bias);
- PIPELINE(channel);
- PIPELINE(kh_padding);
- PIPELINE(owb);
-
- if (p.src)
- ker(&p);
-}
-
-inline void jit_conv_3d_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p,
- const void *src, const void *dst, const void *filt, const void *bias,
- int channel, int kh_padding, int kd_padding)
-{
- PIPELINE(src);
- PIPELINE(dst);
- PIPELINE(filt);
- PIPELINE(bias);
- PIPELINE(channel);
- PIPELINE(kh_padding);
- PIPELINE(kd_padding);
-
- if (p.src)
- ker(&p);
-}
-// The special case for the driver with ow-parallelization (FWD)
-// TODO: implement it for BWD_D and BWD_W too
-inline void jit_conv_3d_ker_pipeline_ow_thr(jit_conv_ker_t ker,
- jit_conv_call_s &p, const void *src, const void *dst, const void *filt,
- const void *bias, int channel, int kh_padding, int kd_padding, int owb)
-{
- PIPELINE(src);
- PIPELINE(dst);
- PIPELINE(filt);
- PIPELINE(bias);
- PIPELINE(channel);
- PIPELINE(kh_padding);
- PIPELINE(kd_padding);
- PIPELINE(owb);
-
- if (p.src)
- ker(&p);
-}
-
-void jit_conv_3d_ker_bwd_w_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p,
- const void *src, const void *dst, const void *filt, const void *bias,
- int channel, int d_index, int d_worksize,
- int kd_padding /* kd_work_size */, size_t kd_offset) {
- PIPELINE(src);
- PIPELINE(dst);
- PIPELINE(filt);
- PIPELINE(bias);
- PIPELINE(channel);
- PIPELINE(kd_padding);
- PIPELINE(d_worksize);
- PIPELINE(d_index);
- PIPELINE(kd_offset);
-
- if (p.src)
- ker(&p);
-}
-#define wht_blk_off(d, g, ...) \
- (pd()->with_groups() \
- ? (d).blk_off((g), __VA_ARGS__) \
- : (d).blk_off(__VA_ARGS__))
-
-template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
-void jit_avx512_common_convolution_fwd_t<src_type, wei_type,
- dst_type>::prepare_padded_bias(const dst_data_t *&bias,
- const memory_tracking::grantor_t &scratchpad) const {
- if (!pd()->wants_padded_bias()) return;
-
- auto padded_bias = scratchpad.template get<dst_data_t>(
- key_conv_padded_bias);
- utils::array_copy(padded_bias, bias, pd()->jcp_.oc_without_padding);
- utils::array_set(padded_bias + pd()->jcp_.oc_without_padding,
- (dst_data_t)0, pd()->jcp_.oc - pd()->jcp_.oc_without_padding);
- bias = padded_bias;
-}
-
-template <data_type_t src_type, data_type_t wei_type,
- data_type_t dst_type>
-void jit_avx512_common_convolution_fwd_t<src_type, wei_type, dst_type>::
-execute_forward_1d(const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
-
- prepare_padded_bias(bias, this->scratchpad(ctx));
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
-
- const auto &jcp = pd()->jcp_;
- assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
-
- int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
- int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.nb_ow;
-
- int nthr;
- if (jcp.aligned_threads)
- nthr = jcp.aligned_threads;
- else
- nthr = mkldnn_get_max_threads();
-
- parallel(nthr, [&](const int ithr, const int nthr) {
- int start{0}, end{0}, start_copy;
- balance211(work_amount, nthr, ithr, start, end);
- start_copy = start;
-
- auto par_conv = jit_conv_call_s();
- size_t src_c_stride = src_d.blk_off(0, 1);
- size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
-
- for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
- start = start_copy;
- int n{0}, g{0}, occ{0}, owb{0};
-
- if (jcp.loop_order == loop_cwgn) {
- int dummy{0};
- nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow,
- g, jcp.ngroups, n, jcp.mb, dummy, 1);
- } else if (jcp.loop_order == loop_gncw) {
- int dummy{0};
- nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, occ,
- oc_chunks, owb, jcp.nb_ow, dummy, 1);
- } else {
- assert(!"unsupported loop order");
- }
-
- while (start < end) {
- int ocb = occ * jcp.nb_oc_blocking;
- int g_ocb = g * jcp.nb_oc + ocb;
- int g_oc = g_ocb * jcp.oc_block;
- int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
-
- int ow_s = owb * jcp.ow_block;
- int iw_s = ow_s * jcp.stride_w;
- auto bias_w = bias ? bias + g_oc : nullptr;
- auto dst_w = dst + dst_d.blk_off(n, g_ocb, ow_s);
- auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, iw_s);
- auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2);
-
- for (int icb = icb_l2;
- icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) {
- jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv,
- src_w, dst_w, wht_w, bias_w, icb, 1, owb);
-
- src_w += src_c_stride;
- wht_w += wht_ic_stride;
- }
- if (jcp.loop_order == loop_cwgn) {
- int dummy{0};
- nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow,
- g, jcp.ngroups, n, jcp.mb, dummy, 1);
- } else if (jcp.loop_order == loop_gncw) {
- int dummy{0};
- nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb,
- occ, oc_chunks, owb, jcp.nb_ow, dummy, 1);
- } else {
- assert(!"unsupported loop order");
- }
- }
- }
- jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv,
- src, dst, weights, bias, 0, 0, 0);
- });
-}
-
-template <data_type_t src_type, data_type_t wei_type,
- data_type_t dst_type>
-void jit_avx512_common_convolution_fwd_t<src_type, wei_type, dst_type>::
-execute_forward_2d(const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
-
- prepare_padded_bias(bias, this->scratchpad(ctx));
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
-
- const auto &jcp = pd()->jcp_;
- assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
-
- int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
- int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh * jcp.nb_ow;
-
- int nthr;
- if (jcp.aligned_threads)
- nthr = jcp.aligned_threads;
- else
- nthr = mkldnn_get_max_threads();
-
- parallel(nthr, [&](const int ithr, const int nthr) {
- int start{0}, end{0}, start_copy;
- balance211(work_amount, nthr, ithr, start, end);
- start_copy = start;
-
- auto par_conv = jit_conv_call_s();
- size_t src_h_stride = src_d.blk_off(0, 0, 1);
- size_t src_c_stride = src_d.blk_off(0, 1);
- size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
- size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
- size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
-
- for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
- start = start_copy;
- int n{0}, g{0}, occ{0}, oh_s{0}, owb{0};
-
- if (jcp.loop_order == loop_cwgn)
- nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow,
- g, jcp.ngroups, n, jcp.mb, oh_s, jcp.oh);
- else if (jcp.loop_order == loop_gncw)
- nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb,
- occ, oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
- else
- assert(!"unsupported loop order");
-
- while (start < end) {
- int ocb = occ * jcp.nb_oc_blocking;
- int g_ocb = g * jcp.nb_oc + ocb;
- int g_oc = g_ocb * jcp.oc_block;
- int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
-
- int work_rem = end - start;
-
- int ow_s = owb * jcp.ow_block;
- int iw_s = ow_s * jcp.stride_w;
- int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
- auto bias_w = bias ? bias + g_oc : nullptr;
-
- for (int oh_b = oh_s; oh_b < oh_e; oh_b += jcp.h_blocking) {
- int ih_b = -jcp.t_pad + oh_b * jcp.stride_h;
-
- auto dst_w = dst + dst_d.blk_off(n, g_ocb, oh_b, ow_s);
- auto src_w
- = src + src_d.blk_off(n, g_icb + icb_l2, ih_b, iw_s);
- auto wht_w
- = weights + wht_blk_off(weights_d, g, ocb, icb_l2);
-
- for (int icb = icb_l2;
- icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2);
- ++icb) {
- auto src_c = src_w;
- auto dst_c = dst_w;
- for (int oj = oh_b, ij = ih_b;
- oj < min(oh_e, oh_b + jcp.h_blocking);
- ++oj, ij += jcp.stride_h) {
- int dilate_h = jcp.dilate_h + 1;
- int i_t_overflow = div_up(max(0, -ij), dilate_h);
- int i_b_overflow = div_up(max(0, ij - jcp.ih
- + (jcp.kh - 1) * dilate_h + 1), dilate_h);
- int kh_padding = nstl::max(
- 0, jcp.kh - i_t_overflow - i_b_overflow);
-
- auto aux_src = src_c
- + i_t_overflow * dilate_h * src_h_stride;
- auto aux_wht = wht_w + i_t_overflow * wht_h_stride;
-
- jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker,
- par_conv, aux_src, dst_c, aux_wht, bias_w, icb,
- kh_padding, owb);
-
- src_c += src_h_stride * jcp.stride_h;
- dst_c += dst_h_stride;
- }
- src_w += src_c_stride;
- wht_w += wht_ic_stride;
- }
- }
-
- if (jcp.loop_order == loop_cwgn)
- nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow,
- g, jcp.ngroups, n, jcp.mb, oh_s, jcp.oh);
- else if (jcp.loop_order == loop_gncw)
- nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, occ,
- oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
- else
- assert(!"unsupported loop order");
- }
- }
-
- jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv,
- src, dst, weights, bias, 0, 0, 0);
- });
-}
-
-template <data_type_t src_type, data_type_t wei_type,
- data_type_t dst_type>
-void jit_avx512_common_convolution_fwd_t<src_type, wei_type, dst_type>::
-execute_forward_3d(const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
-
- prepare_padded_bias(bias, this->scratchpad(ctx));
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
- const memory_desc_wrapper bias_d(pd()->weights_md(1));
-
- const auto &jcp = pd()->jcp_;
- assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
-
- parallel(0, [&](const int ithr, const int nthr) {
- int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
- int start{0}, end{0}, start_copy;
- int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.od * jcp.oh
- * jcp.nb_ow;
- balance211(work_amount, nthr, ithr, start, end);
- start_copy = start;
-
- auto par_conv = jit_conv_call_s();
- size_t src_d_stride = src_d.blk_off(0, 0, 1);
- size_t src_h_stride = src_d.blk_off(0, 0, 0, 1);
- size_t src_c_stride = src_d.blk_off(0, 1);
- size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1);
- size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
- size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
- size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
-
- for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
- start = start_copy;
- int n{0}, g{0}, occ{0}, oh_s{0}, od_s{0}, owb{0};
-
- if (jcp.loop_order == loop_cwgn)
- nd_iterator_init(start,
- occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, jcp.mb,
- od_s, jcp.od, oh_s, jcp.oh);
- else if (jcp.loop_order == loop_gncw)
- nd_iterator_init(start,
- g, jcp.ngroups, n, jcp.mb, occ, oc_chunks, owb, jcp.nb_ow,
- od_s, jcp.od, oh_s, jcp.oh);
- else
- assert(!"unsupported loop order");
-
- while (start < end) {
- int ocb = occ * jcp.nb_oc_blocking;
- int g_ocb = g * jcp.nb_oc + ocb;
- int g_oc = g_ocb * jcp.oc_block;
- int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
-
- int work_rem = end - start;
- int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
- int ow_s = owb * jcp.ow_block;
- int iw_s = ow_s * jcp.stride_w;
- int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
-
- int id_s = -jcp.f_pad + od_s * jcp.stride_d;
-
- int dilate_d = jcp.dilate_d + 1;
- int d_t_overflow = div_up(max(0, -id_s), dilate_d);
- int d_b_overflow = div_up(
- max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1),
- dilate_d);
- int kd_padding = nstl::max(0,
- jcp.kd - d_t_overflow - d_b_overflow);
-
- auto bias_w = bias ? bias + bias_d.blk_off(g_oc) : 0;
- auto dst_w = dst + dst_d.blk_off(n, g_ocb, od_s, oh_s, ow_s);
- auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, id_s, ih_s,
- iw_s) + d_t_overflow * dilate_d * src_d_stride;
- auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2)
- + d_t_overflow * wht_d_stride;
-
- for (int icb = icb_l2;
- icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) {
- auto src_c = src_w;
- auto dst_c = dst_w;
- for (int oj = oh_s, ij = ih_s;
- oj < oh_e; ++oj, ij += jcp.stride_h)
- {
- int dilate_h = jcp.dilate_h + 1;
- int i_t_overflow = div_up(max(0, -ij), dilate_h);
- int i_b_overflow = div_up(
- max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h
- + 1),
- dilate_h);
- int kh_padding = nstl::max(0,
- jcp.kh - i_t_overflow - i_b_overflow);
- jit_conv_3d_ker_pipeline_ow_thr(kernel_->jit_ker,
- par_conv,
- src_c + i_t_overflow * dilate_h * src_h_stride,
- dst_c, wht_w + i_t_overflow * wht_h_stride,
- bias_w, icb, kh_padding, kd_padding, owb);
-
- src_c += src_h_stride * jcp.stride_h;
- dst_c += dst_h_stride;
- }
- src_w += src_c_stride;
- wht_w += wht_ic_stride;
- }
-
- if (jcp.loop_order == loop_cwgn)
- nd_iterator_jump(start, end,
- occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, jcp.mb,
- od_s, jcp.od, oh_s, jcp.oh);
- else if (jcp.loop_order == loop_gncw)
- nd_iterator_jump(start, end,
- g, jcp.ngroups, n, jcp.mb, occ, oc_chunks, owb, jcp.nb_ow,
- od_s, jcp.od, oh_s, jcp.oh);
- else
- assert(!"unsupported loop order");
- }
- }
- jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv,
- src, dst, weights, bias, 0, 0, 0);
- });
-}
-
-template struct jit_avx512_common_convolution_fwd_t<data_type::f32>;
-
-template <data_type_t diff_dst_type, data_type_t wei_type,
- data_type_t diff_src_type>
-void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
- diff_src_type>::execute_backward_data_1d(const exec_ctx_t &ctx) const
-{
- auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
-
- const auto &jcp = kernel_->jcp;
-
- parallel(0, [&](const int ithr, const int nthr) {
- int start{0}, end{0}, start_copy;
- int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
- int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih;
- balance211(work_amount, nthr, ithr, start, end);
- start_copy = start;
-
- auto par_conv = jit_conv_call_s();
- size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
- size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
-
- for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
- start = start_copy;
- int n{0}, g{0}, icc{0};
- if (jcp.loop_order == loop_cgn) {
- int dummy{0};
- nd_iterator_init(start, icc, ic_chunks, g, jcp.ngroups, n,
- jcp.mb, dummy, 1);
- } else if (jcp.loop_order == loop_gnc) {
- int dummy{0};
- nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, icc,
- ic_chunks, dummy, 1);
- } else {
- assert(!"unsupported loop order");
- }
-
- while (start < end) {
- int icb = icc * jcp.nb_ic_blocking;
- int g_icb = g * jcp.nb_ic + icb;
- int g_ocb = g * jcp.nb_oc;
-
- auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb);
- auto diff_dst_w = diff_dst
- + diff_dst_d.blk_off(n, g_ocb + ocb_l2);
- auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb);
-
- for (int ocb = ocb_l2;
- ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) {
- jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
- diff_src_w, diff_dst_w, wht_w, 0, ocb, 1);
- diff_dst_w += diff_dst_c_stride;
- wht_w += wht_oc_stride;
- }
-
- if (jcp.loop_order == loop_cgn) {
- int dummy{0};
- nd_iterator_jump(start, end, icc, ic_chunks, g, jcp.ngroups,
- n, jcp.mb, dummy, 1);
- } else if (jcp.loop_order == loop_gnc) {
- int dummy{0};
- nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, icc,
- ic_chunks, dummy, 1);
- } else {
- assert(!"unsupported loop order");
- }
- }
- }
-
- jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
- diff_src, diff_dst, weights, 0, 0, 1);
- });
-}
-
-template <data_type_t diff_dst_type, data_type_t wei_type,
- data_type_t diff_src_type>
-void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
- diff_src_type>::execute_backward_data_2d(const exec_ctx_t &ctx) const
-{
- auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
-
- const auto &jcp = kernel_->jcp;
-
- parallel(0, [&](const int ithr, const int nthr) {
- int start{0}, end{0}, start_copy;
- int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
- int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih;
- balance211(work_amount, nthr, ithr, start, end);
- start_copy = start;
-
- auto par_conv = jit_conv_call_s();
- size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 1);
- size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 1);
- size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
- size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
- size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
-
- bool is_fast_path = jcp.dilate_h == 0 && jcp.stride_h == 1;
-
- for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
- start = start_copy;
- int n{0}, g{0}, icc{0}, ih_s{0};
- if (jcp.loop_order == loop_cgn)
- nd_iterator_init(start,
- icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, ih_s, jcp.ih);
- else if (jcp.loop_order == loop_gnc)
- nd_iterator_init(start,
- g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, ih_s, jcp.ih);
- else
- assert(!"unsupported loop order");
-
- while (start < end) {
- int icb = icc * jcp.nb_ic_blocking;
- int g_icb = g * jcp.nb_ic + icb;
- int g_ocb = g * jcp.nb_oc;
-
- int work_rem = end - start;
- int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem;
-
- auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb);
- auto diff_dst_w = diff_dst
- + diff_dst_d.blk_off(n, g_ocb + ocb_l2);
- auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb);
-
- for (int ocb = ocb_l2;
- ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) {
- for (int ij = ih_s; ij < ih_e; ++ij) {
- int oj, k_len, k_lo;
- if (is_fast_path) { // dilate == 0 && stride == 1
- int i_t_overflow = max(0, jcp.kh - 1 - ij
- - jcp.t_pad);
- int i_b_overflow = max(0, jcp.kh - jcp.ih + ij
- - jcp.b_pad);
- k_len = jcp.kh - i_t_overflow - i_b_overflow;
- k_lo = i_b_overflow;
- oj = ij + jcp.t_pad - i_b_overflow;
- } else if (jcp.dilate_h != 0) { // stride == 1
- int dilate_h = jcp.dilate_h + 1;
- // Note: use div_up to account for "holes" in filter
- int i_t_overflow
- = div_up(max(0, (jcp.kh - 1) * dilate_h
- - ij - jcp.t_pad), dilate_h);
- int i_b_overflow
- = div_up(max(0, (jcp.kh - 1) * dilate_h + 1
- - jcp.ih + ij - jcp.b_pad), dilate_h);
- k_len = jcp.kh - i_t_overflow - i_b_overflow;
- k_lo = i_b_overflow;
- oj = ij + jcp.t_pad - i_b_overflow * dilate_h;
- } else { // dilate == 0
- int i_t_overflow = max(0, (jcp.kh - 1 - ij
- - jcp.t_pad) / jcp.stride_h);
- int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij
- - jcp.b_pad) / jcp.stride_h);
- int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1
- + jcp.b_pad - ij) % jcp.stride_h);
- int overflow_kh_lo = (ij + jcp.t_pad)
- % jcp.stride_h;
-
- k_len = (overflow_kh_hi - overflow_kh_lo)
- / jcp.stride_h + 1 - i_t_overflow
- - i_b_overflow;
- k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h;
- oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h;
- }
- assert(k_len >= 0);
-
- jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
- diff_src_w + ij * diff_src_h_stride,
- diff_dst_w + oj * diff_dst_h_stride,
- wht_w + k_lo * wht_h_stride,
- 0, ocb, k_len);
- }
- diff_dst_w += diff_dst_c_stride;
- wht_w += wht_oc_stride;
- }
-
- if (jcp.loop_order == loop_cgn)
- nd_iterator_jump(start, end,
- icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, ih_s, jcp.ih);
- else if (jcp.loop_order == loop_gnc)
- nd_iterator_jump(start, end,
- g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, ih_s, jcp.ih);
- else
- assert(!"unsupported loop order");
- }
- }
-
- jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
- diff_src, diff_dst, weights, 0, 0, 1);
- });
-}
-
-template <data_type_t diff_dst_type, data_type_t wei_type,
- data_type_t diff_src_type>
-void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
- diff_src_type>::execute_backward_data_3d(const exec_ctx_t &ctx) const
-{
- auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
-
- const auto &jcp = kernel_->jcp;
-
- parallel(0, [&](const int ithr, const int nthr) {
- int start{0}, end{0}, start_copy;
- int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
- int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.id * jcp.ih;
- balance211(work_amount, nthr, ithr, start, end);
- start_copy = start;
-
- auto par_conv = jit_conv_call_s();
- size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 0, 1);
- size_t diff_src_d_stride = diff_src_d.blk_off(0, 0, 1);
- size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 0, 1);
- size_t diff_dst_d_stride = diff_dst_d.blk_off(0, 0, 1);
- size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
- size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
- size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
- size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
-
- bool is_fast_path_d = jcp.dilate_d == 0 && jcp.stride_d == 1;
- bool is_fast_path_h = jcp.dilate_h == 0 && jcp.stride_h == 1;
-
- for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
- start = start_copy;
- int n{0}, g{0}, icc{0}, ih_s{0}, id_s{0};
- if (jcp.loop_order == loop_cgn)
- nd_iterator_init(start,
- icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, id_s, jcp.id,
- ih_s, jcp.ih);
- else if (jcp.loop_order == loop_gnc)
- nd_iterator_init(start,
- g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, id_s, jcp.id,
- ih_s, jcp.ih);
- else
- assert(!"unsupported loop order");
-
- while (start < end) {
- int icb = icc * jcp.nb_ic_blocking;
- int g_icb = g * jcp.nb_ic + icb;
- int g_ocb = g * jcp.nb_oc;
-
- int work_rem = end - start;
- int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem;
- int d_len = 0, d_lo = 0, d_oj = 0;
- if (is_fast_path_d) { // dilate == 0 && stride == 1
- int d_t_overflow = max(0, jcp.kd - 1 - id_s
- - jcp.f_pad);
- int d_b_overflow = max(0, jcp.kd - jcp.id + id_s
- - jcp.back_pad);
- d_len = jcp.kd - d_t_overflow - d_b_overflow;
- d_lo = d_b_overflow;
- d_oj = id_s + jcp.f_pad - d_b_overflow;
- } else if (jcp.dilate_d != 0) { // stride == 1
- int dilate_d = jcp.dilate_d + 1;
- // Note: use div_up to account for "holes" in filter
- int d_t_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d
- - id_s - jcp.f_pad), dilate_d);
- int d_b_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d + 1
- - jcp.id + id_s - jcp.back_pad), dilate_d);
- d_len = jcp.kd - d_t_overflow - d_b_overflow;
- d_lo = d_b_overflow;
- d_oj = id_s + jcp.f_pad - d_b_overflow * dilate_d;
- } else { // dilate == 0
- int d_t_overflow = max(0, (jcp.kd - 1 - id_s
- - jcp.f_pad) / jcp.stride_d);
- int d_b_overflow = max(0, (jcp.kd - jcp.id + id_s
- - jcp.back_pad) / jcp.stride_d);
- int overflow_kd_hi = jcp.kd - 1 - abs((jcp.id - 1
- + jcp.back_pad - id_s) % jcp.stride_d);
- int overflow_kd_lo = (id_s + jcp.f_pad)
- % jcp.stride_d;
-
- d_len = (overflow_kd_hi - overflow_kd_lo)
- / jcp.stride_d + 1 - d_t_overflow
- - d_b_overflow;
- d_lo = overflow_kd_lo + d_b_overflow * jcp.stride_d;
- d_oj = (id_s + jcp.f_pad - d_lo) / jcp.stride_d;
- }
- assert(d_len >= 0);
-
- auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb)
- + id_s * diff_src_d_stride;
- auto diff_dst_w = diff_dst
- + diff_dst_d.blk_off(n, g_ocb + ocb_l2)
- + d_oj * diff_dst_d_stride;
- auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb)
- + d_lo * wht_d_stride;
-
- for (int ocb = ocb_l2;
- ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) {
- for (int ij = ih_s; ij < ih_e; ++ij) {
- int oj, k_len, k_lo;
- if (is_fast_path_h) { // dilate == 0 && stride == 1
- int i_t_overflow = max(0, jcp.kh - 1 - ij
- - jcp.t_pad);
- int i_b_overflow = max(0, jcp.kh - jcp.ih + ij
- - jcp.b_pad);
- k_len = jcp.kh - i_t_overflow - i_b_overflow;
- k_lo = i_b_overflow;
- oj = ij + jcp.t_pad - i_b_overflow;
- } else if (jcp.dilate_h != 0) { // stride == 1
- int dilate_h = jcp.dilate_h + 1;
- // Note: use div_up to account for "holes" in filter
- int i_t_overflow
- = div_up(max(0, (jcp.kh - 1) * dilate_h
- - ij - jcp.t_pad), dilate_h);
- int i_b_overflow
- = div_up(max(0, (jcp.kh - 1) * dilate_h + 1
- - jcp.ih + ij - jcp.b_pad), dilate_h);
- k_len = jcp.kh - i_t_overflow - i_b_overflow;
- k_lo = i_b_overflow;
- oj = ij + jcp.t_pad - i_b_overflow * dilate_h;
- } else { // dilate == 0
- int i_t_overflow = max(0, (jcp.kh - 1 - ij
- - jcp.t_pad) / jcp.stride_h);
- int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij
- - jcp.b_pad) / jcp.stride_h);
- int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1
- + jcp.b_pad - ij) % jcp.stride_h);
- int overflow_kh_lo = (ij + jcp.t_pad)
- % jcp.stride_h;
-
- k_len = (overflow_kh_hi - overflow_kh_lo)
- / jcp.stride_h + 1 - i_t_overflow
- - i_b_overflow;
- k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h;
- oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h;
- }
- assert(k_len >= 0);
-
- jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv,
- diff_src_w + ij * diff_src_h_stride,
- diff_dst_w + oj * diff_dst_h_stride,
- wht_w + k_lo * wht_h_stride,
- 0, ocb, k_len, d_len);
- }
- diff_dst_w += diff_dst_c_stride;
- wht_w += wht_oc_stride;
- }
-
- if (jcp.loop_order == loop_cgn)
- nd_iterator_jump(start, end,
- icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, id_s, jcp.id,
- ih_s, jcp.ih);
- else if (jcp.loop_order == loop_gnc)
- nd_iterator_jump(start, end,
- g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, id_s, jcp.id,
- ih_s, jcp.ih);
- else
- assert(!"unsupported loop order");
- }
- }
-
- jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv,
- diff_src, diff_dst, weights, 0, 0, 1, 1);
- });
-}
-
-template struct jit_avx512_common_convolution_bwd_data_t<data_type::f32>;
-
-template <data_type_t src_type, data_type_t diff_dst_type,
- data_type_t diff_weights_type>
-jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
- diff_weights_type>::
-jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd)
- : cpu_primitive_t(apd), kernel_(nullptr)
- , trans_kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr)
-{
- const auto &j = pd()->jcp_;
-
- nthr_ = j.nthr;
- nthr_mb_ = j.nthr_mb;
- nthr_g_ = j.nthr_g;
- nthr_oc_b_ = j.nthr_oc_b;
- nthr_ic_b_ = j.nthr_ic_b;
-
- kernel_ = new jit_avx512_common_conv_bwd_weights_kernel_f32(j);
-
- if (j.ver == ver_4fma)
- trans_kernel_ = create_trans_src(&j);
-
- if (nthr_mb_ > 1)
- acc_ker_ = new cpu_accumulator_1d_t<diff_weights_type>();
-
- reducer_bias_ =
- new cpu_reducer_t<diff_weights_type>(pd()->reducer_bia_conf_);
-}
-
-template <data_type_t src_type, data_type_t diff_dst_type,
- data_type_t diff_weights_type>
-struct jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
- diff_weights_type>::thread_info_t {
- const src_data_t *src;
- const diff_dst_data_t *diff_dst;
- const diff_weights_data_t *diff_weights;
- diff_weights_data_t *diff_bias;
-
- const memory_tracking::grantor_t scratchpad;
-
- src_data_t *tr_src;
- simple_barrier::ctx_t *tr_src_bctx;
-
- diff_dst_data_t *tr_diff_dst;
- simple_barrier::ctx_t *tr_diff_dst_bctx;
-
- diff_weights_data_t *wei_bia_reduction;
- simple_barrier::ctx_t *wei_bia_reduction_bctx;
-
- int ithr;
- int ithr_ic_b, ithr_oc_b, ithr_g, ithr_mb;
- int ithr_but_oc;
- int ithr_but_ic;
-
- int img_start = 0, img_end = 0, img_work;
- int g_start = 0, g_end = 0, g_work;
- int oc_b_start = 0, oc_b_end = 0, oc_b_work;
- int ic_b_start = 0, ic_b_end = 0, ic_b_work;
-
- thread_info_t(const jit_avx512_common_convolution_bwd_weights_t *self,
- const exec_ctx_t &ctx, int ithr)
- : scratchpad(self->scratchpad(ctx)), ithr(ithr)
- {
- diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
- src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
- diff_weights = CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
- diff_bias = self->pd()->wants_padded_bias()
- ? scratchpad.template get<diff_weights_data_t>(
- key_conv_padded_bias)
- : CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_BIAS);
-
- tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
- tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
- key_conv_tr_src_bctx);
-
- tr_diff_dst = scratchpad.template get<diff_dst_data_t>(
- key_conv_tr_diff_dst);
- tr_diff_dst_bctx = scratchpad.template get<simple_barrier::ctx_t>(
- key_conv_tr_diff_dst_bctx);
-
- wei_bia_reduction = scratchpad.template get<diff_weights_data_t>(
- key_conv_wei_bia_reduction);
- wei_bia_reduction_bctx = scratchpad.template get<simple_barrier::ctx_t>(
- key_conv_wei_bia_reduction_bctx);
-
- ithr_ic_b = ithr % self->nthr_ic_b_;
- ithr_oc_b = ithr / self->nthr_ic_b_ % self->nthr_oc_b_;
- ithr_g = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ % self->nthr_g_;
- ithr_mb = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ / self->nthr_g_;
-
- ithr_but_oc = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_ic_b_
- + ithr_ic_b;
-
- ithr_but_ic = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_oc_b_
- + ithr_oc_b;
-
- const auto &jcp = self->kernel_->jcp;
-
- /* reduction dimension */
- balance211(jcp.mb*jcp.od, self->nthr_mb_, ithr_mb, img_start, img_end);
- img_work = img_end - img_start;
-
- /* independent dimensions */
- balance211(jcp.ngroups, self->nthr_g_, ithr_g, g_start, g_end);
- g_work = g_end - g_start;
-
- balance211(jcp.nb_oc, self->nthr_oc_b_, ithr_oc_b, oc_b_start,
- oc_b_end);
- oc_b_work = oc_b_end - oc_b_start;
-
- balance211(jcp.nb_ic, self->nthr_ic_b_, ithr_ic_b, ic_b_start,
- ic_b_end);
- ic_b_work = ic_b_end - ic_b_start;
- }
-};
-
-template <data_type_t src_type, data_type_t diff_dst_type,
- data_type_t diff_weights_type>
-void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
- diff_weights_type>::compute_diff_weights(const thread_info_t *ti) const {
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
-
- const auto &jcp = kernel_->jcp;
- const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh*jcp.kw*jcp.kd;
-
- diff_weights_data_t *diff_wei = ti->ithr_mb == 0
- ? (diff_weights_data_t*)ti->diff_weights
- : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
- diff_weights_data_t *diff_bia = ti->ithr_mb == 0
- ? (diff_weights_data_t*)ti->diff_bias
- : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size
- + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc;
-
- // TODO: use memory descriptor with the same fmt as src (or use a macro :))
- auto tr_src_off = [&](int ithr_mb, int ic, int ij) {
- const size_t tr_row_size = jcp.tr_iw * jcp.ic_block;
- const size_t tr_chn_size = tr_row_size * jcp.ih;
- const size_t tr_img_size = tr_chn_size * jcp.nb_ic * jcp.ngroups;
-
- return ti->ithr_mb * tr_img_size + ic * tr_chn_size + ij * tr_row_size;
- };
-
- auto uker_trans = [&](int img) {
- const int work_amount = ti->g_work * ti->ic_b_work * jcp.ih;
-
- int start{0}, end{0};
- balance211(work_amount, nthr_oc_b_, ti->ithr_oc_b, start, end);
- const int my_work = end - start;
-
- int g{0}, ic_b{0}, j{0};
- nd_iterator_init(start, g, ti->g_work, ic_b, ti->ic_b_work, j, jcp.ih);
- g += ti->g_start;
- ic_b += ti->ic_b_start;
-
- const int _ic = g * jcp.nb_ic + ic_b;
- src_data_t *src1 = (src_data_t*)&ti->src[src_d.blk_off(img, _ic, j)];
- src_data_t *tr_src1 = &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, j)];
-
- assert(jcp.ic_block == 16);
- const int src_stride = jcp.iw * jcp.ic_block;
- const int tr_src_stride = jcp.tr_iw * jcp.ic_block;
-
- const int pf_depth = 2;
- struct { src_data_t *src, *tr_src; } pf_circ_buf[pf_depth];
-
- for (int iwork = 0; iwork < my_work + pf_depth - 1; iwork++) {
- pf_circ_buf[iwork % pf_depth] = {src1, tr_src1};
-
- if (iwork >= pf_depth - 1) {
- int old_idx = (iwork - pf_depth + 1) % pf_depth;
- auto ctx = jit_trans_src_t::ctx_t();
- ctx.src = pf_circ_buf[old_idx].src;
- ctx.tr_src = pf_circ_buf[old_idx].tr_src;
- ctx.src_prf = src1;
- ctx.tr_src_prf = tr_src1;
- (*trans_kernel_)(&ctx);
- }
- src1 += src_stride;
- tr_src1 += tr_src_stride;
- }
-#if 0
- // reference transposition
- const int l_pad = jcp.l_pad;
- const int iwlp = l_pad + jcp.iw;
- const int tr_iw = jcp.tr_iw;
-
- for (size_t iwork = start; iwork < end; iwork++) {
- PRAGMA_OMP_SIMD()
-# pragma unroll
- for (int i = 0; i < l_pad; i++)
- for (int j = 0; j < jcp.ic_block; j++)
- tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0;
-
- PRAGMA_OMP_SIMD()
-# pragma unroll
- for (int i = l_pad; i < iwlp; i++)
- for (int j = 0; j < jcp.ic_block; j++)
- tr_src1[j * jcp.tr_iw + i]
- = (src_data_t)src1[(i - l_pad) * 16 + j];
-
- PRAGMA_OMP_SIMD()
-# pragma unroll
- for (int i = iwlp; i < tr_iw; i++)
- for (int j = 0; j < jcp.ic_block; j++)
- tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0;
-
- src1 += src_stride;
- tr_src1 += tr_src_stride;
- }
-#endif
- };
-
- if (jcp.is_1stconv && jcp.ver == ver_4fma) {
- /* prepare contexts */
- auto tr_ctx = jit_trans_src_t::ctx_t();
- tr_ctx.tr_src = ti->tr_src
- + ti->ithr_but_oc * jcp.ih * jcp.stride_w * jcp.tr_ld;
-
- assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_oc_b_ == 1));
- tr_ctx.nthr_oc_b = nthr_oc_b_;
- int ih_start{0}, ih_end{0};
- balance211(jcp.ih, nthr_oc_b_, ti->ithr_oc_b, ih_start, ih_end);
- tr_ctx.tr_src_ih_start = ih_start;
- tr_ctx.tr_src_ih_end = ih_end;
- tr_ctx.tr_src_bctx = ti->tr_src_bctx + ti->ithr_but_oc;
-
- auto p = jit_conv_call_s();
- p.src = tr_ctx.tr_src;
-
- /* zero diff_bias if applicable */
- if (jcp.with_bias && ti->ithr_ic_b == 0) {
- assert(jcp.oc_block == 16);
- for (int oc_b = ti->ic_b_start; oc_b < ti->oc_b_end; ++oc_b) {
- diff_weights_data_t *db = &diff_bia[oc_b * 16];
- for (int o = 0; o < 16; ++o)
- db[o] = 0;
- }
- }
-
- for (int img = ti->img_start; img < ti->img_end; ++img) {
- p.flags = (img == ti->img_start) * FLAG_MB_FIRST;
-
- for (int g = ti->g_start; g < ti->g_end; ++g) {
- for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) {
- const int _ic = g * jcp.nb_ic + ic_b;
- tr_ctx.src = &ti->src[src_d.blk_off(img, _ic)];
-
- (*trans_kernel_)(&tr_ctx);
-
- if (ic_b == 0)
- p.flags |= FLAG_IC_FIRST;
- else
- p.flags &= ~FLAG_IC_FIRST;
-
- for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) {
- const int _oc = g * jcp.nb_oc + oc_b;
- p.dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)];
-
- const size_t off =
- wht_blk_off(diff_weights_d, g, oc_b, ic_b);
- p.filt = diff_wei + off;
- p.bias = diff_bia + _oc * jcp.oc_block;
-
- kernel_->jit_ker(&p);
- }
- }
- }
- }
- } else {
- for (int img = ti->img_start; img < ti->img_end; ++img) {
- auto p = jit_conv_call_s();
-
- if (jcp.ver == ver_4fma) {
- /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */
- using simple_barrier::barrier;
- if (nthr_oc_b_ > 1)
- barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
- uker_trans(img);
- if (nthr_oc_b_ > 1)
- barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
- }
-
- for (int g = ti->g_start; g < ti->g_end; ++g) {
- for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) {
- for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) {
- const int _oc = g * jcp.nb_oc + oc_b;
- const int _ic = g * jcp.nb_ic + ic_b;
-
- jit_conv_ker_pipeline(kernel_->jit_ker, p,
- jcp.ver == ver_4fma
- ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)]
- : &ti->src[src_d.blk_off(img, _ic)],
- &ti->diff_dst[diff_dst_d.blk_off(img, _oc)],
- diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b),
- 0, (img == ti->img_start), 0);
-
- }
- }
- }
-
- const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start;
- const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start;
- jit_conv_ker_pipeline(kernel_->jit_ker, p,
- jcp.ver == ver_4fma
- ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)]
- : &ti->src[src_d.blk_off(img + 1, _ic)],
- &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)],
- diff_wei + wht_blk_off(
- diff_weights_d, ti->g_start,
- ti->oc_b_start, ti->ic_b_start),
- 0, 0, 0);
- }
- }
-}
-
-template <data_type_t src_type, data_type_t diff_dst_type,
- data_type_t diff_weights_type>
-void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
- diff_weights_type>::compute_diff_weights_3d(const thread_info_t *ti) const
-{
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
-
- const auto &jcp = kernel_->jcp;
- const int wei_size
- = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw * jcp.kd;
-
- diff_weights_data_t *diff_wei = ti->ithr_mb == 0
- ? (diff_weights_data_t*)ti->diff_weights
- : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
- diff_weights_data_t *diff_bia = ti->ithr_mb == 0
- ? (diff_weights_data_t*)ti->diff_bias
- : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size
- + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc;
-
- const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
- const int input_step = jcp.ih * jcp.iw * inp_mult;
- const int output_step = jcp.ow * jcp.oh * jcp.oc_block;
- int img{0}, od_s{0};
- int img_start = ti->img_start, img_end = ti->img_end;
- nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od);
- const int img_first = img;
-
- while (img_start < img_end) {
- auto p = jit_conv_call_s();
-
- int work_rem = img_end - img_start;
- const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem;
- const int id_s = od_s * jcp.stride_d;
- const int ik_overlap = nstl::max(0, id_s - jcp.f_pad);
- const int kd_front_pad = nstl::max(0, jcp.f_pad - id_s);
- const int kd_back_pad
- = nstl::max(0, id_s - jcp.f_pad - jcp.id + jcp.kd);
- int kd_pad_off = nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw
- * jcp.ic_block * jcp.oc_block * jcp.typesize_out;
-
- for (int g = ti->g_start; g < ti->g_end; ++g) {
- for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) {
- for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) {
- const int _oc = g * jcp.nb_oc + oc_b;
- const int _ic = g * jcp.nb_ic + ic_b;
-
- auto src = &ti->src[src_d.blk_off(img, _ic)
- + ik_overlap * input_step];
- auto dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)
- + od_s * output_step];
-
- jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p, src, dst,
- diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b),
- diff_bia + _oc * 16, (img == img_first), od_s, od_e,
- jcp.kd - kd_front_pad - kd_back_pad, kd_pad_off);
-
- if (ic_b == 0) p.flags = 0;
- else p.flags = 1;
- }
- }
- }
-
- const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start;
- const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start;
- jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p,
- &ti->src[src_d.blk_off(img + 1, _ic)],
- &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)],
- diff_wei + wht_blk_off(diff_weights_d, ti->g_start,
- ti->oc_b_start, ti->ic_b_start),
- diff_bia, 0, 0, 0, 0, 0);
- nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od);
- }
-}
-
-template <data_type_t src_type, data_type_t diff_dst_type,
- data_type_t diff_weights_type>
-void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
- diff_weights_type>::reduce_diff_weights(const thread_info_t *ti) const {
- const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
-
- const auto &jcp = kernel_->jcp;
- const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
- const int bia_size = jcp.ngroups * jcp.oc;
- const diff_weights_data_t *diff_bias_ws
- = ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size;
-
- /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
- simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
-
- const int ic_b_kh_work = ti->ic_b_work * jcp.kh;
- const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
-
- int start{0}, end{0};
- balance211(work, nthr_mb_, ti->ithr_mb, start, end);
- if (start == end) return;
-
- for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
- int w = start;
- int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0};
- nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start,
- ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
- while (w < end) {
- const int g = ti->g_start + sub_g_start;
- const int oc_b = ti->oc_b_start + sub_oc_b_start;
- const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kh;
- const int kh = sub_ic_b_kh_start % jcp.kh;
-
- const int acc_size
- = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start)
- * jcp.kw * jcp.ic_block * jcp.oc_block;
-
- const size_t off
- = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kh);
-
- diff_weights_data_t *d
- = (diff_weights_data_t *)ti->diff_weights + off;
- diff_weights_data_t *s
- = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off;
-
- acc_ker_->accumulate(d, s, acc_size);
-
- nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start,
- ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
- }
-
- if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) {
- if (ti->ithr == 0)
- acc_ker_->accumulate((diff_weights_data_t *)ti->diff_bias,
- diff_bias_ws, bia_size);
- diff_bias_ws += bia_size;
- }
- }
-}
-
-template <data_type_t src_type, data_type_t diff_dst_type,
- data_type_t diff_weights_type>
-void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
- diff_weights_type>::reduce_diff_weights_3d(const thread_info_t *ti) const {
- const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
-
- const auto &jcp = kernel_->jcp;
- const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw
- * jcp.kd;
-
- /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
- simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
-
- const int ic_b_kh_work = ti->ic_b_work * jcp.kd;
- const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
-
- int start{0}, end{0};
- balance211(work, nthr_mb_, ti->ithr_mb, start, end);
- if (start == end) return;
-
- for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
- int w = start;
- int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0};
- nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start,
- ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
- while (w < end) {
- const int g = ti->g_start + sub_g_start;
- const int oc_b = ti->oc_b_start + sub_oc_b_start;
- const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kd;
- const int kd = sub_ic_b_kh_start % jcp.kd;
-
- const int acc_size
- = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start)
- * jcp.kw * jcp.ic_block * jcp.oc_block * jcp.kh;
-
- const size_t off
- = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kd);
- diff_weights_data_t *d
- = (diff_weights_data_t *)ti->diff_weights + off;
- diff_weights_data_t *s
- = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off;
- acc_ker_->accumulate(d, s, acc_size);
-
- nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start,
- ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
- }
- }
-}
-
-template <data_type_t src_type, data_type_t diff_dst_type,
- data_type_t diff_weights_type>
-void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
- diff_weights_type>::compute_diff_bias(const thread_info_t *ti) const {
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
-
- auto rb = this->reducer_bias_;
- assert(nthr_ == rb->balancer().nthr_);
-
- const auto reducer_bia_scratchpad = memory_tracking::grantor_t(
- ti->scratchpad, prefix_reducer_bia);
-
- const auto &jcp = kernel_->jcp;
-
- if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) return;
-
- const int b_job_start = rb->balancer().ithr_job_off(ti->ithr);
- const int b_njobs = rb->balancer().ithr_njobs(ti->ithr);
-
- if (b_njobs == 0) return;
-
- /* reduction dimension */
- int img_start{0}, img_end{0};
- balance211(jcp.mb, rb->balancer().nthr_per_group_,
- rb->balancer().id_in_group(ti->ithr), img_start, img_end);
-
- /* jobs */
- int g_start{0}, ocb_start{0};
- nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_oc);
- for (int img = img_start; img < img_end; ++img) {
- int g = g_start, ocb = ocb_start;
- for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
- const size_t _oc = g * jcp.nb_oc + ocb;
-
- const diff_dst_data_t *d_dst
- = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)];
- diff_weights_data_t *d_bias = rb->get_local_ptr(ti->ithr,
- ti->diff_bias, reducer_bia_scratchpad)
- + b_job_loc * rb->balancer().job_size_;
-
- if (img == img_start)
- for (int o = 0; o < 16; ++o)
- d_bias[o] = 0;
- for (int hw = 0; hw < jcp.oh * jcp.ow * jcp.od; ++hw) {
- PRAGMA_OMP_SIMD()
- for (int o = 0; o < 16; ++o)
- d_bias[o] += d_dst[o];
- d_dst += 16;
- }
-
- nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc);
- }
- }
-
- rb->reduce(ti->ithr, ti->diff_bias, reducer_bia_scratchpad);
-}
-
-template <data_type_t src_type, data_type_t diff_dst_type,
- data_type_t diff_weights_type>
-void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
- diff_weights_type>::compute_diff_bias_3d(const thread_info_t *ti) const {
-
- const auto &jcp = kernel_->jcp;
-
- const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh
- * jcp.kw * jcp.kd;
- const int bia_size = jcp.ngroups * jcp.oc;
- const diff_weights_data_t *diff_bias_ws
- = ti->wei_bia_reduction + (size_t)(nthr_mb_ - 1) * wei_size;
-
- if (nthr_mb_ > 1) mkldnn_thr_barrier();
-
- if (ti->ithr == 0)
- {
- for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
- acc_ker_->accumulate(ti->diff_bias, diff_bias_ws, bia_size);
- diff_bias_ws += bia_size;
- }
- }
-}
-
-template <data_type_t src_type, data_type_t diff_dst_type,
- data_type_t diff_weights_type>
-void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
- diff_weights_type>::prepare_scratchpad_data(const exec_ctx_t &ctx) const
-{
- const auto &j = pd()->jcp_;
- auto scratchpad = this->scratchpad(ctx);
-
- if (j.ver == ver_4fma) {
- if (!j.is_1stconv) {
- // XXX: See the comment about tr_iw and guarding elements in
- // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf()
- const int max_nthr = j.nthr_mb * j.ngroups * j.nb_ic;
- const int min_tr_src_size_per_thr = j.ih * j.ic_block * j.tr_iw;
-
- auto tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
- /* to avoid NaNs in computations we zero tail num_guard_elems for
- * each possible thread group */
-
- for (int ithr = 1; ithr <= max_nthr; ++ithr) {
- src_data_t *ts = &tr_src[ithr * min_tr_src_size_per_thr];
- for (int i = 0; i < j.tr_src_num_guard_elems; ++i)
- ts[i] = 0;
- }
- }
-
- if (j.nthr_oc_b > 1) {
- const int tr_src_bctx_size = j.nthr / j.nthr_oc_b;
- auto tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
- key_conv_tr_src_bctx);
- for (int i = 0; i < tr_src_bctx_size; ++i)
- simple_barrier::ctx_init(&tr_src_bctx[i]);
- }
- }
-
- if (nthr_mb_ > 1) {
- simple_barrier::ctx_init(scratchpad.template get<simple_barrier::ctx_t>(
- key_conv_wei_bia_reduction_bctx));
- }
-
- const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
- prefix_reducer_bia);
- auto rb = this->reducer_bias_;
- rb->init(reducer_bia_scratchpad);
-}
-
-template <data_type_t src_type, data_type_t diff_dst_type,
- data_type_t diff_weights_type>
-void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
- diff_weights_type>::execute_backward_weights(const exec_ctx_t &ctx) const {
- prepare_scratchpad_data(ctx);
-
- parallel(nthr_, [&](const int ithr, const int nthr) {
- assert(nthr_ == nthr);
-
- thread_info_t thread_info(this, ctx, ithr);
-
- if (utils::one_of(pd()->ndims(), 3, 4)) {
- compute_diff_weights(&thread_info);
- if (nthr_mb_ > 1) reduce_diff_weights(&thread_info);
- if (pd()->with_bias()) compute_diff_bias(&thread_info);
- } else if (pd()->ndims() == 5) {
- compute_diff_weights_3d(&thread_info);
- if (nthr_mb_ > 1) reduce_diff_weights_3d(&thread_info);
- if (pd()->with_bias()) compute_diff_bias_3d(&thread_info);
- } else {
- assert(false);
- }
- });
-
- /* TODO: put that into compute_diff_bias() */
- if (pd()->wants_padded_bias()) {
- auto diff_bias = scratchpad(ctx).template get<const diff_weights_data_t>(
- key_conv_padded_bias);
- auto diff_bias_in = CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_BIAS);
- for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc)
- diff_bias_in[oc] = diff_bias[oc];
- }
-}
-
-template struct jit_avx512_common_convolution_bwd_weights_t<data_type::f32>;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp
deleted file mode 100644
index 3341c3ebe0..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp
+++ /dev/null
@@ -1,302 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_HPP
-#define CPU_JIT_AVX512_COMMON_CONVOLUTION_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "mkldnn_thread.hpp"
-#include "utils.hpp"
-
-#include "cpu_barrier.hpp"
-#include "cpu_convolution_pd.hpp"
-#include "cpu_primitive.hpp"
-#include "cpu_reducer.hpp"
-
-#include "jit_transpose_src_utils.hpp"
-#include "jit_avx512_common_conv_kernel.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <impl::data_type_t src_type,
- impl::data_type_t wei_type = src_type,
- impl::data_type_t dst_type = src_type>
-struct jit_avx512_common_convolution_fwd_t : public cpu_primitive_t {
- struct pd_t : public cpu_convolution_fwd_pd_t {
- pd_t(engine_t *engine, const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const typename pd_t::base_class *hint_fwd_pd)
- : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_()
- {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""),
- jit_avx512_common_convolution_fwd_t);
-
- status_t init() {
- bool ok = true
- && is_fwd()
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(src_type, wei_type, dst_type, dst_type,
- data_type::undef)
- && !has_zero_dim_memory();
- if (!ok) return status::unimplemented;
-
- status_t status = jit_avx512_common_conv_fwd_kernel::init_conf(
- jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_,
- *attr(), mkldnn_get_max_threads());
- if (status != status::success) return status;
-
- auto scratchpad = scratchpad_registry().registrar();
- jit_avx512_common_conv_fwd_kernel::init_scratchpad(scratchpad,
- jcp_);
-
- return status;
- }
-
- jit_conv_conf_t jcp_;
- };
-
- jit_avx512_common_convolution_fwd_t(const pd_t *apd)
- : cpu_primitive_t(apd)
- {
- kernel_ = new jit_avx512_common_conv_fwd_kernel(pd()->jcp_,
- *pd()->attr());
- }
- ~jit_avx512_common_convolution_fwd_t() { delete kernel_; }
-
- typedef typename prec_traits<src_type>::type src_data_t;
- typedef typename prec_traits<wei_type>::type wei_data_t;
- typedef typename prec_traits<dst_type>::type dst_data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- if (pd()->ndims() == 3)
- execute_forward_1d(ctx);
- else if (pd()->ndims() == 4)
- execute_forward_2d(ctx);
- else if (pd()->ndims() == 5)
- execute_forward_3d(ctx);
- else
- assert(false);
-
- if (pd()->wants_zero_pad_dst())
- ctx.memory(MKLDNN_ARG_DST)->zero_pad();
-
- return status::success;
- }
-
-private:
- void prepare_padded_bias(const dst_data_t *&bias,
- const memory_tracking::grantor_t &scratchpad) const;
- void execute_forward_1d(const exec_ctx_t &ctx) const;
- void execute_forward_2d(const exec_ctx_t &ctx) const;
- void execute_forward_3d(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_avx512_common_conv_fwd_kernel *kernel_;
-};
-
-template <impl::data_type_t diff_dst_type,
- impl::data_type_t wei_type = diff_dst_type,
- impl::data_type_t diff_src_type = diff_dst_type>
-struct jit_avx512_common_convolution_bwd_data_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_bwd_data_pd_t {
- pd_t(engine_t *engine,
- const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_()
- {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""),
- jit_avx512_common_convolution_bwd_data_t);
-
- status_t init() {
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_data
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(diff_src_type, wei_type,
- data_type::undef, diff_dst_type, data_type::undef)
- && !has_zero_dim_memory()
- && set_default_formats();
- if (!ok) return status::unimplemented;
-
- status_t status =
- jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(jcp_,
- *desc(), *diff_src_md(), *weights_md(), *diff_dst_md());
- if (status != status::success) return status;
-
- auto scratchpad = scratchpad_registry().registrar();
- jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad(
- scratchpad, jcp_);
-
- return status::success;
- }
-
- jit_conv_conf_t jcp_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
-
- auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
- auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(),
- OIw16o16i, gOIw16o16i, OIhw16o16i, gOIhw16o16i,
- OIdhw16o16i, gOIdhw16o16i);
-
- return set_default_formats_common(dat_tag, wei_tag, dat_tag);
- }
- };
-
- jit_avx512_common_convolution_bwd_data_t(const pd_t *apd)
- : cpu_primitive_t(apd)
- { kernel_ = new jit_avx512_common_conv_bwd_data_kernel_f32(pd()->jcp_); }
- ~jit_avx512_common_convolution_bwd_data_t() { delete kernel_; };
-
- typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
- typedef typename prec_traits<wei_type>::type wei_data_t;
- typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- if (pd()->ndims() == 3)
- execute_backward_data_1d(ctx);
- else if (pd()->ndims() == 4)
- execute_backward_data_2d(ctx);
- else if (pd()->ndims() == 5)
- execute_backward_data_3d(ctx);
- else
- assert(false);
- return status::success;
- }
-
-private:
- void execute_backward_data_1d(const exec_ctx_t &ctx) const;
- void execute_backward_data_2d(const exec_ctx_t &ctx) const;
- void execute_backward_data_3d(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_avx512_common_conv_bwd_data_kernel_f32 *kernel_;
-};
-
-template <impl::data_type_t src_type,
- impl::data_type_t diff_dst_type = src_type,
- impl::data_type_t diff_weights_type = src_type>
-struct jit_avx512_common_convolution_bwd_weights_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_bwd_weights_pd_t {
- pd_t(engine_t *engine, const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""),
- jit_avx512_common_convolution_bwd_weights_t);
-
- status_t init() {
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_weights
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(src_type, diff_weights_type,
- diff_weights_type, diff_dst_type, data_type::undef)
- && !has_zero_dim_memory();
- if (!ok) return status::unimplemented;
-
- status_t status = jit_avx512_common_conv_bwd_weights_kernel_f32::
- init_conf(jcp_, *desc(), src_md_, diff_weights_md_,
- diff_bias_md_, diff_dst_md_);
- if (status != status::success) return status;
-
- init_balancers();
-
- auto scratchpad = scratchpad_registry().registrar();
- jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad(
- scratchpad, jcp_);
-
- auto reducer_bia_scratchpad = memory_tracking::registrar_t(
- scratchpad, memory_tracking::names::prefix_reducer_bia);
- reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
-
- return status;
- }
-
- jit_conv_conf_t jcp_;
- typename cpu_reducer_t<diff_weights_type>::conf_t reducer_bia_conf_;
-
- private:
- void init_balancers() {
- const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16;
- if (with_bias()) {
- reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr,
- jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb,
- max_buffer_size));
- }
- }
- };
-
- jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd);
- ~jit_avx512_common_convolution_bwd_weights_t() {
- delete kernel_;
- if (trans_kernel_)
- delete trans_kernel_;
- if (acc_ker_)
- delete acc_ker_;
- delete reducer_bias_;
- }
-
- typedef typename prec_traits<src_type>::type src_data_t;
- typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
- typedef typename prec_traits<diff_weights_type>::type diff_weights_data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward_weights(ctx);
- return status::success;
- }
-
-private:
- void execute_backward_weights(const exec_ctx_t &ctx) const;
- void prepare_scratchpad_data(const exec_ctx_t &ctx) const;
- struct thread_info_t;
- void compute_diff_weights(const thread_info_t *) const;
- void compute_diff_weights_3d(const thread_info_t *) const;
- void reduce_diff_weights(const thread_info_t *) const;
- void reduce_diff_weights_3d(const thread_info_t *) const;
- void compute_diff_bias(const thread_info_t *) const;
- void compute_diff_bias_3d(const thread_info_t *) const;
-
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- int nthr_, nthr_mb_, nthr_g_, nthr_oc_b_, nthr_ic_b_;
-
- jit_avx512_common_conv_bwd_weights_kernel_f32 *kernel_;
- jit_trans_src_t *trans_kernel_;
- cpu_accumulator_1d_t<diff_weights_type> *acc_ker_;
- cpu_reducer_t<diff_weights_type> *reducer_bias_;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp
deleted file mode 100644
index 62247c0264..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp
+++ /dev/null
@@ -1,1215 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifdef __INTEL_COMPILER
-#include <immintrin.h>
-#endif
-
-#include "mkldnn_types.h"
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "jit_avx512_common_convolution_winograd.hpp"
-
-#ifndef _MSC_VER
-#define pragma_unroll _Pragma("unroll")
-#else
-#define pragma_unroll
-#endif
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace memory_tracking::names;
-
-namespace {
-
-unsigned int LLC_cache_size = get_cache_size(3, false);
-
-void inline load_ps(float *dest, const float *src_mem) {
-#ifdef __INTEL_COMPILER
- __m512 *Iv512 = (__m512 *)dest;
- Iv512[0] = _mm512_load_ps(src_mem);
-#else
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) dest[v] = src_mem[v];
-#endif
-}
-
-void inline store_output(float *dest, const float *data, bool streamout) {
-#ifdef __INTEL_COMPILER
- if (streamout)
- _mm512_stream_ps(dest, *((__m512 *)data));
- else
- _mm512_store_ps(dest, *((__m512 *)data));
-#else
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++)
- dest[v] = data[v];
-#endif
-}
-
-void inline accum_output(
- float *dest, float *data, bool streamout, bool with_relu_postsum) {
-#ifdef __INTEL_COMPILER
- __m512 _data = _mm512_loadu_ps(data);
- __m512 _dest = _mm512_loadu_ps(dest);
- _data = _mm512_add_ps(_data, _dest);
- if (with_relu_postsum)
- _data = _mm512_max_ps(_data, _mm512_setzero_ps());
- if (streamout)
- _mm512_stream_ps(dest, _data);
- else
- _mm512_store_ps(dest, _data);
-#else
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++)
- data[v] += dest[v];
-
- if (with_relu_postsum) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++)
- if (data[v] < 0.f)
- data[v] = 0.f;
- }
-
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++)
- dest[v] = data[v];
-#endif
-}
-}
-
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::utils;
-
-void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]) {
- float Fw[6][16];
- float T[6][3][16];
- float t0[16];
- float t1[16];
- float t2[16];
-
- for (int j = 0; j < 16; j++) {
-#pragma unroll
- for (int i = 0; i < 3; i++) {
- PRAGMA_OMP_SIMD()
- for (int k = 0; k < 16; k++) {
- t0[k] = 0.26890756302521f * F[2][i][j][k];
- t1[k] = -t0[k] - 0.688403361344538f * F[0][i][j][k];
- t2[k] = t0[k] + 0.119514472455649f * F[0][i][j][k];
-
- T[0][i][k] = 1.13777777777778f * F[0][i][j][k];
- T[1][i][k] = t1[k] - 0.430252100840336f * F[1][i][j][k];
- T[2][i][k] = t1[k] + 0.430252100840336f * F[1][i][j][k];
- T[3][i][k] = t2[k] + 0.179271708683473f * F[1][i][j][k];
- T[4][i][k] = t2[k] - 0.179271708683473f * F[1][i][j][k];
- T[5][i][k] = F[2][i][j][k];
- }
- }
-#pragma unroll
- for (int i = 0; i < 6; i++) {
- PRAGMA_OMP_SIMD()
- for (int k = 0; k < 16; k++) {
- t0[k] = 0.26890756302521f * T[i][2][k];
- t1[k] = -t0[k] - 0.688403361344538f * T[i][0][k];
- t2[k] = t0[k] + 0.119514472455649f * T[i][0][k];
-
- Fw[0][k] = 1.13777777777778f * T[i][0][k];
- Fw[1][k] = t1[k] - 0.430252100840336f * T[i][1][k];
- Fw[2][k] = t1[k] + 0.430252100840336f * T[i][1][k];
- Fw[3][k] = t2[k] + 0.179271708683473f * T[i][1][k];
- Fw[4][k] = t2[k] - 0.179271708683473f * T[i][1][k];
- Fw[5][k] = T[i][2][k];
-#pragma unroll
- for (int l = 0; l < 6; l++) {
- Fw_[i][l][j][k] = Fw[l][k];
- }
- }
- }
- }
-}
-
-void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]) {
- float T[4][6][16];
- float t0[16];
- float t1[16];
- float t2[16];
- float t3[16];
-
-#pragma unroll
- for (int i = 0; i < 6; i++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < 16; v++) {
- t0[v] = Mw[1][i][v] + Mw[2][i][v];
- t1[v] = Mw[3][i][v] + Mw[4][i][v];
- t2[v] = Mw[1][i][v] - Mw[2][i][v];
- t3[v] = Mw[3][i][v] - Mw[4][i][v];
-
- T[0][i][v] = t0[v] + t1[v] + Mw[0][i][v];
- T[1][i][v] = t2[v] * 0.625f + t3[v] * 1.5f;
- T[2][i][v] = t0[v] * 0.390625f + t1[v] * 2.25f;
- T[3][i][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + Mw[5][i][v];
- }
- }
-#pragma unroll
- for (int i = 0; i < 4; i++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < 16; v++) {
- t0[v] = T[i][1][v] + T[i][2][v];
- t1[v] = T[i][3][v] + T[i][4][v];
- t2[v] = T[i][1][v] - T[i][2][v];
- t3[v] = T[i][3][v] - T[i][4][v];
-
- O[i][0][v] = t0[v] + t1[v] + T[i][0][v];
- O[i][1][v] = t2[v] * 0.625f + t3[v] * 1.5f;
- O[i][2][v] = t0[v] * 0.390625f + t1[v] * 2.25f;
- O[i][3][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + T[i][5][v];
- }
- }
-}
-
-
-void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16])
-{
- const float rcp3 = 1.0f / 3.0f;
- const float rcp4 = 1.0f / 4.0f;
- const float rcp6 = 1.0f / 6.0f;
- const float rcp12 = 1.0f / 12.0f;
- const float rcp24 = 1.0f / 24.0f;
- float t0[16];
- float t1[16];
- float t2[16];
- float t3[16];
- float t4[16];
- float T[6][4][16];
-
-pragma_unroll
- for (int i = 0; i < 4; i++) {
- PRAGMA_OMP_SIMD()
- for (int j = 0; j < 16; j++) {
- t0[j] = F[2][i][j] * rcp6;
- t1[j] = F[0][i][j] * -rcp6 - t0[j];
- t2[j] = F[0][i][j] * rcp24 + t0[j];
- t3[j] = (F[1][i][j] + F[3][i][j]) * rcp6;
- t4[j] = F[1][i][j] * rcp12 + F[3][i][j] * rcp3;
-
- T[0][i][j] = F[0][i][j] * rcp4;
- T[1][i][j] = t1[j] - t3[j];
- T[2][i][j] = t1[j] + t3[j];
- T[3][i][j] = t2[j] + t4[j];
- T[4][i][j] = t2[j] - t4[j];
- T[5][i][j] = F[3][i][j];
- }
- }
-pragma_unroll
- for (int i = 0; i < 6; i++) {
- PRAGMA_OMP_SIMD()
- for (int j = 0; j < 16; j++) {
- t0[j] = T[i][2][j] * rcp6;
- t1[j] = T[i][0][j] * -rcp6 - t0[j];
- t2[j] = T[i][0][j] * rcp24 + t0[j];
- t3[j] = (T[i][1][j] + T[i][3][j]) * rcp6;
- t4[j] = T[i][1][j] * rcp12 + T[i][3][j] * rcp3;
-
- Fw[i][0][j] = T[i][0][j] * rcp4;
- Fw[i][1][j] = t1[j] - t3[j];
- Fw[i][2][j] = t1[j] + t3[j];
- Fw[i][3][j] = t2[j] + t4[j];
- Fw[i][4][j] = t2[j] - t4[j];
- Fw[i][5][j] = T[i][3][j];
- }
- }
-}
-
-void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16])
-{
- float T[4][6][16];
- float M_[3][16];
- float t0[16];
- float t1[16];
- float t2[16];
-
- for (int j = 0; j < 16; j++) {
-pragma_unroll
- for (int i = 0; i < 6; i++) {
- PRAGMA_OMP_SIMD()
- for (int l = 0; l < 16; l++) {
- t0[l] = Mw[1][i][j][l] + Mw[2][i][j][l];
- t1[l] = Mw[3][i][j][l] + Mw[4][i][j][l];
- t2[l] = t1[l] * 4.0f + Mw[5][i][j][l];
-
- T[0][i][l] = Mw[0][i][j][l] + t0[l] + t1[l];
- T[1][i][l] = (Mw[1][i][j][l] - Mw[2][i][j][l]) +
- 2.0f * (Mw[3][i][j][l] - Mw[4][i][j][l]);
- T[2][i][l] = t0[l] + t2[l];
- }
- }
-pragma_unroll
- for (int i = 0; i < 3; i++) {
- PRAGMA_OMP_SIMD()
- for (int l = 0; l < 16; l++) {
- t0[l] = T[i][1][l] + T[i][2][l];
- t1[l] = T[i][3][l] + T[i][4][l];
- t2[l] = t1[l] * 4.0f + T[i][5][l];
-
- M_[0][l] = T[i][0][l] + t0[l] + t1[l];
- M_[1][l] = (T[i][1][l] - T[i][2][l]) +
- 2.0f * (T[i][3][l] - T[i][4][l]);
- M_[2][l] = t0[l] + t2[l];
-
- for (int k = 0; k < 3; k++) {
- M[i][k][j][l] = M_[k][l];
- }
- }
- }
- }
-}
-
-void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16])
-{
- float T[6][6][16];
- float t0[16];
- float t1[16];
- float t2[16];
- float t3[16];
- float t4[16];
- float t5[16];
-
-pragma_unroll
- for (int i = 0; i < 6; i++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < 16; v++) {
- t0[v] = I[2][i][v] * -2.25f + I[4][i][v];
- t1[v] = I[1][i][v] * -2.25f + I[3][i][v];
- t2[v] = I[2][i][v] * -0.390625f + I[4][i][v];
- t3[v] = I[1][i][v] * -0.390625f + I[3][i][v];
- t4[v] = I[0][i][v] * 0.87890625f + I[4][i][v];
- t5[v] = I[1][i][v] * 0.87890625f + I[5][i][v];
-
- T[0][i][v] = I[2][i][v] * -2.640625f + t4[v];
- T[1][i][v] = t1[v] * 0.625f + t0[v];
- T[2][i][v] = t1[v] * -0.625f + t0[v];
- T[3][i][v] = t3[v] * 1.5f + t2[v];
- T[4][i][v] = t3[v] * -1.5f + t2[v];
- T[5][i][v] = I[3][i][v] * -2.640625f + t5[v];
- }
- }
-
-pragma_unroll
- for (int i = 0; i < 6; i++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < 16; v++) {
- t0[v] = T[i][2][v] * -2.25f + T[i][4][v];
- t1[v] = T[i][1][v] * -2.25f + T[i][3][v];
- t2[v] = T[i][2][v] * -0.390625f + T[i][4][v];
- t3[v] = T[i][1][v] * -0.390625f + T[i][3][v];
- t4[v] = T[i][0][v] * 0.87890625f + T[i][4][v];
- t5[v] = T[i][1][v] * 0.87890625f + T[i][5][v];
-
- Iw[i][0][v] = T[i][2][v] * -2.640625f + t4[v];
- Iw[i][1][v] = t1[v] * 0.625f + t0[v];
- Iw[i][2][v] = t1[v] * -0.625f + t0[v];
- Iw[i][3][v] = t3[v] * 1.5f + t2[v];
- Iw[i][4][v] = t3[v] * -1.5f + t2[v];
- Iw[i][5][v] = T[i][3][v] * -2.640625f + t5[v];
- }
- }
-}
-
-void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16])
-{
- float T[6][4][16];
- float t0[16];
- float t1[16];
- float t2[16];
- float t3[16];
- float t4[16];
-
-pragma_unroll
- for (int i = 0; i < 4; i++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < 16; v++) {
- t0[v] = F[2][i][v] * 0.26890756302521f;
- t1[v] = F[0][i][v] * -0.688403361344538f - t0[v];
- t2[v] = F[0][i][v] * 0.119514472455649f + t0[v];
- t3[v] = F[1][i][v] * 0.430252100840336f +
- F[3][i][v] * 0.168067226890756f;
- t4[v] = F[1][i][v] * 0.179271708683473f +
- F[3][i][v] * 0.403361344537815f;
-
- T[0][i][v] = F[0][i][v] * 1.13777777777778f;
- T[1][i][v] = t1[v] - t3[v];
- T[2][i][v] = t1[v] + t3[v];
- T[3][i][v] = t2[v] + t4[v];
- T[4][i][v] = t2[v] - t4[v];
- T[5][i][v] = F[3][i][v];
- }
- }
-pragma_unroll
- for (int i = 0; i < 6; i++) {
- for (int v = 0; v < 16; v++) {
- t0[v] = T[i][2][v] * 0.26890756302521f;
- t1[v] = T[i][0][v] * -0.688403361344538f - t0[v];
- t2[v] = T[i][0][v] * 0.119514472455649f + t0[v];
- t3[v] = T[i][1][v] * 0.430252100840336f +
- T[i][3][v] * 0.168067226890756f;
- t4[v] = T[i][1][v] * 0.179271708683473f +
- T[i][3][v] * 0.403361344537815f;
-
- Fw[i][0][v] = T[i][0][v] * 1.13777777777778f;
- Fw[i][1][v] = t1[v] - t3[v];
- Fw[i][2][v] = t1[v] + t3[v];
- Fw[i][3][v] = t2[v] + t4[v];
- Fw[i][4][v] = t2[v] - t4[v];
- Fw[i][5][v] = T[i][3][v];
- }
- }
-}
-
-void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16])
-{
- float T[3][6][16];
- float t0[16];
- float t1[16];
- float t2[16];
- float M_[3][16];
-
- for (int j = 0; j < 16; j++) {
-pragma_unroll
- for (int i = 0; i < 6; i++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < 16; v++) {
- t0[v] = Mw[1][i][j][v] + Mw[2][i][j][v];
- t1[v] = Mw[3][i][j][v] + Mw[4][i][j][v];
- t2[v] = t1[v] * 2.25f + Mw[5][i][j][v];
-
- T[0][i][v] = Mw[0][i][j][v] + t0[v] + t1[v];
- T[1][i][v] = 0.625f * (Mw[1][i][j][v] - Mw[2][i][j][v]) +
- 1.5f * (Mw[3][i][j][v] - Mw[4][i][j][v]);
- T[2][i][v] = t0[v] * 0.390625f + t2[v];
- }
- }
-pragma_unroll
- for (int i = 0; i < 3; i++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < 16; v++) {
- t0[v] = T[i][1][v] + T[i][2][v];
- t1[v] = T[i][3][v] + T[i][4][v];
- t2[v] = t1[v] * 2.25f + T[i][5][v];
-
- M_[0][v] = T[i][0][v] + t0[v] + t1[v];
- M_[1][v] = 0.625f * (T[i][1][v] - T[i][2][v]) +
- 1.5f * (T[i][3][v] - T[i][4][v]);
- M_[2][v] = t0[v] * 0.390625f + t2[v];
- }
-
-pragma_unroll
- for (int k = 0; k < 3; k++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < 16; v++) {
- M[i][k][j][v] = M_[k][v];
- }
- }
- }
- }
-}
-
-template <bool is_fwd>
-void input_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
- float *inp, float *tinp, bool streamout = true)
-{
- const int inpw = is_fwd ? jcp.iw : jcp.ow;
- const int inph = is_fwd ? jcp.ih : jcp.oh;
- const int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow;
- const int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh;
- const int wp_max = inpw + l_pad;
- const int hp_max = inph + t_pad;
- float Iw[alpha][alpha][simd_w];
- float I[alpha][alpha][simd_w];
-
- array_offset_calculator<float, 5> input(inp,
- jcp.mb, jcp.dimK/simd_w, inph, inpw,
- simd_w);
- array_offset_calculator<float, 8> output(tinp,
- jcp.dimN_nb_block, alpha, alpha,
- jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block,
- jcp.dimN_reg_block, jcp.dimK_reg_block);
-
- int tile_base_index = image * jcp.itiles * jcp.jtiles;
- int tile_block_ur = tile_base_index % jcp.tile_block_ur;
- int nb_tile_block_ur =
- (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
- int tile_block =
- (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
-
- for (int tj = 0; tj < jcp.jtiles; tj++) {
- for (int ti = 0; ti < jcp.itiles; ti++) {
- for (int j = 0; j < alpha; j++) {
- int ydim = tj * tile_size + j;
- if ((t_pad <= ydim) && (ydim < hp_max)) {
- float *pinp_j = inp + (ydim - t_pad) * inpw * 16 ;
- for (int i = 0; i < alpha; i++) {
- int xdim = ti * tile_size + i;
- if ((l_pad <= xdim) && (xdim < wp_max)) {
- float *pinp_i = pinp_j + (xdim - l_pad) * 16;
- load_ps(I[j][i], pinp_i);
- } else {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- I[j][i][v] = 0.0f;
- }
- }
- }
- } else {
- for (int i = 0; i < alpha; i++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- I[j][i][v] = 0.0f;
- }
- }
- }
- }
-
- trans_I_4x4_3x3(Iw, I);
-
- for (int j = 0; j < alpha; j++) {
- for (int i = 0; i < alpha; i++) {
- store_output(&(output(tile_block, j, i,
- nb_tile_block_ur, 0, 0,
- tile_block_ur, 0)),
- Iw[j][i], streamout);
- }
- }
- tile_block_ur++;
- if (tile_block_ur >= jcp.tile_block_ur) {
- tile_block_ur = 0;
- nb_tile_block_ur++;
- }
- if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
- nb_tile_block_ur = 0;
- tile_block++;
- }
- }
- }
-}
-
-template <bool is_fwd>
-void weight_transform_data(const jit_conv_winograd_conf_t &jcp,
- float *wp, float *twp)
-{
- const int kh = 3;
- const int kw = 3;
- array_offset_calculator<float, 6> input(wp,
- jcp.oc/jcp.oc_simd_block,
- jcp.ic/jcp.ic_simd_block,
- jcp.kh, jcp.kw,
- simd_w, simd_w);
- array_offset_calculator<float, 8> output(twp,
- jcp.dimM_nb_block,
- alpha, alpha,
- jcp.dimK_nb_block,
- jcp.dimM_block, jcp.dimK_block,
- simd_w, simd_w);
- float Fw[alpha][alpha][simd_w][simd_w];
- float F[kh][kw][simd_w][simd_w];
-
- for (int j = 0; j < kh; j++) {
- for (int i = 0; i < kw; i++) {
- for (int v1 = 0; v1 < simd_w; v1++) {
- float *base_inp = is_fwd
- ? &(input(0, 0, j, i, v1, 0))
- : &(input(0, 0, 2 - j, 2 - i, v1, 0));
- PRAGMA_OMP_SIMD()
- for (int v2 = 0; v2 < simd_w; v2++) {
- if (is_fwd)
- F[j][i][v1][v2] = *(base_inp + v2);
- else
- F[j][i][v2][v1] = *(base_inp + v2);
- }
- }
- }
- }
-
- trans_W_4x4_3x3(Fw, F);
-
- for (int j = 0; j < alpha; j++) {
- for (int i = 0; i < alpha; i++) {
- for (int v1 = 0; v1 < simd_w; v1++) {
- PRAGMA_OMP_SIMD()
- for (int v2 = 0; v2 < simd_w; v2++) {
- output(0, j, i, 0, 0, 0, v1, v2) = Fw[j][i][v1][v2];
- }
- }
- }
- }
-}
-
-template <bool is_fwd, bool with_bias, bool with_relu_presum, bool with_sum>
-void output_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
- const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias,
- bool streamout = true) {
- float Ow[alpha][alpha][simd_w];
- float O[tile_size][tile_size][simd_w];
- int outw = is_fwd ? jcp.ow : jcp.iw;
- int outh = is_fwd ? jcp.oh : jcp.ih;
-
- /* Prepare for PostOps */
- bool with_relu_postsum = p_ops.find(primitive_kind::eltwise, 1) != -1;
-
- array_offset_calculator<float, 8> input(toutp,
- jcp.dimN_nb_block, jcp.dimM_nb_block,
- alpha, alpha,
- jcp.dimN_block, jcp.dimM_block,
- jcp.dimN_reg_block, jcp.dimM_simd_block);
-
- int tile_base_index = image * jcp.itiles * jcp.jtiles;
- int tile_block_ur = tile_base_index % jcp.tile_block_ur;
- int nb_tile_block_ur =
- (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
- int tile_block =
- (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
-
- for (int tj = 0; tj < jcp.jtiles; tj++) {
- for (int ti = 0; ti < jcp.itiles; ti++) {
- for (int j = 0; j < alpha; j++) {
- for (int i = 0; i < alpha; i++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- Ow[j][i][v] = input(tile_block, 0,
- j, i,
- nb_tile_block_ur, 0,
- tile_block_ur, v);
- }
- }
- }
-
- trans_O_4x4_3x3(Ow, O);
-
- for (int j = 0; j < tile_size; j++) {
- int ydim = tj * tile_size + j;
- if (ydim < outh) {
- float *pout_j = pout_b + ydim * outw * simd_w;
- for (int i = 0; i < tile_size; i++) {
- int xdim = ti * tile_size + i;
- if (xdim < outw) {
- float *pout_i = pout_j + xdim * simd_w;
- if (is_fwd) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- O[j][i][v] += with_bias ? bias[v] : 0.f;
- O[j][i][v] = true
- && with_relu_presum && O[j][i][v] < 0.f
- ? O[j][i][v]
- * jcp.eltwise.alpha
- : O[j][i][v];
- }
- }
- if (with_sum)
- accum_output(pout_i, O[j][i], streamout,
- with_relu_postsum);
- else
- store_output(pout_i, O[j][i], streamout);
- }
- }
- }
- }
- tile_block_ur++;
- if (tile_block_ur >= jcp.tile_block_ur) {
- tile_block_ur = 0;
- nb_tile_block_ur++;
- }
- if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
- nb_tile_block_ur = 0;
- tile_block++;
- }
- }
- }
-}
-
-template <bool ver_4fma>
-void diff_src_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv,
- float *inp, float *tinp, float *Iw_temp,
- void (*transpose_4fma_ker)(float *, float *))
-{
-
- const int ifwp = conv.iw + conv.l_pad;
- const int ifhp = conv.ih + conv.t_pad;
- float I[alpha][alpha][simd_w];
- float Iw[alpha][alpha][simd_w];
-
- array_offset_calculator<float, 4> Iw_trans_temp(Iw_temp,
- alpha, alpha, conv.tile_4fma, simd_w);
- array_offset_calculator<float, 5> input(inp,
- conv.mb, conv.ic/simd_w, conv.ih, conv.iw, simd_w);
- array_offset_calculator<float, 8> output(tinp,
- conv.nb_ic, alpha, alpha,
- conv.tile_block, conv.ic_block,
- conv.nb_tile_block_ur, conv.tile_block_ur,
- conv.ic_simd_block * conv.tile_4fma);
-
- int tile_base_index =
- image * (conv.itiles * conv.jtiles + conv.tile_4fma_padding);
- int tile_4fma = 0;
- int tile_block_ur = (tile_base_index / conv.tile_4fma) % conv.tile_block_ur;
- int nb_tile_block_ur =
- (tile_base_index / conv.tile_4fma / conv.tile_block_ur)
- % conv.nb_tile_block_ur;
- int tile_block = (tile_base_index / conv.tile_4fma / conv.tile_block_ur)
- / conv.nb_tile_block_ur;
-
- for (int tj = 0; tj < conv.jtiles; tj++) {
- for (int ti = 0; ti < conv.itiles; ti++) {
- for (int j = 0; j < alpha; j++) {
- int ydim = tj * tile_size + j;
- if ((conv.t_pad <= ydim) && ydim < ifhp) {
- for (int i = 0; i < alpha; i++) {
- int xdim = ti * tile_size + i;
- if ((conv.l_pad <= xdim) && xdim < ifwp) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- I[j][i][v] = input(0, 0,
- ydim - conv.t_pad,
- xdim - conv.l_pad, v);
- }
- } else {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- I[j][i][v] = 0.0f;
- }
- }
- }
- } else {
- for (int i = 0; i < alpha; i++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- I[j][i][v] = 0.0f;
- }
- }
- }
- }
- trans_I_4x4_3x3(Iw, I);
-
- if (ver_4fma) {
- for (int j = 0; j < alpha; j++) {
- for (int i = 0; i < alpha; i++) {
- float *Iw_temp_base = &(Iw_trans_temp(j, i,
- tile_4fma, 0));
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- Iw_temp_base[v] = Iw[j][i][v];
- }
- }
- }
- tile_4fma++;
- if (tile_4fma == conv.tile_4fma) {
- float *outp = &(output(0, 0, 0,
- tile_block, 0,
- nb_tile_block_ur, tile_block_ur, 0));
- transpose_4fma_ker(outp, (float *)Iw_temp);
- tile_4fma = 0;
- tile_block_ur++;
- }
- } else {
- for (int j = 0; j < alpha; j++) {
- for (int i = 0; i < alpha; i++) {
- store_output(&(output(0, j, i,
- tile_block, 0,
- nb_tile_block_ur, tile_block_ur, 0)),
- Iw[j][i], true);
- }
- }
- tile_block_ur++;
- }
-
- if (tile_block_ur == conv.tile_block_ur) {
- tile_block_ur = 0;
- ++nb_tile_block_ur;
- }
- if (nb_tile_block_ur == conv.nb_tile_block_ur) {
- nb_tile_block_ur = 0;
- tile_block++;
- }
- }
- }
-
- if (ver_4fma && tile_4fma < conv.tile_4fma && conv.tile_4fma_padding != 0) {
-
- for (int j = 0; j < alpha; j++) {
- for (int i = 0; i < alpha; i++) {
- for (int tb = tile_4fma; tb < conv.tile_4fma; tb++) {
- float *Iw_temp_base = &(Iw_trans_temp(j, i, tb, 0));
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- Iw_temp_base[v] = 0;
- }
- }
- }
- }
- float *outp = &(output(0, 0, 0,
- tile_block, 0,
- nb_tile_block_ur, tile_block_ur, 0));
- transpose_4fma_ker(outp, (float *)Iw_temp);
- }
-}
-
-template <bool with_bias>
-void diff_dst_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv,
- float *inp, float *tinp, float *dbias)
-{
-
- const int total_tiles = conv.itiles * conv.jtiles + conv.tile_4fma_padding;
- float I[alpha][alpha][simd_w];
- float Iw[alpha][alpha][simd_w];
-
- array_offset_calculator<float, 5> input(inp,
- conv.mb, conv.oc/simd_w, conv.oh, conv.ow, conv.oc_simd_block);
- array_offset_calculator<float, 8> output(tinp,
- conv.nb_oc, alpha, alpha,
- conv.tile_block, conv.oc_block,
- conv.nb_tile_block_ur,
- conv.tile_block_ur * conv.tile_4fma, conv.oc_simd_block);
-
- int tile_base_index = image * total_tiles;
- int tile_block_ur = tile_base_index % (conv.tile_block_ur * conv.tile_4fma);
- int nb_tile_block_ur =
- (tile_base_index / conv.tile_block_ur / conv.tile_4fma)
- % conv.nb_tile_block_ur;
- int tile_block = (tile_base_index / conv.tile_block_ur / conv.tile_4fma)
- / conv.nb_tile_block_ur;
-
- for (int tj = 0; tj < conv.jtiles; tj++) {
- for (int ti = 0; ti < conv.itiles; ti++) {
- for (int j = 0; j < alpha; j++) {
- int ydim = tj * tile_size + j;
- if (ydim < conv.oh) {
- for (int i = 0; i < alpha; i++) {
- int xdim = ti * tile_size + i;
- if (xdim < conv.ow) {
- float *input_base = &(input(0, 0, ydim, xdim, 0));
-
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- I[j][i][v] = input_base[v];
- }
- if (with_bias && j < tile_size && i < tile_size) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- dbias[v] += input_base[v];
- }
- }
- } else {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- I[j][i][v] = 0.0f;
- }
- }
- }
- } else {
- for (int i = 0; i < alpha; i++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- I[j][i][v] = 0.0f;
- }
- }
- }
- }
-
- trans_W_3x3_4x4_wu(Iw, I);
-
- for (int j = 0; j < alpha; j++) {
- for (int i = 0; i < alpha; i++) {
- store_output(&(output(0, j, i,
- tile_block, 0,
- nb_tile_block_ur,
- tile_block_ur, 0)),
- Iw[j][i], true);
- }
- }
- tile_block_ur++;
- if (tile_block_ur >= conv.tile_block_ur * conv.tile_4fma) {
- tile_block_ur = 0;
- nb_tile_block_ur++;
- }
- if (nb_tile_block_ur >= conv.nb_tile_block_ur) {
- nb_tile_block_ur = 0;
- tile_block++;
- }
- }
- }
-}
-
-void diff_weights_transform_bwd_weights(jit_conv_winograd_conf_t conv,
- float *wp, float *twp)
-{
- const int kh = 3;
- const int kw = 3;
- float Fw[alpha][alpha][simd_w][simd_w];
- float F[kh][kw][simd_w][simd_w];
-
- array_offset_calculator<float, 8> input(twp,
- conv.nb_ic, conv.nb_oc,
- alpha, alpha,
- conv.oc_block, conv.ic_block,
- conv.ic_simd_block, conv.oc_simd_block);
- array_offset_calculator<float, 6> output(wp,
- conv.oc/simd_w, conv.ic/simd_w,
- conv.kh, conv.kw,
- conv.ic_simd_block, conv.oc_simd_block);
-
- for (int j = 0; j < alpha; j++) {
- for (int i = 0; i < alpha; i++) {
- for (int v = 0; v < conv.ic_simd_block; v++) {
- PRAGMA_OMP_SIMD()
- for (int k = 0; k < conv.oc_simd_block; k++) {
- Fw[j][i][v][k] = input(0, 0, j, i, 0, 0, v, k);
- }
- }
- }
- }
-
- trans_O_3x3_4x4_wu(Fw, F);
-
- for (int j = 0; j < kh; j++) {
- for (int i = 0; i < kw; i++) {
- for (int v = 0; v < conv.ic_simd_block; v++) {
- store_output(&(output(0, 0, j, i, v, 0)),
- F[j][i][v], true);
- }
- }
- }
-}
-
-template <bool is_fwd>
-void _jit_avx512_common_convolution_winograd_t<is_fwd>::_execute_data_W_S_G_D(
- float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
- const memory_tracking::grantor_t &scratchpad) const {
- const auto &jcp = kernel_->jcp;
- const auto &p_ops = attr_->post_ops_;
-
- const int inph = is_fwd ? jcp.ih : jcp.oh;
- const int inpw = is_fwd ? jcp.iw : jcp.ow;
- const int outh = is_fwd ? jcp.oh : jcp.ih;
- const int outw = is_fwd ? jcp.ow : jcp.iw;
-
- /* Note that jcp.with_eltwise is true for both fused conv+relu primitive
- * and conv primitive with PostOps with relu before sum
- * (PostOps relu after sum is handled later) */
- auto output_transform = jcp.with_bias
- ? (jcp.with_eltwise
- ? (jcp.with_sum
- ? output_transform_data<is_fwd, true, true, true>
- : output_transform_data<is_fwd, true, true, false>)
- : (jcp.with_sum
- ? output_transform_data<is_fwd, true, false, true>
- : output_transform_data<is_fwd, true, false, false>))
- : (jcp.with_eltwise
- ? (jcp.with_sum
- ? output_transform_data<is_fwd, false, true, true>
- : output_transform_data<is_fwd, false, true, false>)
- : (jcp.with_sum
- ? output_transform_data<is_fwd, false, false, true>
- : output_transform_data<is_fwd, false, false, false>));
-
- /* Notation:
- FWD: dimM:oc, dimN:ntiles, dimK:ic,
- BWD: dimM:ic, dimN:ntiles, dimK:oc,
- FWD/BWD: V: src/diff_dst transform, U:weight transform,
- M:dst/diff_src transform */
- array_offset_calculator<float, 5> input(inp_ptr,
- jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw,
- jcp.dimK_reg_block);
- array_offset_calculator<float, 5> output(out_ptr,
- jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw,
- jcp.dimM_simd_block);
- array_offset_calculator<float, 6> weights(wei_ptr,
- jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
- jcp.ic_simd_block, jcp.oc_simd_block);
- array_offset_calculator<float, 2> bias(bias_ptr,
- jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block);
-
- array_offset_calculator<float, 8> M(is_fwd
- ? scratchpad.template get<float>(key_wino_M)
- : scratchpad.template get<float>(key_wino_V),
- jcp.dimN_nb_block, jcp.dimM_nb_block,
- alpha, alpha,
- jcp.dimN_block, jcp.dimM_block,
- jcp.dimN_reg_block, jcp.dimM_simd_block);
- array_offset_calculator<float, 8> U(
- scratchpad.template get<float>(key_wino_U),
- jcp.dimM_nb_block,
- alpha, alpha,
- jcp.dimK_nb_block,
- jcp.dimM_block, jcp.dimK_block,
- jcp.dimK_reg_block, jcp.dimM_simd_block);
- array_offset_calculator<float, 8> V(is_fwd
- ? scratchpad.template get<float>(key_wino_V)
- : scratchpad.template get<float>(key_wino_M),
- jcp.dimN_nb_block, alpha, alpha,
- jcp.dimN_block, jcp.dimK_nb_block,
- jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block);
-
- bool V_streamout = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float)
- > 2 * LLC_cache_size ? true : false;
-
- const bool output_is_aligned = ((size_t)out_ptr & (64 - 1)) == 0;
-
- const bool wants_padded_bias = jcp.with_bias
- && jcp.oc_without_padding != jcp.oc;
- float last_slice_bias[simd_w] = {0};
- if (wants_padded_bias) {
- for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
- last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
- }
-
- {
- parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block,
- [&](int img, int K_blk1, int K_blk2) {
- input_transform_data<is_fwd>(img, jcp,
- &(input(img, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)),
- &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)), V_streamout);
- });
-
- parallel_nd(jcp.nb_oc, jcp.nb_ic, jcp.oc_block, jcp.ic_block,
- [&](int ofm1, int ifm1, int ofm2, int ifm2) {
- float *U_base_ptr = is_fwd
- ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
- : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
- weight_transform_data<is_fwd>(jcp,
- &(weights(ofm1 * jcp.oc_block + ofm2,
- ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)), U_base_ptr);
- });
-
- parallel_nd(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block, jcp.dimN_block,
- [&](int N_blk1, int oj, int oi, int M_blk1, int N_blk2) {
-
- kernel_->gemm_loop_ker_first_iter(
- (float *)&(M(N_blk1, M_blk1, oj, oi,
- N_blk2, 0, 0, 0)),
- (const float *)&(U(M_blk1, oj, oi,
- 0, 0, 0, 0, 0)),
- (const float *)&(V(N_blk1, oj, oi,
- N_blk2, 0, 0, 0, 0)));
- for (int K_blk1 = 1; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
- kernel_->gemm_loop_ker(
- (float *)&(M(N_blk1, M_blk1, oj, oi,
- N_blk2, 0, 0, 0)),
- (const float *)&(U(M_blk1, oj, oi,
- K_blk1, 0, 0, 0, 0)),
- (const float *)&(V(N_blk1, oj, oi,
- N_blk2, K_blk1,
- 0, 0, 0)));
- }
-
- });
-
- parallel_nd(jcp.mb, jcp.dimM_nb_block, jcp.dimM_block,
- [&](int img, int M_blk1, int M_blk2) {
-
- const int M_blk = M_blk1 * jcp.dimM_block + M_blk2;
-
- float *bias_ptr = wants_padded_bias
- && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
- ? last_slice_bias : &bias(M_blk, 0);
-
- output_transform(img, jcp, p_ops,
- &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
- &(output(img, M_blk, 0, 0, 0)),
- bias_ptr, output_is_aligned);
-
- });
-
- }
-}
-
-template struct _jit_avx512_common_convolution_winograd_t<true>;
-template struct _jit_avx512_common_convolution_winograd_t<false>;
-
-void jit_avx512_common_convolution_winograd_bwd_weights_t::
-_maybe_execute_diff_bias_copy(float *diff_bias,
- const memory_tracking::grantor_t &scratchpad) const {
- if (pd()->wants_padded_bias()) {
- auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
- for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc)
- diff_bias[oc] = padded_bias[oc];
- }
-}
-
-void jit_avx512_common_convolution_winograd_bwd_weights_t::
-_execute_backward_weights_S_D_G_W(const exec_ctx_t &ctx,
- const memory_tracking::grantor_t &scratchpad) const {
- auto ptr_diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST);
- auto ptr_src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC);
- auto ptr_diff_weights = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS);
- auto ptr_diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS);
-
- const auto &jcp = kernel_->jcp;
- const int nthreads = jcp.nthr;
-
- auto diff_src_transform_bwd_weights_ver = jcp.ver == ver_4fma ?
- diff_src_transform_bwd_weights<true> :
- diff_src_transform_bwd_weights<false>;
- auto diff_dst_transform_bwd_weights_ver = jcp.with_bias
- ? diff_dst_transform_bwd_weights<true>
- : diff_dst_transform_bwd_weights<false>;
-
- array_offset_calculator<float, 5> src((float *)ptr_src,
- jcp.mb, jcp.ic/simd_w, jcp.ih, jcp.iw, simd_w);
- array_offset_calculator<float, 5> diff_dst((float *)ptr_diff_dst,
- jcp.mb, jcp.oc/simd_w, jcp.oh, jcp.ow, simd_w);
- array_offset_calculator<float, 6> diff_weights(ptr_diff_weights,
- jcp.oc/simd_w, jcp.ic/simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
- array_offset_calculator<float, 2> diff_bias(pd()->wants_padded_bias()
- ? scratchpad.get<float>(key_conv_padded_bias) : ptr_diff_bias,
- jcp.oc/simd_w, simd_w);
-
- array_offset_calculator<float, 8> U(
- scratchpad.get<float>(key_wino_U),
- jcp.nb_ic, jcp.nb_oc,
- alpha, alpha,
- jcp.oc_block, jcp.ic_block,
- jcp.ic_simd_block, jcp.oc_simd_block);
-
- array_offset_calculator<float, 8> M(
- scratchpad.get<float>(key_wino_M),
- jcp.nb_oc, alpha, alpha,
- jcp.tile_block, jcp.oc_block,
- jcp.nb_tile_block_ur, jcp.tile_block_ur * jcp.tile_4fma,
- jcp.oc_simd_block);
- array_offset_calculator<float, 8> V(
- scratchpad.get<float>(key_wino_V),
- jcp.nb_ic, alpha, alpha,
- jcp.tile_block, jcp.ic_block,
- jcp.nb_tile_block_ur, jcp.tile_block_ur,
- jcp.ic_simd_block * jcp.tile_4fma);
-
- const int trans_buffer_size = alpha * alpha * jcp.tile_4fma
- * jcp.ic_simd_block;
- array_offset_calculator<float, 2> trans_buffer(
- scratchpad.get<float>(key_conv_tr_src),
- nthreads,
- trans_buffer_size);
-
- array_offset_calculator<float, 2> diff_bias_prv(
- scratchpad.get<float>(key_conv_bia_reduction),
- nthreads,
- jcp.oc);
-
-PRAGMA_OMP(parallel num_threads(nthreads))
- {
- if (jcp.with_bias) {
- parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) {
- diff_bias_prv(ithr, ofm) = 0.0f;
- });
-
-PRAGMA_OMP(for nowait)
- for (int bofm = 0; bofm < jcp.oc / simd_w; bofm++) {
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++)
- diff_bias(bofm, v) = 0.0f;
- }
- }
-
- const int ithread = mkldnn_get_thread_num();
-
- parallel_nd_in_omp(jcp.mb, jcp.nb_ic, jcp.ic_block,
- [&](int img, int ifm1, int ifm2) {
- float *transb = jcp.ver == ver_4fma
- ? &(trans_buffer(ithread, 0))
- : NULL;
- diff_src_transform_bwd_weights_ver(img, jcp,
- &(src(img, ifm1 * jcp.ic_block + ifm2,
- 0, 0, 0)),
- &(V(ifm1, 0, 0, 0, ifm2, 0, 0, 0)),
- transb,
- kernel_->transpose_4fma_ker);
- });
-
- parallel_nd_in_omp(jcp.mb, jcp.nb_oc, jcp.oc_block,
- [&](int img, int ofm1, int ofm2) {
- float *dbias = jcp.with_bias
- ? &(diff_bias_prv(ithread,
- simd_w * (ofm1 * jcp.oc_block + ofm2)))
- : NULL;
- diff_dst_transform_bwd_weights_ver(img, jcp,
- &(diff_dst(img, ofm1 * jcp.oc_block + ofm2,
- 0, 0, 0)),
- &(M(ofm1, 0, 0, 0, ofm2, 0, 0, 0)),
- dbias);
- });
-
-PRAGMA_OMP(barrier)
-
- for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++) {
- parallel_nd_in_omp(alpha, alpha, jcp.nb_oc,
- [&](int oj, int oi, int ofm1) {
- kernel_->gemm_loop_ker_first_iter(
- (float *)&(U(ifm1, ofm1, oj, oi,
- 0, 0, 0, 0)),
- (const float *)&(M(ofm1, oj, oi,
- 0, 0, 0, 0, 0)),
- (const float *)&(V(ifm1, oj, oi,
- 0, 0, 0, 0, 0)));
- for (int tile_block = 1; tile_block < jcp.tile_block;
- tile_block++) {
- kernel_->gemm_loop_ker((float *)&(U(ifm1, ofm1,
- oj, oi,
- 0, 0, 0, 0)),
- (const float *)&(M(ofm1, oj, oi, tile_block,
- 0, 0, 0, 0)),
- (const float *)&(V(ifm1, oj, oi, tile_block,
- 0, 0, 0, 0)));
- }
- });
- }
-
-PRAGMA_OMP(barrier)
-
- parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block,
- [&](int ifm1, int ofm1, int ofm2, int ifm2) {
- diff_weights_transform_bwd_weights(jcp,
- &(diff_weights(ofm1 * jcp.oc_block + ofm2,
- ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)),
- &(U(ifm1, ofm1, 0, 0, ofm2, ifm2, 0, 0)));
- });
-
- if (jcp.with_bias) {
-PRAGMA_OMP(for)
- for (int ofm1 = 0; ofm1 < jcp.oc / simd_w; ofm1++) {
- for (int ithr = 0; ithr < nthreads; ithr++) {
- float* base_bias_ptr = &(diff_bias(ofm1, 0));
- float* base_bias_prv_ptr = &(diff_bias_prv(
- ithr * jcp.oc + ofm1 * simd_w));
- PRAGMA_OMP_SIMD()
- for (int ofm2 = 0; ofm2 < simd_w; ofm2++) {
- base_bias_ptr[ofm2] += base_bias_prv_ptr[ofm2];
- }
- }
- }
- }
- }
-
- _maybe_execute_diff_bias_copy(ptr_diff_bias, scratchpad);
-}
-
-}
-}
-}
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp
deleted file mode 100644
index 6c76f37c72..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp
+++ /dev/null
@@ -1,318 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP
-#define CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "mkldnn_thread.hpp"
-
-#include "cpu_convolution_pd.hpp"
-#include "cpu_primitive.hpp"
-
-#include "jit_avx512_common_conv_winograd_kernel_f32.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-namespace winograd_avx512_common {
-inline void init_scratchpad(memory_tracking::registrar_t &scratchpad,
- const jit_conv_winograd_conf_t &jcp) {
- using namespace memory_tracking::names;
-
- size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc;
- size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic
- * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding);
- size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc
- * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding);
-
- scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M);
- scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M);
- scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M);
-
- if (jcp.sched_policy == WSCHED_WEI_S_D_G_W) {
- const int nthr = mkldnn_get_max_threads();
-
- size_t tr_src_sz = jcp.ver != ver_4fma ? 0 : (size_t)nthr
- * alpha * alpha * jcp.tile_4fma * jcp.ic_simd_block;
- scratchpad.book(key_conv_tr_src, sizeof(float) * tr_src_sz, PAGE_2M);
-
- size_t br_sz = jcp.with_bias ? nthr * jcp.oc : 0;
- scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M);
-
- size_t padded_bias_sz =
- jcp.with_bias && jcp.oc_without_padding != jcp.oc ? jcp.oc : 0;
- scratchpad.book(key_conv_padded_bias, sizeof(float) * padded_bias_sz);
- }
-}
-}
-
-template <bool is_fwd>
-struct _jit_avx512_common_convolution_winograd_t {
- _jit_avx512_common_convolution_winograd_t(
- const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr)
- : kernel_(nullptr), attr_(attr) {
- kernel_ = new _jit_avx512_common_conv_winograd_data_kernel_f32(jcp);
- }
-
- ~_jit_avx512_common_convolution_winograd_t() { delete kernel_; }
-
- protected:
- void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr,
- float *wei_ptr, float *bias_ptr,
- const memory_tracking::grantor_t &scratchpad) const;
- _jit_avx512_common_conv_winograd_data_kernel_f32 *kernel_;
- const primitive_attr_t *attr_;
-};
-
-struct jit_avx512_common_convolution_winograd_fwd_t
- : _jit_avx512_common_convolution_winograd_t<true>
- , public cpu_primitive_t
- {
- struct pd_t : public cpu_convolution_fwd_pd_t {
- pd_t(engine_t *engine, const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const typename pd_t::base_class *hint_fwd_pd)
- : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""),
- jit_avx512_common_convolution_winograd_fwd_t);
-
- status_t init() {
- bool ok = true
- && is_fwd()
- && utils::one_of(desc()->alg_kind,
- alg_kind::convolution_auto,
- alg_kind::convolution_winograd)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::f32, data_type::f32, data_type::f32)
- && !has_zero_dim_memory()
- && set_default_formats();
- if (!ok) return status::unimplemented;
-
- status_t status = jit_avx512_common_conv_winograd_fwd_kernel_f32::
- init_conf(jcp_, *desc(), *src_md(), *weights_md(), *dst_md(),
- *attr());
- if (status != status::success) return status;
- set_default_alg_kind(alg_kind::convolution_winograd);
-
- auto scratchpad = scratchpad_registry().registrar();
- winograd_avx512_common::init_scratchpad(scratchpad, jcp_);
-
- return status;
- }
-
- jit_conv_winograd_conf_t jcp_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
- auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o;
- return set_default_formats_common(nChw16c, wei_tag, nChw16c);
- }
- };
-
- jit_avx512_common_convolution_winograd_fwd_t(const pd_t *apd)
- : _jit_avx512_common_convolution_winograd_t<true>(apd->jcp_, apd->attr())
- , cpu_primitive_t(apd, true) {}
-
- ~jit_avx512_common_convolution_winograd_fwd_t(){};
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override
- {
- auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST);
- this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights,
- (float *)bias, this->scratchpad(ctx));
- return status::success;
- }
-
-private:
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-struct jit_avx512_common_convolution_winograd_bwd_data_t
- : _jit_avx512_common_convolution_winograd_t<false>,
- public cpu_primitive_t {
- struct pd_t : public cpu_convolution_bwd_data_pd_t {
- pd_t(engine_t *engine, const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""),
- jit_avx512_common_convolution_winograd_bwd_data_t);
-
- status_t init() {
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_data
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::undef, data_type::f32, data_type::f32)
- && utils::one_of(desc()->alg_kind,
- alg_kind::convolution_auto,
- alg_kind::convolution_winograd)
- && !has_zero_dim_memory()
- && set_default_formats()
- && mkldnn_thr_syncable();
- if (!ok) return status::unimplemented;
-
- status_t status =
- jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf(
- jcp_, *desc(), *diff_src_md(), *weights_md(),
- *diff_dst_md());
- if (status != status::success) return status;
- set_default_alg_kind(alg_kind::convolution_winograd);
-
- auto scratchpad = scratchpad_registry().registrar();
- winograd_avx512_common::init_scratchpad(scratchpad, jcp_);
-
- return status;
- }
-
- jit_conv_winograd_conf_t jcp_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
- auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o;
- return set_default_formats_common(nChw16c, wei_tag, nChw16c);
- }
- };
-
- jit_avx512_common_convolution_winograd_bwd_data_t(const pd_t *apd)
- : _jit_avx512_common_convolution_winograd_t<false>(apd->jcp_, apd->attr())
- , cpu_primitive_t(apd, true) {}
-
- ~jit_avx512_common_convolution_winograd_bwd_data_t(){};
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST);
- auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS);
- auto diff_src = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC);
- this->_execute_data_W_S_G_D((float *)diff_dst, diff_src,
- (float *)weights, nullptr, this->scratchpad(ctx));
- return status::success;
- }
-
-private:
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-struct jit_avx512_common_convolution_winograd_bwd_weights_t
- : public cpu_primitive_t {
- struct pd_t : public cpu_convolution_bwd_weights_pd_t {
- pd_t(engine_t *engine, const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr,
- hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""),
- jit_avx512_common_convolution_winograd_bwd_weights_t);
-
- status_t init() {
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_weights
- && utils::one_of(desc()->alg_kind,
- alg_kind::convolution_auto,
- alg_kind::convolution_winograd)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::f32, data_type::f32, data_type::f32)
- && !has_zero_dim_memory()
- && set_default_formats()
- && mkldnn_thr_syncable();
- if (!ok) return status::unimplemented;
-
- status_t status =
- jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::
- init_conf(jcp_, *desc(), *src_md(), *diff_dst_md(),
- *diff_weights_md());
- if (status != status::success) return status;
- set_default_alg_kind(alg_kind::convolution_winograd);
-
- auto scratchpad = scratchpad_registry().registrar();
- winograd_avx512_common::init_scratchpad(scratchpad, jcp_);
-
- return status;
- }
-
- jit_conv_winograd_conf_t jcp_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
- auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o;
- return set_default_formats_common(nChw16c, wei_tag, nChw16c);
- }
- };
-
- jit_avx512_common_convolution_winograd_bwd_weights_t(const pd_t *apd)
- : cpu_primitive_t(apd, true), kernel_(nullptr)
- {
- kernel_ = new jit_avx512_common_conv_winograd_bwd_weights_kernel_f32(
- pd()->jcp_);
- }
-
- ~jit_avx512_common_convolution_winograd_bwd_weights_t()
- { delete kernel_; }
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override
- {
- _execute_backward_weights_S_D_G_W(ctx, scratchpad(ctx));
- return status::success;
- }
-
-private:
- void _execute_backward_weights_S_D_G_W(const exec_ctx_t &ctx,
- const memory_tracking::grantor_t &scratchpad) const;
- void _maybe_execute_diff_bias_copy(float *diff_bias,
- const memory_tracking::grantor_t &scratchpad) const;
-
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- jit_avx512_common_conv_winograd_bwd_weights_kernel_f32 *kernel_;
-};
-
-void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]);
-void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]);
-void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16]);
-void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16]);
-void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16]);
-void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16]);
-void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16]);
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp
deleted file mode 100644
index d4a451c021..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp
+++ /dev/null
@@ -1,853 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "jit_avx512_common_lrn.hpp"
-
-#include "jit_generator.hpp"
-
-#define FWD_RBC 4
-#define BWD_RBC 3
-
-#define XMM_SIZE (4*sizeof(float))
-#define ZMM_SIZE (vlen)
-#define BUFFER_BLOCK (XMM_SIZE + ZMM_SIZE + XMM_SIZE)
-#define BUFFER_NEXT_OFFSET (XMM_SIZE + ZMM_SIZE)
-#define SRC_PREV_OFFSET (vlen - XMM_SIZE)
-
-#define IRB_LOOP(statement) for(int irb = 0; irb < loop_size; irb++) { \
- statement;\
-}
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::utils;
-
-using namespace Xbyak;
-
-enum params { vsize = 16, vlen = 64};
-
-typedef struct {
- const float *src;
- float *dst, *ws0, *ws1;
-} jit_args_fwd_t;
-
-typedef struct {
- const float *src, *diff_dst, *ws0, *ws1;
- float *diff_src;
-} jit_args_bwd_t;
-
-struct nChw16c_across {
-/* version:
- * -1: channels 0..15,
- * 1: channels C-16 .. C-1,
- * 0: other channels
- * 3: channels only for this kernel(without prev and next)
- */
- int H, W, version;
- nChw16c_across(int h, int w, int v) : H(h), W(w), version(v) {}
-};
-
-struct jit_avx512_common_lrn_fwd_t::jit_avx512_common_lrn_kernel_f32:
- public jit_generator {
- int HW, W;
- bool is_first;
- bool is_last;
- bool is_single;
-
- Reg64 src = rax;
- Reg64 dst = r8;
- Reg64 scratch0 = rdx;
- Reg64 scratch1 = rsi;
- Reg64 imm_addr64 = rbx;
-
- Zmm zalpha = zmm0;
- Xmm xalpha = xmm0;
- Zmm zk = zmm1;
- Xmm xk = xmm1;
-
- Reg64 param = abi_param1;
- Reg64 t = rsp;
- Reg64 hw = r9;
-
- int xsrc_prev = 2;
- int zsrc = 7;
- int xsrc_next = 3;
- int zc = 7;
-
- int za = 2;
- int zb = 3;
- int zd = 5;
- int ze = 6;
- int zsum = 4;
- int zdst = 2;
- int zbase = 3;
- int zsum2 = 5;
-
- prop_kind_t pk;
- int use_h_parallelism;
-
- float alpha, k;
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_f32)
-
- void (*ker)(jit_args_fwd_t *);
- void operator()(jit_args_fwd_t *arg) { ker(arg); }
-
- enum {
- prf0_offt = 1*FWD_RBC,
- prf2_offt = 8*FWD_RBC
- };
-
- inline void compute_loop(int loop_size_param)
- {
- // loop_size - param for IRB_LOOP macro
- int loop_size = FWD_RBC;
-
- auto xreg = [=](int irb, int i) {
- return Xmm(irb*3 + i);
- };
-
- auto zreg = [=](int irb, int i) {
- return Zmm(irb*7 + i);
- };
-
- if (!is_first && !is_single) {
- IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt - HW)*vlen]));
- IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt - HW)*vlen]));
- }
- IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(src, (irb + prf0_offt)*vlen)));
- IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(src, (irb + prf2_offt)*vlen)));
- if (!is_last && !is_single) {
- IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt + HW)*vlen]));
- IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt + HW)*vlen]));
- }
- if (pk != prop_kind::forward_inference) {
- IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(scratch0,
- (irb + prf0_offt)*vlen)));
- IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(scratch0,
- (irb + prf2_offt)*vlen)));
- }
- IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(dst, (irb + prf0_offt)*vlen)));
- IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(dst, (irb + prf2_offt)*vlen)));
- if (pk != prop_kind::forward_inference) {
- IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(scratch1,
- (irb + prf0_offt) * vlen)));
- IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(scratch1,
- (irb + prf2_offt) * vlen)));
- }
-
- loop_size = loop_size_param;
- if (loop_size == 0)
- return;
- if (!is_first && !is_single) {
- IRB_LOOP(vmovups(xreg(irb, xsrc_prev),
- ptr[src + (irb - HW) * vlen + SRC_PREV_OFFSET]));
- }
- IRB_LOOP(vmovups(zreg(irb, zsrc), EVEX_compress_addr(src,irb*vlen)));
- if (!is_last && !is_single) {
- IRB_LOOP(vmovups(xreg(irb, xsrc_next),
- ptr[src + (irb + HW) * vlen]));
- }
-
- if (!is_first && !is_single) {
- IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK],
- xreg(irb, xsrc_prev)));
- }
- IRB_LOOP(vmovups(EVEX_compress_addr(t, irb*BUFFER_BLOCK + XMM_SIZE),
- zreg(irb, zsrc)));
- if (!is_last && !is_single) {
- IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET],
- xreg(irb, xsrc_next)));
- }
-
- IRB_LOOP(vmovups(zreg(irb, za), EVEX_compress_addr(t, irb*BUFFER_BLOCK
- + XMM_SIZE - 2*sizeof(float))));
- IRB_LOOP(vmovups(zreg(irb, zb), EVEX_compress_addr(t, irb*BUFFER_BLOCK
- + XMM_SIZE - sizeof(float))));
- IRB_LOOP(vmovups(zreg(irb, zd), EVEX_compress_addr(t, irb*BUFFER_BLOCK
- + XMM_SIZE + sizeof(float))));
- IRB_LOOP(vmovups(zreg(irb, ze), EVEX_compress_addr(t, irb*BUFFER_BLOCK
- + XMM_SIZE + 2*sizeof(float))));
-
- assert(zc == zsrc);
- IRB_LOOP(vmulps(zreg(irb, zsum), zreg(irb, zc), zreg(irb, zc)));
-
- IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, za), zreg(irb, za)));
- IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, zb), zreg(irb, zb)));
- IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, zd), zreg(irb, zd)));
- IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, ze), zreg(irb, ze)));
-
- IRB_LOOP(vfmadd132ps(zreg(irb, zsum), zk, zalpha));
-
- IRB_LOOP(vmovaps(zreg(irb, zbase), zreg(irb, zsum)));
-
- IRB_LOOP(vmulps(zreg(irb, zsum2), zreg(irb, zsum), zreg(irb, zsum)));
- IRB_LOOP(vmulps(zreg(irb, zsum), zreg(irb, zsum), zreg(irb, zsum2)));
-
- IRB_LOOP(vsqrtps(zreg(irb, zsum), zreg(irb, zsum)));
- IRB_LOOP(vsqrtps(zreg(irb, zsum), zreg(irb, zsum)));
-
- if (pk != prop_kind::forward_inference) {
- IRB_LOOP(vmovups(EVEX_compress_addr(scratch0, irb*vlen),
- zreg(irb, zsum)));
- }
- IRB_LOOP(vdivps(zreg(irb, zdst), zreg(irb, zsrc), zreg(irb, zsum)));
- IRB_LOOP(vmovups(EVEX_compress_addr(dst, irb*vlen), zreg(irb, zdst)));
- if (pk != prop_kind::forward_inference) {
- /* ws1 = zdst / zbase = zsrc / (zbase^1.75) */
- IRB_LOOP(vdivps(zreg(irb, zsum), zreg(irb, zdst), zreg(irb, zbase)));
- IRB_LOOP(vmovups(EVEX_compress_addr(scratch1, irb*vlen),
- zreg(irb, zsum)));
- }
- }
-
- jit_avx512_common_lrn_kernel_f32(
- const struct nChw16c_across &J,
- prop_kind_t prop_kind,
- int use_h_parallel,
- float A,
- float K,
- void *code_ptr = nullptr,
- size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE)
- : jit_generator(code_ptr, code_size)
- , pk(prop_kind)
- , use_h_parallelism(use_h_parallel)
- , alpha(A)
- , k(K)
- {
- this->preamble();
-
- mov(src, ptr[param + 0]);
- mov(dst, ptr[param + 8]);
- if (pk != prop_kind::forward_inference)
- {
- mov(scratch0, ptr[param + 16]);
- mov(scratch1, ptr[param + 24]);
- }
- is_first = J.version == -1 || J.version == -2;
- is_last = J.version == +1 || J.version == -2;
- is_single = J.version == 3;
-
- W = J.W;
- HW = J.W*J.H;
- int LSB = use_h_parallelism ? W : HW;
-
- sub(t, FWD_RBC*BUFFER_BLOCK);
- mov(imm_addr64, float2int(this->alpha));
- movq(xalpha, imm_addr64);
- vbroadcastss(zalpha, xalpha);
-
- mov(imm_addr64, float2int(this->k));
- movq(xk, imm_addr64);
- vbroadcastss(zk, xk);
-
- if (is_first || is_single) {
- vxorps(xmm2, xmm2, xmm2);
- for(int irb = 0; irb < FWD_RBC; irb++) {
- vmovups(ptr[t + irb*BUFFER_BLOCK], xmm2);
- }
- }
- if (is_last || is_single) {
- vxorps(xmm2, xmm2, xmm2);
- for(int irb = 0; irb < FWD_RBC; irb++) {
- vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET],
- xmm2);
- }
- }
-
- int LSREST = LSB % FWD_RBC;
- int LS = LSB - LSREST;
-
- Label lrn_loop;
-
- if (LS > 0) {
- mov(hw, LS);
-
- L(lrn_loop);
- {
- compute_loop(FWD_RBC);
-
- add(src, FWD_RBC*vlen);
- add(dst, FWD_RBC*vlen);
- if (pk != prop_kind::forward_inference)
- {
- add(scratch0, FWD_RBC*vlen);
- add(scratch1, FWD_RBC*vlen);
- }
-
- for(int irb = 0; irb < FWD_RBC; irb++)
- dec(hw);
- cmp(hw, 0);
- jne(lrn_loop, T_NEAR);
- }
- }
-
- compute_loop(LSREST);
-
- add(t, FWD_RBC*BUFFER_BLOCK);
- this->postamble();
-
- ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
- this->getCode()));
- }
-};
-
-status_t jit_avx512_common_lrn_fwd_t::pd_t::init() {
- using namespace prop_kind;
- using namespace alg_kind;
-
- const memory_desc_wrapper data_d(src_md());
- bool ok = true
- && mayiuse(avx512_common)
- && is_fwd()
- && !has_zero_dim_memory()
- && everyone_is(data_type::f32, data_d.data_type())
- && data_d.ndims() == 4
- && data_d.dims()[1] % vsize == 0
- && attr()->has_default_values();
- if (!ok) return unimplemented;
-
- if (desc()->prop_kind == forward_training) {
- dims_t ws_dims = { MB(), C(), H(), 2*W() };
- mkldnn_memory_desc_init_by_tag(&ws_md_, 4, ws_dims, data_type::f32,
- format_tag::nChw16c);
- }
-
- bool args_ok_across = true
- && desc()->alg_kind == lrn_across_channels
- && desc()->local_size == 5
- && desc()->lrn_beta == 0.75
- && data_d.matches_tag(format_tag::nChw16c);
-
- return args_ok_across ? success : unimplemented;
-}
-
-jit_avx512_common_lrn_fwd_t::jit_avx512_common_lrn_fwd_t(const pd_t *apd)
- : cpu_primitive_t(apd)
- , use_h_parallelism(0), ker_(nullptr), ker_first_(nullptr)
- , ker_last_(nullptr) {
- using namespace alg_kind;
- const int C = pd()->C();
- const int H = pd()->H();
- const int W = pd()->W();
- const int ls = pd()->desc()->local_size;
- const float alpha = pd()->desc()->lrn_alpha / ls;
- const float k = pd()->desc()->lrn_k;
-
- auto pk = pd()->desc()->prop_kind;
-
- use_h_parallelism = H > 28 ? 1 : 0;
-
- if (C / vsize == 1) {
- ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 3), pk,
- use_h_parallelism, alpha, k);
- } else {
- ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 0), pk,
- use_h_parallelism, alpha, k);
- ker_first_ = new jit_avx512_common_lrn_kernel_f32(
- nChw16c_across(H, W, -1), pk, use_h_parallelism, alpha, k);
- ker_last_ = new jit_avx512_common_lrn_kernel_f32(
- nChw16c_across(H, W, +1), pk, use_h_parallelism, alpha, k);
- }
-}
-
-jit_avx512_common_lrn_fwd_t::~jit_avx512_common_lrn_fwd_t()
-{ delete ker_; delete ker_first_; delete ker_last_; }
-
-void jit_avx512_common_lrn_fwd_t::execute_forward(const exec_ctx_t &ctx) const
-{
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
- auto ws = CTX_OUT_MEM(data_t *, MKLDNN_ARG_WORKSPACE);
-
- const int N = pd()->MB();
- const int C = pd()->C();
- const int H = pd()->H();
- const int W = pd()->W();
-
- parallel(0, [&](const int ithr, const int nthr) {
- size_t start{0}, end{0};
- const int C16 = C / vsize;
- const size_t work_amount = use_h_parallelism ? N*C16*H : N*C16;
-
- balance211(work_amount, nthr, ithr, start, end);
- if (use_h_parallelism) {
- int n{0}, c16{0}, h{0};
- nd_iterator_init(start, n, N, c16, C16, h, H);
- for (size_t iwork = start; iwork < end; ++iwork) {
- auto offset = n*C*H*W + c16*H*W*vsize
- + h*W*vsize;
- auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize
- + h*2*W*vsize;
- auto ws_offset1 = ws_offset0 + W*vsize;
-
- jit_args_fwd_t args;
- args.src = &src[offset];
- args.dst = &dst[offset];
- args.ws0 = &ws[ws_offset0];
- args.ws1 = &ws[ws_offset1];
-
- if (C16 == 1)
- (*ker_)(&args);
- else if (c16 == 0)
- (*ker_first_)(&args);
- else if (c16 == C16 - 1)
- (*ker_last_)(&args);
- else
- (*ker_)(&args);
- nd_iterator_step(n, N, c16, C16, h, H);
- }
- } else {
- int n{0}, c16{0};
- nd_iterator_init(start, n, N, c16, C16);
- for (size_t iwork = start; iwork < end; ++iwork) {
- auto offset = n*C*H*W + c16*H*W*vsize;
- auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize;
- auto ws_offset1 = ws_offset0 + H*W*vsize;
-
- jit_args_fwd_t args;
- args.src = &src[offset];
- args.dst = &dst[offset];
- args.ws0 = &ws[ws_offset0];
- args.ws1 = &ws[ws_offset1];
-
- if (C16 == 1)
- (*ker_)(&args);
- else if (c16 == 0)
- (*ker_first_)(&args);
- else if (c16 == C16 - 1)
- (*ker_last_)(&args);
- else
- (*ker_)(&args);
-
- nd_iterator_step(n, N, c16, C16);
- }
- }
- });
-}
-
-struct jit_avx512_common_lrn_bwd_t::jit_avx512_common_lrn_kernel_f32:
- public jit_generator {
- int HW, W;
- bool is_first;
- bool is_last;
- bool is_single;
-
- Reg64 src = rax;
- Reg64 diffsrc = r8;
- Reg64 diffdst = r9;
- Reg64 workspace0 = rdx;
- Reg64 workspace1 = rsi;
- Reg64 imm_addr64 = rbx;
-
- Zmm znalphabeta = zmm0;
- Xmm xnalphabeta = xmm0;
-
- Reg64 param = abi_param1;
- Reg64 t = rsp;
- Reg64 hw = r10;
-
- int xws1_prev = 1;
- int xdiffdst_prev = 2;
- int zws1 = 1;
-
- int zsrc = 1;
- int zdiffdst = 5;
- int zdiffsrc = 6;
-
- int xws1_next = 1;
- int xdiffdst_next = 3;
-
- int za = 1;
- int zb = 2;
- int zd = 3;
- int ze = 4;
- int zws0 = 2;
-
- float nalphabeta;
-
- int use_h_parallelism;
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_f32)
-
- void (*ker)(jit_args_bwd_t *);
- void operator()(jit_args_bwd_t *arg) { ker(arg); }
-
- enum {
- prf0_offt = 1*BWD_RBC,
- prf2_offt = 8*BWD_RBC
- };
-
- inline void compute_loop(int loop_size_param, int prefetchL1,
- int prefetchL2)
- {
- // loop_size - param for IRB_LOOP macro
- int loop_size = loop_size_param;
-
- auto xreg = [=](int irb, int i) {
- return Xmm(irb*6 + i);
- };
-
- auto zreg = [=](int irb, int i) {
- return Zmm(irb*6 + i);
- };
-
-// ---- prefetching -------------------------------------------
- if (!is_first && !is_single) {
- if (prefetchL1)
- IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt
- - 2 * HW) * vlen]));
- if (prefetchL1)
- IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt
- - HW) * vlen]));
- }
-
- if (prefetchL1)
- IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt)*vlen]));
- if (prefetchL2)
- IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt)*vlen]));
-
- if (prefetchL1)
- IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt)*vlen]));
-
- if (prefetchL1)
- IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt)*vlen]));
-
- if (!is_last && !is_single) {
- if (prefetchL1)
- IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt
- + 2 * HW) * vlen]));
- if (prefetchL2)
- IRB_LOOP(mic_prefetcht2(ptr[workspace1 + (irb + prf2_offt
- + 2 * HW) * vlen]));
-
- if (prefetchL1)
- IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt
- + HW) * vlen]));
- if (prefetchL2)
- IRB_LOOP(mic_prefetcht2(ptr[diffdst + (irb + prf2_offt
- + HW) * vlen]));
- }
- if (prefetchL1)
- IRB_LOOP(mic_prefetcht0(ptr[workspace0 + (irb + prf0_offt)*vlen]));
- if (prefetchL2)
- IRB_LOOP(mic_prefetcht2(ptr[workspace0 + (irb + prf2_offt)*vlen]));
-// -----------------------------------------------------------
-
- if (loop_size_param == 0)
- return;
-
- if (!is_first && !is_single) {
- IRB_LOOP(vmovups(xreg(irb, xws1_prev), ptr[workspace1 + (irb
- - 2 * HW) * vlen + SRC_PREV_OFFSET]));
- IRB_LOOP(vmovups(xreg(irb, xdiffdst_prev), ptr[diffdst + (irb
- - HW) * vlen + SRC_PREV_OFFSET]));
- IRB_LOOP(vmulps(xreg(irb, xdiffdst_prev), xreg(irb, xdiffdst_prev),
- xreg(irb, xws1_prev)));
- }
-
- IRB_LOOP(vmovups(zreg(irb, zws1),
- EVEX_compress_addr(workspace1, irb*vlen)));
- IRB_LOOP(vmovups(zreg(irb, zdiffdst),
- EVEX_compress_addr(diffdst, irb*vlen)));
- IRB_LOOP(vmulps(zreg(irb, zdiffsrc), zreg(irb, zdiffdst),
- zreg(irb, zws1)));
-
- if (!is_last && !is_single) {
- IRB_LOOP(vmovups(xreg(irb, xws1_next), ptr[workspace1 + (irb
- + 2 * HW) * vlen]));
- IRB_LOOP(vmovups(xreg(irb, xdiffdst_next), ptr[diffdst + (irb
- + HW) * vlen]));
- IRB_LOOP(vmulps(xreg(irb, xdiffdst_next), xreg(irb, xdiffdst_next),
- xreg(irb, xws1_next)));
- }
-
- if (!is_first && !is_single) {
- IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK],
- xreg(irb, xdiffdst_prev)));
- }
- IRB_LOOP(vmovups(EVEX_compress_addr(t, irb*BUFFER_BLOCK + XMM_SIZE),
- zreg(irb, zdiffsrc)));
- if (!is_last && !is_single) {
- IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET],
- xreg(irb, xdiffdst_next)));
- }
-
- IRB_LOOP(vmovups(zreg(irb, za), EVEX_compress_addr(t, irb*BUFFER_BLOCK
- + XMM_SIZE - 2*sizeof(float))));
- IRB_LOOP(vmovups(zreg(irb, zb), EVEX_compress_addr(t, irb*BUFFER_BLOCK
- + XMM_SIZE - 1*sizeof(float))));
- IRB_LOOP(vmovups(zreg(irb, zd), EVEX_compress_addr(t, irb*BUFFER_BLOCK
- + XMM_SIZE + 1*sizeof(float))));
- IRB_LOOP(vmovups(zreg(irb, ze), EVEX_compress_addr(t, irb*BUFFER_BLOCK
- + XMM_SIZE + 2*sizeof(float))));
- IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc),
- zreg(irb, za)));
- assert(zsrc == za);
- IRB_LOOP(vmovups(zreg(irb, zsrc), EVEX_compress_addr(src, irb*vlen)));
- IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc),
- zreg(irb, zb)));
- IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc),
- zreg(irb, zd)));
- IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc),
- zreg(irb, ze)));
- IRB_LOOP(vmulps(zreg(irb, zsrc), zreg(irb, zsrc), znalphabeta));
-
- IRB_LOOP(vmovups(zreg(irb, zws0),
- EVEX_compress_addr(workspace0, irb*vlen)));
- IRB_LOOP(vdivps(zreg(irb, zdiffdst), zreg(irb, zdiffdst),
- zreg(irb, zws0)));
- IRB_LOOP(vfmadd213ps(zreg(irb, zdiffsrc), zreg(irb, zsrc),
- zreg(irb, zdiffdst)));
-
- Label unaligned_store, end_store;
- test(diffsrc, vlen - 1);
- jnz(unaligned_store, T_NEAR);
- IRB_LOOP(uni_vmovntps(EVEX_compress_addr(diffsrc, irb*vlen),
- zreg(irb, zdiffsrc)));
- jmp(end_store, T_NEAR);
- L(unaligned_store); {
- IRB_LOOP(uni_vmovups(EVEX_compress_addr(diffsrc, irb*vlen),
- zreg(irb, zdiffsrc)));
- }
- L(end_store);
- }
-
- jit_avx512_common_lrn_kernel_f32(
- const struct nChw16c_across &J,
- float A,
- float B,
- int use_h_parallel,
- void *code_ptr = nullptr,
- size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE)
- : jit_generator(code_ptr, code_size)
- , nalphabeta(-2*A*B)
- , use_h_parallelism(use_h_parallel)
- {
- this->preamble();
-
- mov(src, ptr[param + 0]);
- mov(diffdst, ptr[param + 8]);
- mov(workspace0, ptr[param + 16]);
- mov(workspace1, ptr[param + 24]);
- mov(diffsrc, ptr[param + 32]);
-
- W = J.W;
- HW = J.H*J.W;
- int LSB = this->use_h_parallelism ? W : HW;
-
- sub(t, BWD_RBC*BUFFER_BLOCK);
- mov(imm_addr64, float2int(this->nalphabeta));
- movq(xnalphabeta, imm_addr64);
- vbroadcastss(znalphabeta, xnalphabeta);
-
- is_first = J.version == -1 || J.version == -2;
- is_last = J.version == +1 || J.version == +2;
- is_single = J.version == 3;
-
- if (is_first || is_single) {
- vxorps(xmm1, xmm1, xmm1);
- for(int irb = 0; irb < BWD_RBC; irb++) {
- vmovups(ptr[t + irb*BUFFER_BLOCK], xmm1);
- }
- }
- if (is_last || is_single) {
- vxorps(xmm1, xmm1, xmm1);
- for(int irb = 0; irb < BWD_RBC; irb++) {
- vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], xmm1);
- }
- }
-
- int LSREST = LSB % BWD_RBC;
- int LS = LSB - LSREST;
-
- Label lrn_loop;
-
- if (LS > 0) {
- mov(hw, LS);
-
- L(lrn_loop);
- {
- compute_loop(BWD_RBC, 1, 1);
-
- add(src, BWD_RBC*vlen);
- add(diffsrc, BWD_RBC*vlen);
- add(diffdst, BWD_RBC*vlen);
- add(workspace0, BWD_RBC*vlen);
- add(workspace1, BWD_RBC*vlen);
-
- for(int irb = 0; irb < BWD_RBC; irb++)
- dec(hw);
- cmp(hw, 0);
- jne(lrn_loop, T_NEAR);
- }
- }
-
- compute_loop(LSREST, 1, this->use_h_parallelism ? 0 : 1);
-
- add(t, BWD_RBC*BUFFER_BLOCK);
- this->postamble();
-
- ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
- this->getCode()));
- }
-
-};
-
-status_t jit_avx512_common_lrn_bwd_t::pd_t::init() {
- using namespace alg_kind;
-
- const memory_desc_wrapper data_d(src_md());
- bool ok = true
- && mayiuse(avx512_common)
- && !is_fwd()
- && utils::everyone_is(data_type::f32, data_d.data_type())
- && !has_zero_dim_memory()
- && data_d.ndims() == 4
- && data_d.dims()[1] % vsize == 0
- && attr()->has_default_values();
- if (!ok) return unimplemented;
-
- dims_t ws_dims = { MB(), C(), H(), 2*W() };
- mkldnn_memory_desc_init_by_tag(&ws_md_, 4, ws_dims, data_type::f32,
- format_tag::nChw16c);
-
- if (!compare_ws(hint_fwd_pd_)) return unimplemented;
-
- bool args_ok_across = true
- && desc()->alg_kind == lrn_across_channels
- && desc()->local_size == 5
- && desc()->lrn_beta == 0.75
- && data_d.matches_tag(format_tag::nChw16c);
-
- return args_ok_across ? success : unimplemented;
-}
-
-jit_avx512_common_lrn_bwd_t::jit_avx512_common_lrn_bwd_t(const pd_t *apd)
- : cpu_primitive_t(apd)
- , use_h_parallelism(0), ker_(nullptr), ker_first_(nullptr)
- , ker_last_(nullptr) {
- const int C = pd()->C();
- const int H = pd()->H();
- const int W = pd()->W();
- const int ls = pd()->desc()->local_size;
- const float alpha = pd()->desc()->lrn_alpha / ls;
- const float beta = pd()->desc()->lrn_beta;
-
- use_h_parallelism = H > 28 ? 1 : 0;
-
- if (C / vsize == 1) {
- ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 3),
- alpha, beta, use_h_parallelism);
- } else {
- ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 0),
- alpha, beta, use_h_parallelism);
- ker_first_ = new jit_avx512_common_lrn_kernel_f32(
- nChw16c_across(H, W, -1), alpha, beta, use_h_parallelism);
- ker_last_ = new jit_avx512_common_lrn_kernel_f32(
- nChw16c_across(H, W, +1), alpha, beta, use_h_parallelism);
- }
-}
-
-jit_avx512_common_lrn_bwd_t::~jit_avx512_common_lrn_bwd_t()
-{ delete ker_; delete ker_first_; delete ker_last_; }
-
-void jit_avx512_common_lrn_bwd_t::execute_backward(const exec_ctx_t &ctx) const
-{
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto ws = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WORKSPACE);
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const int N = pd()->MB();
- const int C = pd()->C();
- const int H = pd()->H();
- const int W = pd()->W();
-
- parallel(0, [&](const int ithr, const int nthr) {
- size_t start{0}, end{0};
- const int C16 = C / vsize;
- const size_t work_amount = use_h_parallelism ? N*C16*H : N*C16;
-
- balance211(work_amount, nthr, ithr, start, end);
- if (use_h_parallelism) {
- int n{0}, c16{0}, h{0};
- nd_iterator_init(start, n, N, h, H, c16, C16);
- for (size_t iwork = start; iwork < end; ++iwork) {
- auto offset = n*C*H*W + c16*H*W*vsize
- + h*W*vsize;
- auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize
- + h*2*W*vsize;
- auto ws_offset1 = ws_offset0 + W*vsize;
-
- jit_args_bwd_t args;
- args.src = &src[offset];
- args.diff_dst = &diff_dst[offset];
- args.ws0 = &ws[ws_offset0];
- args.ws1 = &ws[ws_offset1];
- args.diff_src = &diff_src[offset];
-
- if (C16 == 1)
- (*ker_)(&args);
- else if (c16 == 0)
- (*ker_first_)(&args);
- else if (c16 == C16 - 1)
- (*ker_last_)(&args);
- else
- (*ker_)(&args);
- nd_iterator_step(n, N, h, H, c16, C16);
- }
- } else {
- int n{0}, c16{0};
- nd_iterator_init(start, n, N, c16, C16);
- for (size_t iwork = start; iwork < end; ++iwork) {
- auto offset = n*C*H*W + c16*H*W*vsize;
- auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize;
- auto ws_offset1 = ws_offset0 + H*W*vsize;
-
- jit_args_bwd_t args;
- args.src = &src[offset];
- args.diff_dst = &diff_dst[offset];
- args.ws0 = &ws[ws_offset0];
- args.ws1 = &ws[ws_offset1];
- args.diff_src = &diff_src[offset];
-
- if (C16 == 1)
- (*ker_)(&args);
- else if (c16 == 0)
- (*ker_first_)(&args);
- else if (c16 == C16 - 1)
- (*ker_last_)(&args);
- else
- (*ker_)(&args);
-
- nd_iterator_step(n, N, c16, C16);
- }
- }
- });
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp
deleted file mode 100644
index 37fbb9b3e5..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp
+++ /dev/null
@@ -1,96 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_AVX512_COMMON_LRN_HPP
-#define CPU_JIT_AVX512_COMMON_LRN_HPP
-
-#include "c_types_map.hpp"
-
-#include "cpu_isa_traits.hpp"
-#include "cpu_lrn_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct jit_avx512_common_lrn_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_lrn_fwd_pd_t {
- using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t;
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""),
- jit_avx512_common_lrn_fwd_t);
-
- status_t init();
- };
-
- jit_avx512_common_lrn_fwd_t(const pd_t *apd);
- ~jit_avx512_common_lrn_fwd_t();
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- int use_h_parallelism;
- struct jit_avx512_common_lrn_kernel_f32;
- jit_avx512_common_lrn_kernel_f32 *ker_, *ker_first_, *ker_last_;
-};
-
-struct jit_avx512_common_lrn_bwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_lrn_bwd_pd_t {
- using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t;
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""),
- jit_avx512_common_lrn_bwd_t);
-
- status_t init();
- };
-
- jit_avx512_common_lrn_bwd_t(const pd_t *apd);
- ~jit_avx512_common_lrn_bwd_t();
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward(ctx);
- return status::success;
- }
-
-private:
- void execute_backward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- int use_h_parallelism;
- struct jit_avx512_common_lrn_kernel_f32;
- jit_avx512_common_lrn_kernel_f32 *ker_, *ker_first_, *ker_last_;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp
deleted file mode 100644
index c58d3fa0a6..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp
+++ /dev/null
@@ -1,1103 +0,0 @@
-/*******************************************************************************
- * Copyright 2018 Intel Corporation
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *******************************************************************************/
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "jit_avx512_core_fp32_wino_conv_2x3.hpp"
-#include "jit_generator.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::format_kind;
-using namespace mkldnn::impl::memory_tracking::names;
-using namespace mkldnn::impl::utils;
-using namespace Xbyak;
-
-/// SRC TRANSFORMS /////////////////////////////////////////////////////////////
-struct jit_avx512_core_fp32_wino_conv_2x3_src_trans_t: public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(
- jit_avx512_core_fp32_wino_conv_2x3_src_trans_t)
-
- jit_conv_conf_2x3_wino_t jcp;
-
- struct call_params_t {
- const void *src;
- const void *wino_src;
- const void *v_y_masks;
- const void *v_x_masks;
- };
- void (*ker_)(const call_params_t *);
-
- jit_avx512_core_fp32_wino_conv_2x3_src_trans_t(
- jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
- : jcp(ajcp) {
- generate();
- ker_ =
- reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode()));
- }
-
- void generate();
-
- Zmm vreg_inp(int i) {
- assert(i < jcp.alpha * jcp.alpha);
- return Zmm(31 - i);
- }
-
- Zmm vreg_tmp(int i) {
- assert(i < jcp.alpha * jcp.alpha);
- return Zmm(15 - i);
- }
-
- Zmm vreg_out(int i) {
- assert(i < jcp.alpha * jcp.alpha);
- return Zmm(31 - i);
- }
-
- Opmask y_mask = Opmask(1);
- Opmask r_mask = Opmask(2);
- Opmask x_mask(int id) {
- assert (id < 4);
- return Opmask(3 + id);
- }
-
- Reg64 reg_ptr_v_y_masks = r12;
- Reg64 reg_ptr_v_x_masks = r11;
-
- Reg64 reg_aux_ptr_src = r10;
- Reg64 reg_aux_ptr_dst = r9;
-
- Reg64 reg_ic_block = r8;
-
-};
-
-void jit_avx512_core_fp32_wino_conv_2x3_src_trans_t::generate() {
- Label ic_block_label;
-
- const int load_block = 16;
- int out_offset = 0, inp_offset = 0;
- preamble();
-
-#define READ_PARAM(reg, field) \
- mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
- READ_PARAM(reg_aux_ptr_src, src);
- READ_PARAM(reg_aux_ptr_dst, wino_src);
- READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
- READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
-#undef READ_PARAM
-
- for (int i = 0; i < jcp.alpha; i++) {
- kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]);
- }
- mov(reg_ic_block, jcp.ic / load_block);
- L(ic_block_label);
- {
- for (int y = 0; y < jcp.alpha; y++) {
- kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(int16_t) * y]);
- for (int x = 0; x < jcp.alpha; x++) {
- Zmm zmm = vreg_inp(y * jcp.alpha + x);
-
- vxorps(zmm, zmm, zmm);
- kandw(r_mask, y_mask, x_mask(x));
- inp_offset = sizeof(float)
- * ((-jcp.t_pad + y) * jcp.iw * load_block
- + (-jcp.l_pad + x) * load_block);
- vmovups(zmm | r_mask,
- EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
- }
- }
- for (int y = 0; y < jcp.alpha; y++) {
- vsubps(vreg_tmp(y * jcp.alpha + 0), vreg_inp(y * jcp.alpha + 0),
- vreg_inp(y * jcp.alpha + 2));
- vaddps(vreg_tmp(y * jcp.alpha + 1), vreg_inp(y * jcp.alpha + 1),
- vreg_inp(y * jcp.alpha + 2));
- vsubps(vreg_tmp(y * jcp.alpha + 2), vreg_inp(y * jcp.alpha + 2),
- vreg_inp(y * jcp.alpha + 1));
- vsubps(vreg_tmp(y * jcp.alpha + 3), vreg_inp(y * jcp.alpha + 1),
- vreg_inp(y * jcp.alpha + 3));
- }
- for (int x = 0; x < jcp.alpha; x++) {
- vsubps(vreg_out(x + 0 * jcp.alpha), vreg_tmp(x + jcp.alpha * 0),
- vreg_tmp(x + jcp.alpha * 2));
- vaddps(vreg_out(x + 1 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1),
- vreg_tmp(x + jcp.alpha * 2));
- vsubps(vreg_out(x + 2 * jcp.alpha), vreg_tmp(x + jcp.alpha * 2),
- vreg_tmp(x + jcp.alpha * 1));
- vsubps(vreg_out(x + 3 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1),
- vreg_tmp(x + jcp.alpha * 3));
- }
-
- for (int i = 0; i < 16; i++) {
- out_offset = sizeof(float) * (jcp.inp_stride * i);
- vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset),
- vreg_out(i));
- }
-
- add(reg_aux_ptr_src, sizeof(float) * jcp.ih * jcp.iw * load_block);
- add(reg_aux_ptr_dst, sizeof(float) * load_block);
- }
- dec(reg_ic_block);
- cmp(reg_ic_block, 0);
- jg(ic_block_label, T_NEAR);
- postamble();
-}
-
-/// DST TRANSFORMS /////////////////////////////////////////////////////////////
-struct jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t: public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(
- jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t)
-
- jit_conv_conf_2x3_wino_t jcp;
- const primitive_attr_t &attr_;
-
- struct call_params_t {
- const void *wino_dst;
- const void *dst;
- const void *v_y_masks;
- const void *v_x_masks;
-
- const void *bias;
- const void *scales;
- };
- void (*ker_)(const call_params_t *);
-
- jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t(
- jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
- : jcp(ajcp), attr_(attr) {
- generate();
- ker_ = reinterpret_cast<decltype(ker_)>(
- const_cast<uint8_t *>(getCode()));
- }
-
- void generate();
- bool maybe_relu(int position);
-
- Zmm vreg_inp(int i) { // 16
- assert(i < jcp.alpha * jcp.alpha);
- return Zmm(31 - i);
- }
-
- Zmm vreg_stg(int id) { // 8
- const int id_reg_stg = jcp.alpha * jcp.alpha + id;
- assert(id_reg_stg < jcp.alpha * jcp.alpha + 8);
- return Zmm(31 - id_reg_stg);
- }
-
- Zmm vreg_out(int id) { // 4
- const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
- assert(id_reg_out < jcp.alpha * jcp.alpha + 12);
- return Zmm(31 - id_reg_out);
- }
-
- Zmm vreg_tmp(int id) { // 2
- const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id;
- assert(id_reg_tmp < jcp.alpha * jcp.alpha + 14);
- return Zmm(31 - id_reg_tmp);
- }
-
- Zmm vreg_zero = Zmm(0);
- Zmm vreg_prev_dst = Zmm(0);
- Zmm vreg_bias = Zmm(2);
-
- Opmask y_mask = Opmask(1);
- Opmask r_mask = Opmask(2);
- Opmask x_mask(int id) {
- assert (id < 4);
- return Opmask(3 + id);
- }
-
- Reg64 reg_ptr_v_y_masks = r12;
- Reg64 reg_ptr_v_x_masks = r11;
-
- Reg64 reg_aux_ptr_src = r10;
- Reg64 reg_aux_ptr_dst = r9;
-
- Reg64 reg_oc_block = r8;
-
- Reg64 reg_ptr_bias = rbx;
- Reg64 reg_ptr_scales = abi_not_param1;
- Reg64 reg_ptr_sum_scale = rdx;
-};
-
-bool jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::maybe_relu(int position) {
- using namespace primitive_kind;
- const auto &p = attr_.post_ops_;
-
- if (position == 0) {
- /* relu before sum */
- return false
- || p.contain(eltwise, 0);
- } else if (position == 1) {
- /* relu after sum */
- const int sum_idx = p.contain(sum, 0)
- ? 0 : (p.contain(sum, 1) ? 1 : -1);
- if (sum_idx == -1)
- return false;
-
- return false
- || p.contain(eltwise, sum_idx + 1);
- }
-
- return false;
-}
-
-void jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::generate() {
- Label oc_block_label;
-
- const int load_block = 16;
-
- auto loop_body = [=]() {
- const auto &p = attr_.post_ops_;
- const int sum_idx = p.find(primitive_kind::sum);
- const float *p_sum_scale = (sum_idx != -1)
- ? &p.entry_[sum_idx].sum.scale
- : nullptr;
- if (p_sum_scale && *p_sum_scale != 1.f)
- mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
-
- for (int i = 0; i < 16; i++) {
- int internal_offset = sizeof(float) * jcp.out_stride * i;
- vmovups(vreg_inp(i),
- EVEX_compress_addr(reg_aux_ptr_src, internal_offset));
- }
- for (int y = 0; y < jcp.alpha; y++) {
- vaddps(vreg_tmp(0), vreg_inp(y * 4 + 0), vreg_inp(y * 4 + 1));
- vaddps(vreg_stg(y * 2), vreg_tmp(0), vreg_inp(y * 4 + 2));
-
- vsubps(vreg_tmp(1), vreg_inp(y * 4 + 1), vreg_inp(y * 4 + 2));
- vsubps(vreg_stg(y * 2+1), vreg_tmp(1), vreg_inp(y * 4 + 3));
- }
- for (int x = 0; x < jcp.m; x++) {
- vaddps(vreg_tmp(0), vreg_stg(x), vreg_stg(x+2 * 1));
- vaddps(vreg_out(x), vreg_tmp(0), vreg_stg(x+2 * 2));
-
- vsubps(vreg_tmp(1), vreg_stg(x+2 * 1), vreg_stg(x+2 * 2));
- vsubps(vreg_out(x+2), vreg_tmp(1), vreg_stg(x+2 * 3));
- }
-
-
- if (jcp.with_bias) {
- auto bias_addr = ptr [ reg_ptr_bias ];
- vmovups(vreg_bias, bias_addr);
- }
- for (int y = 0; y < jcp.m; y++) {
- kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(int16_t) * y ]);
- for (int x = 0; x < jcp.m; x++) {
- kandw(r_mask, y_mask, x_mask(x));
-
- int i = y * jcp.m + x;
- int offset = sizeof(float) *
- (y * jcp.ow * jcp.oc_block + x * jcp.oc_block);
- Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset);
-
- Zmm zmm = vreg_out(i);
- if (jcp.with_bias)
- vaddps(zmm, zmm, vreg_bias);
- vmulps(zmm, zmm, ptr [reg_ptr_scales]);
-
- if (maybe_relu(0)) {
- vxorps(vreg_zero, vreg_zero, vreg_zero);
- vmaxps(zmm, vreg_zero, zmm);
- }
- if (p_sum_scale) { // post_op: sum
- vxorps(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst);
- vmovups(vreg_prev_dst | r_mask, addr);
- if (*p_sum_scale == 1.f)
- vaddps(zmm, vreg_prev_dst);
- else
- vfmadd231ps(zmm, vreg_prev_dst,
- zword_b[reg_ptr_sum_scale]);
- }
- if (maybe_relu(1)) {
- vxorps(vreg_zero, vreg_zero, vreg_zero);
- vmaxps(zmm, vreg_zero, zmm);
- }
-
- vmovups(addr, zmm | r_mask);
- }
- }
- };
-
- preamble();
-
-#define READ_PARAM(reg, field) \
- mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
- READ_PARAM(reg_aux_ptr_src, wino_dst);
- READ_PARAM(reg_aux_ptr_dst, dst);
- READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
- READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
- READ_PARAM(reg_ptr_bias, bias);
- READ_PARAM(reg_ptr_scales, scales);
-#undef READ_PARAM
-
- for (int i = 0; i < jcp.alpha * jcp.alpha; i++)
- vxorps(vreg_inp(i), vreg_inp(i), vreg_inp(i));
-
- for (int i = 0; i < jcp.alpha; i++)
- kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]);
-
- int oc_blocks = 1;
- oc_blocks = jcp.oc / load_block;
- mov(reg_oc_block, oc_blocks);
- L(oc_block_label);
- {
- loop_body();
- add(reg_aux_ptr_src, sizeof(float) * load_block);
- add(reg_aux_ptr_dst, sizeof(float) * jcp.oh * jcp.ow * load_block);
-
- add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block);
- add(reg_ptr_bias, jcp.typesize_bia * load_block);
- }
- dec(reg_oc_block);
- cmp(reg_oc_block, 0);
- jg(oc_block_label, T_NEAR);
-
- sub(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block);
- sub(reg_ptr_bias, oc_blocks * jcp.typesize_bia * load_block);
-
- postamble();
-
-}
-
-/// GEMM kernel ////////////////////////////////////////////////////////////////
-struct jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t: public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t)
- jit_conv_conf_2x3_wino_t jcp;
-
- struct call_params_t {
- const void *src;
- const void *dst;
- const void *wei;
- const void *dst_b;
- };
- void (*ker_)(const call_params_t *);
-
- void generate();
- static bool post_ops_ok(jit_conv_conf_2x3_wino_t &jcp,
- const primitive_attr_t &attr);
-
- jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t(
- jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
- : jcp(ajcp) {
- generate();
- ker_ = reinterpret_cast<decltype(ker_)>(
- const_cast<uint8_t *>(getCode()));
- }
-
- static status_t init_conf(
- jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
- memory_desc_t &src_md, memory_desc_t &weights_md,
- memory_desc_t &dst_md, memory_desc_t &bias_md,
- const primitive_attr_t &attr,
- memory_desc_t& expect_wei_md);
-
- Zmm vreg_out(int n, int m) {
- const int id_reg_out = n * jcp.m_block + m;
- assert(id_reg_out < jcp.n2_block * jcp.m_block);
- return Zmm(31 - id_reg_out);
- }
- Zmm vreg_wei(int i) {
- assert (31 - jcp.n2_block * jcp.m_block - i > 1);
- return Zmm(31 - jcp.n2_block * jcp.m_block - i);
- }
-
- Zmm vreg_src = Zmm(0);
- Zmm vreg_one = Zmm(1);
- Zmm vreg_tmp = Zmm(2);
-
- Reg64 reg_ptr_src = r15;
-
- Reg64 reg_aux_dst = r12;
- Reg64 reg_aux_dst2 = r11;
- Reg64 reg_aux_wei = r10;
- Reg64 reg_aux_wei2 = r9;
- Reg64 reg_aux_src = r8;
- Reg64 reg_aux_src2 = rax;
-
- Reg64 reg_mb = rbx;
- Reg64 reg_nnb = rdx;
- Reg64 reg_K = rsi;
-
-};
-
-bool jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::post_ops_ok(
- jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) {
- using namespace primitive_kind;
- const auto &p = attr.post_ops_;
-
- auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
-
- switch (p.len_) {
- case 0: return true;
- case 1: return is_relu(0) || p.contain(sum, 0);
- case 2: return (p.contain(sum, 0) && is_relu(1)) ||
- (p.contain(sum, 1) && is_relu(0));
- case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2);
- default: return false;
- }
-
- return false;
-}
-
-void jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::generate() {
- Label nnb_loop_label, K_loop_label, mb_loop_label;
-
- preamble();
-#define READ_PARAM(reg, field) \
- mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
- READ_PARAM(reg_ptr_src, src);
- READ_PARAM(reg_aux_dst, dst);
- READ_PARAM(reg_aux_wei, wei);
-#undef READ_PARAM
-
- if (!jcp.small_mb) {
- mov(reg_nnb, jcp.n_chunks);
- L(nnb_loop_label);
- }
- mov(reg_aux_dst2, reg_aux_dst);
- mov(reg_aux_src, reg_ptr_src);
- mov(reg_mb, jcp.M / jcp.m_block);
- L(mb_loop_label);
- {
- int nb2 = 0;
- for (nb2 = 0; nb2 < jcp.n2_block; nb2++) {
- for (int m = 0; m < jcp.m_block; m++) {
- vxorps(vreg_out(nb2, m), vreg_out(nb2, m), vreg_out(nb2, m));
- }
- }
- mov(reg_aux_src2, reg_aux_src);
- mov(reg_aux_wei2, reg_aux_wei);
-
- mov(reg_K, jcp.k_chunks);
- L(K_loop_label); {
- int wei_offset = 0;
- for (int _i = 0; _i < jcp.k2_block; _i++) {
- for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
- if (jcp.small_mb) {
- int wei_offset = sizeof(float)
- * ((nb2 * jcp.nb_ic * jcp.ic_block
- * jcp.oc_block)
- + _i * jcp.oc_block);
- vmovups(vreg_wei(nb2),
- EVEX_compress_addr(reg_aux_wei2, wei_offset));
- } else {
- vmovups(vreg_wei(nb2),
- EVEX_compress_addr(reg_aux_wei2,
- sizeof(float) * wei_offset));
- wei_offset += jcp.oc_block;
- }
- }
- for (int m = 0; m < jcp.m_block; m++) {
- int inp_offset = sizeof(float) * (m * jcp.K + _i);
- if (jcp.n2_block > 1) {
- vbroadcastss(vreg_src,
- EVEX_compress_addr(reg_aux_src2, inp_offset));
- for (int nb2 = 0; nb2 < jcp.n2_block; nb2++)
- vfmadd231ps(vreg_out(nb2, m), vreg_wei(nb2),
- vreg_src);
- } else {
- vfmadd231ps(vreg_out(0, m), vreg_wei(0),
- EVEX_compress_addr(reg_aux_src2, inp_offset, true));
- }
- }
- }
- add(reg_aux_src2, sizeof(float) * jcp.ic_block);
- if (jcp.small_mb)
- add(reg_aux_wei2, sizeof(float) * jcp.oc_block * jcp.ic_block);
- else
- add(reg_aux_wei2,
- sizeof(float) * jcp.k2_block * jcp.n2_block
- * jcp.oc_block);
- }
- dec(reg_K);
- cmp(reg_K, 0);
- jg(K_loop_label, T_NEAR);
-
- for (int m = 0; m < jcp.m_block; m++) {
- int nb2 = 0;
- for (nb2 = 0; nb2 < jcp.n2_block; nb2++) {
- int offset = sizeof(float) *
- (m * jcp.N + nb2 * jcp.oc_block);
- vmovups(EVEX_compress_addr(reg_aux_dst2,offset),
- vreg_out(nb2, m));
- }
- }
- add(reg_aux_src, sizeof(float) * jcp.m_block * jcp.K);
- add(reg_aux_dst2, sizeof(float) * jcp.m_block * jcp.N);
- }
- dec(reg_mb);
- cmp(reg_mb, 0);
- jg(mb_loop_label, T_NEAR);
-
- if (!jcp.small_mb) {
- add(reg_aux_dst, sizeof(float) * jcp.n2_block * jcp.oc_block);
- add(reg_aux_wei,
- sizeof(float) * jcp.k_chunks * jcp.ic_block * jcp.n2_block
- * jcp.oc_block);
-
- dec(reg_nnb);
- cmp(reg_nnb, 0);
- jg(nnb_loop_label, T_NEAR);
- }
- postamble();
-}
-
-namespace {
-bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) {
- return jcp.mb >= 4;
-}
-}
-
-status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf(
- jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
- memory_desc_t &src_md, memory_desc_t &wei_md,
- memory_desc_t &dst_md, memory_desc_t &bias_md,
- const primitive_attr_t &attr, memory_desc_t &expect_wei_md) {
- const memory_desc_wrapper src_d(&src_md);
- const memory_desc_wrapper wei_d(&wei_md);
- const memory_desc_wrapper dst_d(&dst_md);
- const memory_desc_wrapper bias_d(&bias_md);
-
- const bool with_groups = wei_d.ndims() == src_d.ndims() + 1;
-
- jcp.nthr = mkldnn_get_max_threads();
-
- jcp.ngroups = with_groups ? wei_d.dims()[0] : 1;
- jcp.mb = src_d.dims()[0];
- jcp.oc = dst_d.dims()[1] / jcp.ngroups;
- jcp.oc_without_padding = jcp.oc;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
- jcp.ih = src_d.dims()[2];
- jcp.iw = src_d.dims()[3];
- jcp.oh = dst_d.dims()[2];
- jcp.ow = dst_d.dims()[3];
- jcp.kh = wei_d.dims()[with_groups + 2];
- jcp.kw = wei_d.dims()[with_groups + 3];
- jcp.t_pad = cd.padding[0][0];
- jcp.b_pad = cd.padding[1][0];
- jcp.l_pad = cd.padding[0][1];
- jcp.r_pad = cd.padding[1][1];
- jcp.stride_h = cd.strides[0];
- jcp.stride_w = cd.strides[1];
- jcp.dilate_h = cd.dilates[0];
- jcp.dilate_w = cd.dilates[1];
-
- jcp.m = 2;
- jcp.r = 3;
- jcp.alpha = jcp.m + jcp.r - 1;
- int simdw = 16;
-
- format_tag_t dat_tag = format_tag::nChw16c;
- jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
- jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
-
- if (jcp.src_tag != dat_tag) return status::unimplemented;
- if (jcp.dst_tag != dat_tag) return status::unimplemented;
-
- jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
-
- if (!post_ops_ok(jcp, attr))
- return status::unimplemented;
-
- bool ok_to_pad_channels = jcp.ngroups == 1;
- if (ok_to_pad_channels) {
- jcp.oc = rnd_up(jcp.oc, simdw);
- jcp.ic = rnd_up(jcp.ic, simdw);
- }
-
- jcp.ver = ver_avx512_core;
- if (!(mayiuse(avx512_core)))
- return status::unimplemented;
-
- if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
- is_winograd_faster_than_direct(jcp)))
- return status::unimplemented;
-
- if (src_d.data_type() != data_type::f32)
- return status::unimplemented;
- if (wei_d.data_type() != data_type::f32)
- return status::unimplemented;
- if (dst_d.data_type() != data_type::f32)
- return status::unimplemented;
-
- jcp.ic_block = simdw;
- jcp.oc_block = simdw;
-
- bool ok = true && jcp.kh == 3 && jcp.kw == 3 && jcp.ngroups == 1
- && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0
- && jcp.stride_h == 1 && jcp.stride_w == 1 && jcp.dilate_h == 0
- && jcp.dilate_w == 0 && jcp.t_pad == jcp.b_pad
- && jcp.l_pad == jcp.r_pad && jcp.t_pad < 2 && jcp.t_pad >= 0
- && jcp.l_pad < 2 && jcp.l_pad >= 0;
- if (!ok)
- return status::unimplemented;
-
- const int L2_cap = get_cache_size(2, true) / sizeof(float);
- const int L3_capacity = get_cache_size(3, false) / sizeof(float);
- int a = jcp.alpha;
- int aa = a * a;
- int mb = jcp.mb;
- int ic = jcp.ic;
- int oc = jcp.oc;
- int ih = jcp.ih;
- int iw = jcp.iw;
- auto wei_sz = (float)aa * ic * oc;
- auto inp_sz = (float)mb * ih * iw * ic;
- auto sp_sz = (float)mb * ih * iw;
-
- /* Heuristics here. Numbers '28','196' is an observation from data. */
- if (wei_sz / inp_sz > 5)
- jcp.small_mb = true;
- else
- jcp.small_mb = false;
-
- if (mb > nstl::min(jcp.nthr, 28)
- || (!jcp.small_mb
- && (wei_sz >= 0.9f * L2_cap
- || inp_sz > L2_cap * jcp.nthr + L3_capacity))
- || (jcp.small_mb && sp_sz > 196))
- return status::unimplemented;
-
- jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
- jcp.dst_dt = cd.dst_desc.data_type;
-
- jcp.typesize_bia
- = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
-
- jcp.nb_oc = jcp.oc / jcp.oc_block;
- jcp.nb_ic = jcp.ic / jcp.ic_block;
-
- const int skx_free_regs = 30;
-
- auto find_m_n2_blocks = [=](int xb, int yb, int &M, int &m_block,
- int &n2_block, float &reg_eff) {
- M = (xb * yb) / jcp.alpha;
- int max_m_block = m_block = nstl::min(M, skx_free_regs);
- int max_n2_block = n2_block = nstl::min(jcp.nb_oc, skx_free_regs);
- reg_eff = 0;
- for (int im = max_m_block; im > 0; im--) {
- for (int in2 = max_n2_block; in2 > 0; in2--) {
- int used_regs = in2 * im + in2;
- float cur_reg_eff = ((float)in2 * im) / (im + in2) / 2.5f;
- if (M % im || jcp.nb_oc % in2 || used_regs > skx_free_regs
- || cur_reg_eff <= reg_eff)
- continue;
- reg_eff = cur_reg_eff;
- m_block = im;
- n2_block = in2;
- }
- }
- };
-
- int oh = jcp.oh;
- int ow = jcp.ow;
- int nb_oc = jcp.nb_oc;
- int Z = ic + oc;
- int Y = ic * oc;
- const int L3_cap_per_core = get_cache_size(3, true) / sizeof(float);
-
- /* Selecting xb and yb blocking */
- int min_yb = jcp.alpha;
- int min_xb = jcp.alpha;
- int max_yb = nstl::max(min_yb, rnd_up(ih, 2));
- int max_xb = nstl::max(min_xb, rnd_up(iw, 2));
- float best_eff = 0.f;
- for (int ix = max_xb; ix >= min_xb; ix -= 2) {
- if (rnd_up(ow, ix) < iw - 2)
- continue;
- for (int iy = max_yb; iy >= min_yb; iy -= 2) {
- if (rnd_up(oh, iy) < ih - 2)
- continue;
- int ex_y = rnd_up(oh, iy);
- int ex_x = rnd_up(ow, ix);
- float work_eff = (float)(ih * iw) / (ex_y * ex_x);
-
- int M, m_block, n2_b;
- float reg_eff, thr_eff, par_eff, mem_eff, req_mem;
-
- find_m_n2_blocks(ix, iy, M, m_block, n2_b, reg_eff);
-
- /* outer parallelization */
- int nblocks = mb * div_up(ih, iy) * div_up(iw, ix);
- thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr);
-
- mem_eff = 1.f;
- req_mem = (((float)ix + 2) * (iy + 2) + aa * M) * Z + aa * Y;
- if (req_mem > L2_cap / 2) {
- if (req_mem > ((L2_cap + L3_cap_per_core) * 4) / 7)
- mem_eff /= (n2_b + 1) / 2.f;
- else
- mem_eff /= (n2_b + 1) / 3.f;
- }
-
- float outer_eff = thr_eff + work_eff + reg_eff + mem_eff;
-
- /* inner parallelization */
- int bsz = iy * ix / a;
- int gemmw = aa * (nb_oc / n2_b);
- int bsz_r = rnd_up(bsz, jcp.nthr);
- int gemmw_r = rnd_up(gemmw, jcp.nthr);
- thr_eff = ((float)Z * bsz / bsz_r + Y * gemmw / gemmw_r) / (Z + Y);
-
- req_mem = (float)ix * iy * (ic + simdw * n2_b) + simdw * n2_b * ic;
- mem_eff = nstl::min(1.f, L2_cap / req_mem);
- int M_per_thr = nstl::max(2, div_up(aa, jcp.nthr));
- int oc_per_thr =
- nstl::min(oc, div_up(aa * (nb_oc / n2_b), jcp.nthr));
- req_mem = (float)aa * oc_per_thr * ic + M_per_thr * M * Z;
- if (req_mem > L2_cap)
- mem_eff = 0.1f;
- par_eff = 1 / (2.f * nblocks);
-
- float inner_eff = thr_eff + work_eff + mem_eff + par_eff;
-
- float eff = jcp.small_mb ? inner_eff : outer_eff;
- if (eff > best_eff) {
- best_eff = eff;
- jcp.yb = iy;
- jcp.xb = ix;
- jcp.M = M;
- jcp.m_block = m_block;
- jcp.n2_block = n2_b;
- }
- }
- }
-
- assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0);
-
- jcp.inp_stride = jcp.M * jcp.ic;
- jcp.out_stride = jcp.M * jcp.oc;
- jcp.wei_stride = jcp.ic * jcp.oc;
- jcp.bia_stride = jcp.oc;
-
- jcp.N = jcp.oc;
- jcp.K = jcp.ic;
-
- jcp.n_block = jcp.oc_block;
- jcp.k_block = jcp.ic_block;
-
- assert(jcp.M % jcp.m_block == 0);
- assert(jcp.nb_oc % jcp.n2_block == 0);
-
- jcp.n_chunks = jcp.nb_oc / jcp.n2_block;
- jcp.k2_block = jcp.ic_block;
- jcp.k_chunks = jcp.K / jcp.k2_block;
-
- const auto &oscales = attr.output_scales_;
- jcp.is_oc_scale = oscales.mask_ == 1 << 1;
- assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
-
- /* re-create weights primitive descriptor
- and set weights wino_blocking */
- expect_wei_md.format_kind = format_kind::wino;
- expect_wei_md.data_type = data_type::f32;
- mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc;
- wd.wino_format
- = jcp.small_mb ? mkldnn_wino_wei_aaOio : mkldnn_wino_wei_aaOBiOo;
- wd.r = jcp.r;
- wd.alpha = jcp.alpha;
- wd.ic = jcp.ic;
- wd.oc = jcp.oc;
- wd.ic_block = jcp.ic_block;
- wd.oc_block = jcp.oc_block;
- wd.oc2_block = jcp.n2_block;
- wd.ic2_block = 1;
- wd.adj_scale = 1.f;
- size_t max_size = sizeof(float) * jcp.alpha * jcp.alpha * jcp.ic * jcp.oc;
- wd.size = max_size;
-
- return status::success;
-}
-////////////////////////////////////////////////////////////////////////////////
-
-status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_t
- ::pd_t::jit_conf(memory_desc_t& expect_wei_md) {
- return jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::init_conf(
- jcp_, *this->desc(), this->src_md_, this->weights_md_,
- this->dst_md_,this->bias_md_, *this->attr(), expect_wei_md);
-}
-
-jit_avx512_core_fp32_wino_conv_2x3_fwd_t::
- jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *apd)
- : cpu_primitive_t(apd)
-{
- kernel_ = new jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t(
- pd()->jcp_, *pd()->attr());
- src_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_src_trans_t(
- pd()->jcp_, *pd()->attr());
- dst_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t(
- pd()->jcp_, *pd()->attr());
-}
-
-jit_avx512_core_fp32_wino_conv_2x3_fwd_t
- ::~jit_avx512_core_fp32_wino_conv_2x3_fwd_t() {
- delete kernel_;
- delete src_trans_;
- delete dst_trans_;
-}
-
-void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_mbN(
- const float *src, const float *wei, const float *bia, float *dst,
- const memory_tracking::grantor_t &scratchpad) const
-{
- const auto &jcp = kernel_->jcp;
- const auto &oscales = pd()->attr()->output_scales_;
-
- const size_t wino_size_offset =
- (size_t)(pd()->jcp_.yb / 2) * (pd()->jcp_.xb / 2) + (pd()->jcp_.xb);
- const size_t size_wino_src = wino_size_offset * pd()->jcp_.ic * 16;
- const size_t size_wino_dst = wino_size_offset * pd()->jcp_.oc * 16;
-
- if (pd()->wants_padded_bias()) {
- auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
- utils::array_copy(padded_bias, bia, jcp.oc_without_padding);
- utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
- jcp.oc - jcp.oc_without_padding);
- bia = padded_bias;
- }
-
- auto ptr_V = scratchpad.get<float>(key_wino_V);
- auto ptr_M = scratchpad.get<float>(key_wino_M);
-
- parallel_nd(jcp.mb, div_up(jcp.oh,jcp.yb), div_up(jcp.ow, jcp.xb),
- [&](int mb, int tile_y_b, int tile_x_b) {
- int tile_y = tile_y_b * jcp.yb;
- int tile_x = tile_x_b * jcp.xb;
-
- int ithr = mkldnn_get_thread_num();
- auto wino_src = ptr_V + size_wino_src * ithr;
- auto wino_dst = ptr_M + size_wino_dst * ithr;
-
- auto src_trans_p =
- jit_avx512_core_fp32_wino_conv_2x3_src_trans_t
- ::call_params_t();
- auto dst_trans_p =
- jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t
- ::call_params_t();
- auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::
- call_params_t();
-
- /* transformation of input tensor to winograd domain */
- for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
- for (int x_in_block = 0; x_in_block < jcp.xb;
- x_in_block += 2) {
-
- unsigned short v_y_masks[4], v_x_masks[4];
-
- int y = y_in_block + tile_y;
- int x = x_in_block + tile_x;
- int m = (y_in_block / 2) * (jcp.xb / 2)
- + (x_in_block / 2);
-
- int v_ys = nstl::max(0, jcp.t_pad - y);
- int v_ye = nstl::min(jcp.alpha,
- nstl::max(0, jcp.ih + jcp.t_pad - y));
-
- int v_xs = nstl::max(0, jcp.l_pad - x);
- int v_xe = nstl::min(jcp.alpha,
- nstl::max(0, jcp.iw + jcp.l_pad - x));
-
-#pragma unroll(4)
- for (int i = 0; i < jcp.alpha; i++) {
- v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff;
- v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff;
- }
- auto local_s = src
- + mb * jcp.nb_ic * jcp.ih * jcp.iw
- * jcp.ic_block
- + y * jcp.iw * jcp.ic_block + x * jcp.ic_block;
- auto local_w = wino_src + m * jcp.ic;
-
- src_trans_p.src = local_s;
- src_trans_p.wino_src = local_w;
- src_trans_p.v_y_masks = v_y_masks;
- src_trans_p.v_x_masks = v_x_masks;
-
- src_trans_->ker_(&src_trans_p);
- }
- }
- /* gemms */
- for (int tile_ij = 0; tile_ij < 16; tile_ij++) {
- int offset = (tile_ij + ithr) % 16;
- gemm_p.src = wino_src + jcp.inp_stride * offset;
- gemm_p.dst = wino_dst + jcp.out_stride * offset;
- gemm_p.wei = wei + jcp.wei_stride * offset;
-
- kernel_->ker_(&gemm_p);
- }
-
- /* transformation from winograd domain to output tensor */
- for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
- for (int x_in_block = 0; x_in_block < jcp.xb;
- x_in_block += 2) {
- unsigned short v_y_masks[2], v_x_masks[2];
-
- int y = y_in_block + tile_y;
- int x = x_in_block + tile_x;
- int m = (y_in_block / 2) * (jcp.xb / 2)
- + (x_in_block / 2);
-
-#pragma unroll(2)
- for (int i = 0; i < jcp.m; i++) {
- v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0;
- v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0;
- }
- auto local_d = dst
- + mb * jcp.nb_oc * jcp.oh * jcp.ow
- * jcp.oc_block
- + y * jcp.ow * jcp.oc_block + x * jcp.oc_block;
- auto local_w = wino_dst + m * jcp.oc;
-
- auto scales = oscales.scales_;
- dst_trans_p.dst = local_d;
- dst_trans_p.wino_dst = local_w;
- dst_trans_p.v_y_masks = v_y_masks;
- dst_trans_p.v_x_masks = v_x_masks;
-
- dst_trans_p.scales = scales;
- dst_trans_p.bias = bia;
-
- dst_trans_->ker_(&dst_trans_p);
- }
- }
- });
-}
-
-void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_small_mb(
- const float *src, const float *wei, const float *bia, float *dst,
- const memory_tracking::grantor_t &scratchpad) const
-{
- const auto &jcp = kernel_->jcp;
- const auto &oscales = pd()->attr()->output_scales_;
-
- if (pd()->wants_padded_bias()) {
- auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
- utils::array_copy(padded_bias, bia, jcp.oc_without_padding);
- utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
- jcp.oc - jcp.oc_without_padding);
- bia = padded_bias;
- }
-
- auto ptr_V = scratchpad.get<float>(key_wino_V);
- auto ptr_M = scratchpad.get<float>(key_wino_M);
-
- for (int mb = 0; mb < jcp.mb; mb++) {
- for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) {
- for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) {
- /* transformation of input tensor to winograd domain */
- parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2),
- [&](int y_in_block_b, int x_in_block_b) {
- int y_in_block = y_in_block_b * 2;
- int x_in_block = x_in_block_b * 2;
-
- auto src_trans_p = jit_avx512_core_fp32_wino_conv_2x3_src_trans_t ::
- call_params_t();
-
- unsigned short v_y_masks[4], v_x_masks[4];
-
- int y = y_in_block + tile_y;
- int x = x_in_block + tile_x;
- int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
-
- int v_ys = nstl::max(0, jcp.t_pad - y);
- int v_ye = nstl::min(
- jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y));
-
- int v_xs = nstl::max(0, jcp.l_pad - x);
- int v_xe = nstl::min(
- jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x));
-
-#pragma unroll(4)
- for (int i = 0; i < jcp.alpha; i++) {
- v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff;
- v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff;
- }
- auto local_s = src
- + mb * jcp.nb_ic * jcp.ih * jcp.iw * jcp.ic_block
- + y * jcp.iw * jcp.ic_block + x * jcp.ic_block;
- auto local_w = ptr_V + m * jcp.ic;
-
- src_trans_p.src = local_s;
- src_trans_p.wino_src = local_w;
- src_trans_p.v_y_masks = v_y_masks;
- src_trans_p.v_x_masks = v_x_masks;
-
- src_trans_->ker_(&src_trans_p);
- });
-
- /* gemms */
- parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) {
- auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::
- call_params_t();
-
- gemm_p.src = ptr_V + jcp.inp_stride * tile_ij;
- gemm_p.dst = ptr_M + jcp.out_stride * tile_ij
- + nnb * jcp.n2_block * jcp.n_block;
- gemm_p.wei = wei + jcp.wei_stride * tile_ij
- + nnb * jcp.n2_block * jcp.n_block * jcp.K;
-
- kernel_->ker_(&gemm_p);
- });
-
- /* transformation from winograd domain to output tensor */
-
- parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2),
- [&](int y_in_block_b, int x_in_block_b) {
- int y_in_block = y_in_block_b * 2;
- int x_in_block = x_in_block_b * 2;
-
- auto dst_trans_p = jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t ::
- call_params_t();
-
- unsigned short v_y_masks[2], v_x_masks[2];
-
- int y = y_in_block + tile_y;
- int x = x_in_block + tile_x;
- int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
-
-#pragma unroll(2)
- for (int i = 0; i < jcp.m; i++) {
- v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0;
- v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0;
- }
- auto local_d = dst
- + mb * jcp.nb_oc * jcp.oh * jcp.ow * jcp.oc_block
- + y * jcp.ow * jcp.oc_block + x * jcp.oc_block;
- auto local_w = ptr_M + m * jcp.oc;
-
- auto scales = oscales.scales_;
- dst_trans_p.dst = local_d;
- dst_trans_p.wino_dst = local_w;
- dst_trans_p.v_y_masks = v_y_masks;
- dst_trans_p.v_x_masks = v_x_masks;
-
- dst_trans_p.scales = scales;
- dst_trans_p.bias = bia;
-
- dst_trans_->ker_(&dst_trans_p);
- });
- }}}
-}
-
-} // namespace cpu
-} // namespace impl
-} // namespace mkldnn
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp
deleted file mode 100644
index 7e38b07f5a..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp
+++ /dev/null
@@ -1,144 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_AVX512_CORE_FP32_WINO_CONV_2x3_HPP
-#define CPU_JIT_AVX512_CORE_FP32_WINO_CONV_2x3_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_convolution_pd.hpp"
-#include "cpu_primitive.hpp"
-
-#include "jit_primitive_conf.hpp"
-#include "jit_generator.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t;
-struct jit_avx512_core_fp32_wino_conv_2x3_src_trans_t;
-struct jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t;
-
-struct jit_avx512_core_fp32_wino_conv_2x3_fwd_t : public cpu_primitive_t {
- struct pd_t : public cpu_convolution_fwd_pd_t {
- pd_t(engine_t *engine, const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const typename pd_t::base_class *hint_fwd_pd)
- : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_fp32_wino_2x3:", avx512_core, ""),
- jit_avx512_core_fp32_wino_conv_2x3_fwd_t);
-
- status_t init() {
- bool ok = true
- && desc()->prop_kind == prop_kind::forward_inference
- && utils::one_of(desc()->alg_kind,
- alg_kind::convolution_auto,
- alg_kind::convolution_winograd)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::f32, data_type::f32, data_type::f32)
- && set_default_formats();
- if (!ok) return status::unimplemented;
-
- memory_desc_t expect_wei_md = *weights_md();
- status_t jit_conf_result = jit_conf(expect_wei_md);
- if (jit_conf_result != status::success) return jit_conf_result;
- set_default_alg_kind(alg_kind::convolution_winograd);
-
- if (weights_md_.format_kind == format_kind::any)
- weights_md_ = expect_wei_md;
- if (weights_md_ != expect_wei_md)
- return status::unimplemented;
-
- init_scratchpad();
-
- return status::success;
- }
-
- jit_conv_conf_2x3_wino_t jcp_;
-
- protected:
- status_t jit_conf(memory_desc_t& expect_wei_md);
-
- void init_scratchpad() {
- using namespace memory_tracking::names;
-
- auto scratchpad = scratchpad_registry().registrar();
-
- int wino_size_offset = (jcp_.yb / 2) * (jcp_.xb / 2) + jcp_.xb;
-
- size_t V_sz = (size_t)jcp_.ic * 16 * wino_size_offset * jcp_.nthr;
- scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_4K);
-
- size_t M_sz = (size_t)jcp_.oc * 16 * wino_size_offset * jcp_.nthr;
- scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_4K);
-
- if (wants_padded_bias()) {
- assert(jcp_.ngroups == 1);
- scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp_.oc);
- }
- }
-
- bool set_default_formats() {
- using namespace format_tag;
- return set_default_formats_common(nChw16c, any, nChw16c);
- }
- };
-
- jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *apd);
- ~jit_avx512_core_fp32_wino_conv_2x3_fwd_t();
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC);
- auto wei = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS);
- auto bia = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST);
-
- if (pd()->jcp_.small_mb)
- execute_forward_small_mb(src, wei, bia, dst, this->scratchpad(ctx));
- else
- execute_forward_mbN(src, wei, bia, dst, this->scratchpad(ctx));
-
- return status::success;
- }
-
-private:
- void execute_forward_small_mb(const float *src, const float *wei,
- const float *bia, float *dst,
- const memory_tracking::grantor_t &scratchpad) const;
- void execute_forward_mbN(const float *src, const float *wei,
- const float *bia, float *dst,
- const memory_tracking::grantor_t &scratchpad) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t *kernel_;
- jit_avx512_core_fp32_wino_conv_2x3_src_trans_t *src_trans_;
- jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t *dst_trans_;
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp
deleted file mode 100644
index 96325e3ade..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp
+++ /dev/null
@@ -1,1020 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifdef __INTEL_COMPILER
-#include <immintrin.h>
-#endif
-
-#include "mkldnn_types.h"
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "jit_avx512_core_fp32_wino_conv_4x3.hpp"
-
-#ifndef _MSC_VER
-#define pragma_unroll _Pragma("unroll")
-#else
-#define pragma_unroll
-#endif
-
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::memory_tracking::names;
-using namespace mkldnn::impl::utils;
-
-template <bool is_fwd>
-void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
-::weight_transform_data(const jit_conv_winograd_conf_t &jcp,
- float *wp, float *twp) const
-{
- float G[] = {0.26890756302521f, 0.688403361344538f, 0.119514472455649f,
- 1.13777777777778f, 0.430252100840336f, 0.179271708683473f};
- const int kh = 3;
- const int kw = 3;
- float Fw[alpha][alpha][simd_w][simd_w];
- float F[kh][kw][simd_w][simd_w];
- float T[alpha][3][simd_w];
- auto p = jit_wino_transform_call_s();
-
- p.src = wp;
- p.dst = twp;
- p.G = G;
- p.M = F;
- p.Mw = Fw;
- p.T = T;
-
- kernel_->weights_transform_data_ker(&p);
-}
-
-template<bool is_fwd>
-void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::output_transform_data
-(int image, const jit_conv_winograd_conf_t &jcp,
- const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias) const {
-
- float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
- float Ow[alpha][alpha][simd_w];
- float O[tile_size][tile_size][simd_w];
- float T[tile_size][alpha][simd_w];
-
- auto p = jit_wino_transform_call_s();
- p.src = toutp;
- p.dst = pout_b;
- p.G = G;
- p.M = O;
- p.Mw = Ow;
- p.T = T;
- p.bias = bias;
-
- int tile_base_index = image * jcp.itiles * jcp.jtiles;
- int tile_block_ur = tile_base_index % jcp.tile_block_ur;
- int nb_tile_block_ur =
- (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
- int tile_block =
- (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
-
- for (int tj = 0; tj < jcp.jtiles; tj++) {
- for (int ti = 0; ti < jcp.itiles; ti++) {
-
- p.tile_block_ur = tile_block_ur;
- p.nb_tile_block_ur = nb_tile_block_ur;
- p.tile_block = tile_block;
- p.tj = tj;
- p.ti = ti;
-
- kernel_->output_transform_data_ker(&p);
-
- tile_block_ur++;
- if (tile_block_ur >= jcp.tile_block_ur) {
- tile_block_ur = 0;
- nb_tile_block_ur++;
- }
- if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
- nb_tile_block_ur = 0;
- tile_block++;
- }
- }
- }
-}
-
-template<bool is_fwd>
-void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
-::output_transform_tileblock_data(int tile_block,
- const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops,
- float *toutp, float *outp, float *bias) const {
-
- float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
- float Ow[alpha][alpha][simd_w];
- float O[tile_size][tile_size][simd_w];
- float T[tile_size][alpha][simd_w];
-
- auto p = jit_wino_transform_call_s();
- p.src = toutp;
- p.dst = outp;
- p.G = G;
- p.M = O;
- p.Mw = Ow;
- p.T = T;
- p.bias = bias;
-
- int outw = is_fwd ? jcp.ow : jcp.iw;
- int outh = is_fwd ? jcp.oh : jcp.ih;
-
- int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
-
- for (int nb_tile_block_ur = 0;
- nb_tile_block_ur < jcp.nb_tile_block_ur;
- nb_tile_block_ur++) {
-
- for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
- tile_block_ur++) {
- int img = tile_index / (jcp.jtiles * jcp.itiles);
- int ti = tile_index % jcp.itiles;
- int tj = (tile_index / jcp.itiles) % jcp.jtiles;
-
- p.tile_block_ur = tile_block_ur;
- p.nb_tile_block_ur = nb_tile_block_ur;
- p.tile_block = tile_block;
- p.tj = tj;
- p.ti = ti;
- p.dst = outp + img * (jcp.dimM / jcp.dimM_simd_block)
- * outh * outw * jcp.dimM_simd_block;
-
- kernel_->output_transform_data_ker(&p);
-
- tile_index++;
- }
- }
-}
-
-
-template<bool is_fwd>
-void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
- ::input_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
- float *inp, float *tinp) const
-{
- float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
- 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
-
- float Iw[alpha][alpha][simd_w];
- float I[alpha][alpha][simd_w];
- float T[alpha][alpha][simd_w];
-
- auto p = jit_wino_transform_call_s();
-
- p.src = inp;
- p.dst = tinp;
- p.G = G;
- p.M = I;
- p.Mw = Iw;
- p.T = T;
-
- int tile_base_index = image * jcp.itiles * jcp.jtiles;
- int tile_block_ur = tile_base_index % jcp.tile_block_ur;
- int nb_tile_block_ur =
- (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
- int tile_block =
- (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
-
- for (int tj = 0; tj < jcp.jtiles; tj++) {
- for (int ti = 0; ti < jcp.itiles; ti++) {
-
- p.tile_block_ur = tile_block_ur;
- p.nb_tile_block_ur = nb_tile_block_ur;
- p.tile_block = tile_block;
- p.tj = tj;
- p.ti = ti;
-
- kernel_->input_transform_data_ker(&p);
-
- tile_block_ur++;
- if (tile_block_ur >= jcp.tile_block_ur) {
- tile_block_ur = 0;
- nb_tile_block_ur++;
- }
- if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
- nb_tile_block_ur = 0;
- tile_block++;
- }
- }
- }
-}
-
-template <bool is_fwd>
-void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
- ::input_transform_tileblock_data(int tile_block,
- const jit_conv_winograd_conf_t &jcp,
- float *inp, float *tinp) const
-{
- float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
- 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
- float Iw[alpha][alpha][simd_w];
- float I[alpha][alpha][simd_w];
- float T[alpha][alpha][simd_w];
-
- const int inph = is_fwd ? jcp.ih : jcp.oh;
- const int inpw = is_fwd ? jcp.iw : jcp.ow;
-
- array_offset_calculator<float, 5> input(inp,
- jcp.mb, jcp.dimK / simd_w, inph, inpw, simd_w);
- array_offset_calculator<float, 7> output(tinp,
- alpha, alpha,
- jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block,
- jcp.dimN_reg_block, jcp.dimK_reg_block);
-
- auto p = jit_wino_transform_call_s();
-
- p.dst = tinp;
- p.G = G;
- p.M = I;
- p.Mw = Iw;
- p.T = T;
-
-
- int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
-
- for (int nb_tile_block_ur = 0;
- nb_tile_block_ur < jcp.nb_tile_block_ur;
- nb_tile_block_ur++) {
-
- for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
- tile_block_ur++) {
-
- int img = tile_index / (jcp.jtiles * jcp.itiles);
- int ti = tile_index % jcp.itiles;
- int tj = (tile_index / jcp.itiles) % jcp.jtiles;
- float *pinp_b = &(input(img, 0, 0, 0, 0));
-
- p.src = pinp_b;
- p.tile_block_ur = tile_block_ur;
- p.nb_tile_block_ur = nb_tile_block_ur;
- p.tj = tj;
- p.ti = ti;
-
- kernel_->input_transform_data_ker(&p);
-
- tile_index++;
- }
- }
-}
-
-template <bool is_fwd>
-void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_S_G_D(
- float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
- const memory_tracking::grantor_t &scratchpad) const {
- const auto &jcp = kernel_->jcp;
- const auto &p_ops = attr_->post_ops_;
-
- const int inph = is_fwd ? jcp.ih : jcp.oh;
- const int inpw = is_fwd ? jcp.iw : jcp.ow;
- const int outh = is_fwd ? jcp.oh : jcp.ih;
- const int outw = is_fwd ? jcp.ow : jcp.iw;
-
- /* Notation:
- FWD: dimM:oc, dimN:ntiles, dimK:ic,
- BWD: dimM:ic, dimN:ntiles, dimK:oc,
- FWD/BWD: V: src/diff_dst transform, U:weight transform,
- M:dst/diff_src transform */
- array_offset_calculator<float, 5> input(inp_ptr,
- jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw,
- jcp.dimK_reg_block);
- array_offset_calculator<float, 5> output(out_ptr,
- jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw,
- jcp.dimM_simd_block);
- array_offset_calculator<float, 6> weights(wei_ptr,
- jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
- jcp.ic_simd_block, jcp.oc_simd_block);
- array_offset_calculator<float, 2> bias(bias_ptr,
- jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block);
-
- array_offset_calculator<float, 8> M(is_fwd
- ? scratchpad.template get<float>(key_wino_M)
- : scratchpad.template get<float>(key_wino_V),
- jcp.dimN_nb_block, jcp.dimM_nb_block,
- alpha, alpha,
- jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block,
- jcp.dimN_reg_block, jcp.dimM_simd_block);
-
- auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference)
- ? wei_ptr
- : scratchpad.template get<float>(key_wino_U);
-
- array_offset_calculator<float, 8> U(wino_wei,
- jcp.dimM_nb_block,
- alpha, alpha,
- jcp.dimK_nb_block,
- jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block,
- jcp.dimK_reg_block, jcp.dimM_simd_block);
- array_offset_calculator<float, 8> V(is_fwd
- ? scratchpad.template get<float>(key_wino_V)
- : scratchpad.template get<float>(key_wino_M),
- jcp.dimN_nb_block, alpha, alpha,
- jcp.dimN_block, jcp.dimK_nb_block,
- jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block);
-
- const bool wants_padded_bias = jcp.with_bias
- && jcp.oc_without_padding != jcp.oc;
- float last_slice_bias[simd_w] = {0};
- if (wants_padded_bias) {
- for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
- last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
- }
-
- {
-
- parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block,
- [&](int img, int K_blk1, int K_blk2) {
- input_transform_data(img, jcp,
- &(input(img, K_blk1 * jcp.dimK_block + K_blk2,
- 0, 0, 0)),
- &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
- });
-
- if (jcp.prop_kind != prop_kind::forward_inference) {
- parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block),
- (jcp.ic_block * jcp.ic_reg_block),
- [&](int ofm1, int ifm1, int ofm2, int ifm2) {
- float *U_base_ptr = is_fwd
- ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
- : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
- weight_transform_data(jcp,
- &(weights(
- ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2,
- ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2,
- 0, 0, 0, 0)),
- U_base_ptr);
- });
- }
-
- parallel_nd(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block,
- [&](int N_blk1, int oj, int oi, int M_blk1) {
- for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block;
- K_blk1++)
- for (int N_blk2 = 0; N_blk2 < jcp.dimN_block; N_blk2++)
- kernel_->gemm_loop_ker(
- (float *)&(M(N_blk1, M_blk1, oj, oi,
- N_blk2, 0, 0, 0)),
- (const float *)&(U(M_blk1, oj, oi,
- K_blk1, 0, 0, 0, 0)),
- (const float *)&(V(N_blk1, oj, oi,
- N_blk2, K_blk1, 0, 0, 0)), K_blk1);
- });
-
- parallel_nd(jcp.mb, jcp.dimM_nb_block, (jcp.dimM_block * jcp.dimM_reg_block),
- [&](int img, int M_blk1, int M_blk2) {
- const int M_blk =
- M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2;
-
- float *bias_ptr = wants_padded_bias
- && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
- ? last_slice_bias : &bias(M_blk, 0);
- output_transform_data(img, jcp, p_ops,
- &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
- &(output(img, M_blk, 0, 0, 0)), bias_ptr);
- });
-
- }
-}
-
-template <bool is_fwd>
-void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_SGD(
- float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
- const memory_tracking::grantor_t &scratchpad) const {
- const auto &jcp = kernel_->jcp;
- const auto &p_ops = attr_->post_ops_;
-
- const int inph = is_fwd ? jcp.ih : jcp.oh;
- const int inpw = is_fwd ? jcp.iw : jcp.ow;
- const int outh = is_fwd ? jcp.oh : jcp.ih;
- const int outw = is_fwd ? jcp.ow : jcp.iw;
-
- array_offset_calculator<float, 5> input(inp_ptr,
- jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block);
- array_offset_calculator<float, 5> output(out_ptr,
- jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block);
- array_offset_calculator<float, 6> weights(wei_ptr,
- jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
- jcp.ic_simd_block, jcp.oc_simd_block);
- array_offset_calculator<float, 2> bias(bias_ptr,
- jcp.oc/jcp.oc_simd_block, jcp.oc_simd_block);
-
- auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference)
- ? wei_ptr
- : scratchpad.template get<float>(key_wino_U);
-
- array_offset_calculator<float, 8> U(wino_wei,
- jcp.dimM_nb_block,
- alpha, alpha,
- jcp.dimK_nb_block,
- jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block,
- jcp.dimK_reg_block, jcp.dimM_simd_block);
-
- array_offset_calculator<float, 8> M(is_fwd
- ? scratchpad.template get<float>(key_wino_M)
- : scratchpad.template get<float>(key_wino_V),
- 0, jcp.dimM_nb_block, alpha, alpha,
- jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block,
- jcp.dimN_reg_block, jcp.dimM_simd_block);
- array_offset_calculator<float, 8> V(is_fwd
- ? scratchpad.template get<float>(key_wino_V)
- : scratchpad.template get<float>(key_wino_M),
- 0, alpha, alpha, jcp.dimN_block,
- jcp.dimK_nb_block, jcp.dimK_block,
- jcp.dimN_reg_block, jcp.dimK_reg_block);
-
- const bool wants_padded_bias = jcp.with_bias
- && jcp.oc_without_padding != jcp.oc;
- float last_slice_bias[simd_w] = {0};
- if (wants_padded_bias) {
- for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
- last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
- }
-
- if (jcp.prop_kind != prop_kind::forward_inference) {
-
- parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block), (jcp.ic_block * jcp.ic_reg_block),
- [&](int ofm1, int ifm1, int ofm2, int ifm2) {
- float *U_base_ptr = is_fwd
- ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
- : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
- weight_transform_data(jcp,
- &(weights(
- ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2,
- ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2,
- 0, 0, 0, 0)),
- U_base_ptr);
- });
- }
-
- parallel_nd(jcp.tile_block, [&](int tile_block) {
- int ithr = mkldnn_get_thread_num();
-
- for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
- for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++) {
-
- input_transform_tileblock_data(
- tile_block, jcp,
- &(input(0, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)),
- &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
- }
- }
-
- for (int oj = 0; oj < alpha; oj++) {
- for (int oi = 0; oi < alpha; oi++) {
- for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++)
- for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++)
- for (int N_blk = 0; N_blk < jcp.dimN_block; N_blk++)
- kernel_->gemm_loop_ker(
- (float *)&(M(ithr, M_blk1, oj, oi,
- N_blk, 0, 0, 0)),
- (const float *)&(U(M_blk1, oj, oi, K_blk1,
- 0, 0, 0, 0)),
- (const float *)&(V(ithr, oj, oi,
- N_blk, K_blk1, 0, 0, 0)), K_blk1);
- }
- }
-
- for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) {
- for (int M_blk2 = 0; M_blk2 < jcp.dimM_block * jcp.dimM_reg_block;
- M_blk2++) {
- const int M_blk =
- M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2;
-
- float *bias_ptr = wants_padded_bias
- && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
- ? last_slice_bias : &bias(M_blk, 0);
-
- output_transform_tileblock_data(tile_block, jcp, p_ops,
- &(M(ithr, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
- &(output(0, M_blk, 0, 0, 0)), bias_ptr);
- }
- }
- });
-}
-
-template struct _jit_avx512_core_fp32_wino_conv_4x3_t<true>;
-template struct _jit_avx512_core_fp32_wino_conv_4x3_t<false>;
-
-namespace {
-
-void subarray_sum(size_t num_arrs, float *output, size_t nelems,
- float *input_ptrs[], size_t input_starts[], size_t input_ends[]) {
- using namespace nstl;
- const size_t block_size = 16 * 1024 / sizeof(float);
- const size_t blocks_number = nelems / block_size;
- const size_t tail = nelems % block_size;
-
-PRAGMA_OMP(parallel)
- {
- const int ithr = mkldnn_get_thread_num();
- const int nthr = mkldnn_get_num_threads();
- size_t start{ 0 }, end{ 0 };
- balance211(blocks_number, nthr, ithr, start, end);
-
- for (size_t nb = start; nb < end; ++nb) {
- size_t start_e = nb * block_size;
- size_t end_e = start_e + block_size;
- size_t input_start = max(start_e, min(input_starts[0], end_e));
- size_t input_end = max(start_e, min(input_ends[0], end_e));
-
- PRAGMA_OMP_SIMD()
- for (size_t e = start_e; e < input_start; e++) {
- output[e] = 0.f;
- }
-
- PRAGMA_OMP_SIMD()
- for (size_t e = input_start; e < input_end; e++) {
- output[e] = input_ptrs[0][e];
- }
-
- PRAGMA_OMP_SIMD()
- for (size_t e = input_end; e < end_e; e++) {
- output[e] = 0.f;
- }
-
- for (size_t a = 1; a < num_arrs; a++) {
- input_start = max(start_e, input_starts[a]);
- input_end = min(input_ends[a], end_e);
-
- PRAGMA_OMP_SIMD()
- for (size_t e = input_start; e < input_end; e++) {
- output[e] += input_ptrs[a][e];
- }
- }
- }
-
- if (tail != 0 && ithr == nthr - 1) {
- size_t start_e = nelems - tail;
- size_t end_e = nelems;
- size_t input_start = max(start_e, min(input_starts[0], end_e));
- size_t input_end = max(start_e, min(input_ends[0], end_e));
-
- PRAGMA_OMP_SIMD()
- for (size_t e = start_e; e < input_start; e++) {
- output[e] = 0.f;
- }
-
- PRAGMA_OMP_SIMD()
- for (size_t e = input_start; e < input_end; e++) {
- output[e] = input_ptrs[0][e];
- }
-
- PRAGMA_OMP_SIMD()
- for (size_t e = input_end; e < end_e; e++) {
- output[e] = 0.f;
- }
-
- for (size_t a = 1; a < num_arrs; a++) {
- input_start = max(start_e, input_starts[a]);
- input_end = min(input_ends[a], end_e);
-
- PRAGMA_OMP_SIMD()
- for (size_t e = input_start; e < input_end; e++) {
- output[e] += input_ptrs[a][e];
- }
- }
- }
- }
-}
-
-const int max_threads_number = 1024;
-
-// Sum to the first buffer array
-void array_sum(size_t num_arrs, float *output,
- size_t nelems, float *input_ptrs[], bool reduce_to_first = true) {
- const size_t block_size = 16 * 1024 / sizeof(float);
- const size_t blocks_number = nelems / block_size;
- const size_t tail = nelems % block_size;
-
-PRAGMA_OMP(parallel)
- {
- const size_t ithr = mkldnn_get_thread_num();
- const size_t nthr = mkldnn_get_num_threads();
- size_t start{ 0 }, end{ 0 };
- balance211(blocks_number, nthr, ithr, start, end);
-
- for (size_t nb = start; nb < end; ++nb) {
- size_t start_e = nb * block_size;
- size_t end_e = start_e + block_size;
- if (!reduce_to_first) {
- PRAGMA_OMP_SIMD()
- for (size_t e = start_e; e < end_e; e++) {
- output[e] = input_ptrs[0][e];
- }
- }
- for (size_t a = 1; a < num_arrs; a++) {
- PRAGMA_OMP_SIMD()
- for (size_t e = start_e; e < end_e; e++) {
- output[e] += input_ptrs[a][e];
- }
- }
- }
-
- if (tail != 0 && ithr == nthr - 1) {
- size_t start_e = nelems - tail;
- size_t end_e = nelems;
- if (!reduce_to_first) {
- PRAGMA_OMP_SIMD()
- for (size_t e = start_e; e < end_e; e++) {
- output[e] = input_ptrs[0][e];
- }
- }
- for (size_t a = 1; a < num_arrs; a++) {
- PRAGMA_OMP_SIMD()
- for (size_t e = start_e; e < end_e; e++) {
- output[e] += input_ptrs[a][e];
- }
- }
- }
- }
-}
-} //bwdw namespace
-
-void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t::
-_execute_backward_weights_SDGtWo(const float *ptr_src,
- const float *ptr_diff_dst, float *ptr_diff_weights,
- float *ptr_diff_bias,
- const memory_tracking::grantor_t &scratchpad) const {
- const auto &jcp = kernel_->jcp;
- const int nthreads = jcp.nthr;
-
- array_offset_calculator<float, 5> src((float *)ptr_src,
- jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
- array_offset_calculator<float, 5> diff_dst((float *)ptr_diff_dst,
- jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
- array_offset_calculator<float, 6> diff_weights(ptr_diff_weights,
- jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
-
- array_offset_calculator<float, 8> Us(scratchpad.get<float>(key_wino_U),
- 0, alpha, alpha,
- jcp.oc_block, jcp.ic_block,
- jcp.ic_simd_block,
- jcp.oc_reg_block,
- jcp.oc_simd_block);
-
- const int U_sz = nthreads * alpha * alpha * jcp.oc / jcp.nb_oc
- * jcp.ic / jcp.nb_ic;
- array_offset_calculator<float, 7>diff_weights_prv(
- scratchpad.get<float>(key_wino_U) + U_sz,
- 0, jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
-
- array_offset_calculator<float, 8> M(scratchpad.get<float>(key_wino_M),
- 0, alpha, alpha,
- jcp.oc_block,
- jcp.nb_tile_block_ur,
- jcp.tile_block_ur,
- jcp.oc_reg_block,
- jcp.oc_simd_block);
-
- array_offset_calculator<float, 7> V(scratchpad.get<float>(key_wino_V),
- 0, alpha, alpha,
- jcp.ic_block,
- jcp.nb_tile_block_ur,
- jcp.tile_block_ur,
- jcp.ic_simd_block);
-
- array_offset_calculator<float, 2> diff_bias_prv(
- scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc);
-
- auto trans_ker_p = jit_wino_transform_call_s();
- float I[alpha][alpha][simd_w];
- float T[alpha][alpha][simd_w];
- float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
- 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
- float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, 0.119514472455649f,
- 0.430252100840336f, 0.168067226890756f, 0.179271708683473f, 0.403361344537815f,
- 1.13777777777778f};
- float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
-
-PRAGMA_OMP(parallel num_threads(nthreads) firstprivate(trans_ker_p, I, T))
-{
- if (jcp.with_bias) {
- parallel_nd_in_omp(nthreads, jcp.oc / simd_w,
- [&](int ithr, int ofm){
- float *pdbias = &(diff_bias_prv(ithr, ofm * simd_w));
- PRAGMA_OMP_SIMD()
- for (int v = 0; v < simd_w; v++) {
- pdbias[v] = 0.0f;
- }
- });
- }
-
- int ithr = mkldnn_get_thread_num();
- for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
- int first_tblk = 0;
-PRAGMA_OMP(for)
- for (int tblk1 = 0; tblk1 < jcp.tile_block; ++tblk1) {
- int tile_index = tblk1 * jcp.nb_tile_block_ur * jcp.tile_block_ur;
- int img = tile_index / (jcp.itiles * jcp.jtiles);
- trans_ker_p.ti = tile_index % jcp.itiles;
- trans_ker_p.tj = (tile_index / jcp.itiles) % jcp.jtiles;
- trans_ker_p.M = I;
- trans_ker_p.T = T;
- trans_ker_p.G = G_I_3x3_4x4;
- for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
- int ifm = ifm1 * jcp.ic_block + ifm2;
- trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0));
- trans_ker_p.dst = (float *)&(V(ithr, 0, 0, ifm2, 0, 0, 0));
- kernel_->src_transform(&trans_ker_p);
- }
-
- for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) {
- trans_ker_p.G = G_W_3x3_4x4;
- for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
- int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block;
- trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0));
- trans_ker_p.dst = (float *)&(M(ithr, 0, 0, ofm2, 0, 0, 0, 0));
- if (jcp.with_bias && ifm1 == 0) {
- trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w));
- kernel_->diff_dst_transform_wbias(&trans_ker_p);
- } else {
- kernel_->diff_dst_transform(&trans_ker_p);
- }
- }
-
- for (int oj = 0; oj < alpha; ++oj) {
- for (int oi = 0; oi < alpha; ++oi) {
- kernel_->gemm_loop_ker_first_iter(
- &(Us(ithr, oj, oi, 0, 0, 0, 0, 0)),
- &(M(ithr, oj, oi, 0, 0, 0, 0, 0)),
- &(V(ithr, oj, oi, 0, 0, 0, 0)));
- }
- }
- trans_ker_p.G = G_O_3x3_4x4;
- for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
- for (int ofm3 = 0; ofm3 < jcp.oc_reg_block; ++ofm3) {
- int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block
- + ofm3;
- for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
- int ifm = ifm1 * jcp.ic_block + ifm2;
- trans_ker_p.src = (float *)&(Us(ithr, 0, 0,
- ofm2, ifm2, 0, ofm3, 0));
- trans_ker_p.dst = (float *)&(diff_weights_prv(ithr,
- ofm, ifm, 0, 0, 0, 0));
- if (first_tblk == 0) {
- kernel_->diff_weights_transform(&trans_ker_p);
- } else {
- kernel_->diff_weights_transform_accum(&trans_ker_p);
- }
- }
- }
- }
- }
- ++first_tblk;
- }
- }
-}
-
- // Reduce diff-weights
- {
- float *output = ptr_diff_weights;
- float *input_base = scratchpad.get<float>(key_wino_U) + U_sz;
- int nelems = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
- float *input_ptrs[max_threads_number];
- for (int i = 0; i < nthreads; ++i) {
- input_ptrs[i] = input_base + nelems * i;
- }
- array_sum(nthreads, output, nelems, input_ptrs, false);
-
- if (jcp.with_bias) {
- output = ptr_diff_bias;
- input_base = scratchpad.get<float>(key_conv_bia_reduction);
- for (int i = 0; i < nthreads; ++i) {
- input_ptrs[i] = input_base + jcp.oc * i;
- }
- array_sum(nthreads, output, jcp.oc_without_padding, input_ptrs,
- false);
- }
- }
-}
-
-void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t::
-_execute_backward_weights_S_D_Giot_W(const float *ptr_src,
- const float *ptr_diff_dst, float *ptr_diff_weights,
- float *ptr_diff_bias,
- const memory_tracking::grantor_t &scratchpad) const {
- const auto &jcp = kernel_->jcp;
- const int nthreads = jcp.nthr;
-
- array_offset_calculator<float, 5> src((float *)ptr_src,
- jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
- array_offset_calculator<float, 5> diff_dst((float *)ptr_diff_dst,
- jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
- array_offset_calculator<float, 6> diff_weights((float *)ptr_diff_weights,
- jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
- array_offset_calculator<float, 1> diff_bias((float *)ptr_diff_bias, jcp.oc);
-
- array_offset_calculator<float, 9> U(scratchpad.get<float>(key_wino_U),
- jcp.nb_ic, jcp.nb_oc,
- alpha, alpha,
- jcp.oc_block, jcp.ic_block,
- jcp.ic_simd_block,
- jcp.oc_reg_block,
- jcp.oc_simd_block);
-
- const int U_size = jcp.oc * jcp.ic * alpha * alpha;
- array_offset_calculator<float, 10> Us(
- scratchpad.get<float>(key_wino_U) + U_size,
- 0, jcp.nb_ic, jcp.nb_oc,
- alpha, alpha,
- jcp.oc_block, jcp.ic_block,
- jcp.ic_simd_block,
- jcp.oc_reg_block,
- jcp.oc_simd_block);
-
- array_offset_calculator<float, 9> M(scratchpad.get<float>(key_wino_M),
- jcp.nb_oc,
- jcp.tile_block,
- alpha, alpha,
- jcp.oc_block,
- jcp.nb_tile_block_ur,
- jcp.tile_block_ur ,
- jcp.oc_reg_block,
- jcp.oc_simd_block);
-
- array_offset_calculator<float, 8> V(scratchpad.get<float>(key_wino_V),
- jcp.nb_ic,
- jcp.tile_block,
- alpha, alpha,
- jcp.ic_block,
- jcp.nb_tile_block_ur, jcp.tile_block_ur,
- jcp.ic_simd_block);
-
- array_offset_calculator<float, 2> diff_bias_prv(
- scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc);
-
- size_t input_starts[max_threads_number] = {0};
- size_t input_ends[max_threads_number] = {0};
- size_t first_tblk = 0;
-
- auto trans_ker_p = jit_wino_transform_call_s();
- float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
- 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
- float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f,
- 0.119514472455649f, 0.430252100840336f, 0.168067226890756f,
- 0.179271708683473f, 0.403361344537815f, 1.13777777777778f};
- float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
- float I[alpha][alpha][simd_w];
- float T[alpha][alpha][simd_w];
-
-PRAGMA_OMP(parallel firstprivate(first_tblk, trans_ker_p, I, T))
-{
- if (jcp.with_bias) {
- parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) {
- diff_bias_prv(ithr, ofm) = 0.0f;
- });
- }
-
- trans_ker_p.G = G_I_3x3_4x4;
- trans_ker_p.M = I;
- trans_ker_p.T = T;
-
- parallel_nd_in_omp(jcp.nb_ic, jcp.ic_block, jcp.mb,
- [&](int ifm1, int ifm2, int img){
- size_t ifm = ifm1 * jcp.ic_block + ifm2;
- size_t tile_base_index = img * (jcp.itiles * jcp.jtiles);
- size_t tblk3 = tile_base_index % jcp.tile_block_ur;
- size_t tblk2 = (tile_base_index / jcp.tile_block_ur)
- % jcp.nb_tile_block_ur;
- size_t tblk1 = (tile_base_index / jcp.tile_block_ur)
- / jcp.nb_tile_block_ur;
- trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3;
- trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0));
- trans_ker_p.dst = (float *)&(V(ifm1, tblk1, 0, 0, ifm2, 0, 0, 0));
- kernel_->src_transform(&trans_ker_p);
- });
-
- int ithr = mkldnn_get_thread_num();
- trans_ker_p.G = G_W_3x3_4x4;
- parallel_nd_in_omp(jcp.nb_oc, jcp.oc_block, jcp.mb,
- [&](int ofm1, int ofm2, int img){
- int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block;
- size_t tile_base_index = img * (jcp.itiles * jcp.jtiles);
- size_t tblk3 = tile_base_index % jcp.tile_block_ur;
- size_t tblk2 = (tile_base_index / jcp.tile_block_ur)
- % jcp.nb_tile_block_ur;
- size_t tblk1 = (tile_base_index / jcp.tile_block_ur)
- / jcp.nb_tile_block_ur;
- trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3;
- trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0));
- trans_ker_p.dst = (float *)&(M(ofm1, tblk1, 0, 0, ofm2, 0, 0, 0, 0));
- if (jcp.with_bias) {
- trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w));
- kernel_->diff_dst_transform_wbias(&trans_ker_p);
- } else {
- kernel_->diff_dst_transform(&trans_ker_p);
- }
- });
-
- PRAGMA_OMP(barrier)
-
- parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, alpha, alpha, jcp.tile_block,
- [&](int ifm1, int ofm1, int oj, int oi, int tblk1){
- if (first_tblk == 0) {
- input_starts[ithr] =
- (float *)&(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0,
- 0, 0))
- - (float *)&(Us(ithr, 0, 0, 0, 0, 0, 0,
- 0, 0, 0));
- input_ends[ithr] = input_starts[ithr]
- + jcp.oc_block * jcp.ic_block
- * jcp.ic_simd_block * jcp.oc_reg_block
- * jcp.oc_simd_block;
- }
- else if (tblk1 == 0) {
- input_ends[ithr] += jcp.oc_block * jcp.ic_block
- * jcp.ic_simd_block * jcp.oc_reg_block
- * jcp.oc_simd_block;
- }
-
- if (first_tblk == 0 || tblk1 == 0) {
- kernel_->gemm_loop_ker_first_iter(
- &(Us(ithr, ifm1, ofm1, oj, oi,
- 0, 0, 0, 0, 0)),
- &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)),
- &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0)));
- } else {
- kernel_->gemm_loop_ker(
- &(Us(ithr, ifm1, ofm1, oj, oi,
- 0, 0, 0, 0, 0)),
- &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)),
- &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0)));
- }
- ++first_tblk;
- });
-}
-
- // Reduce diff-weights
- {
- float *output = &(U(0, 0, 0, 0, 0, 0, 0, 0, 0));
- size_t nelems = jcp.ic * jcp.oc * alpha * alpha;
- float *input_ptrs[max_threads_number];
- for (int i = 0; i < nthreads; ++i)
- input_ptrs[i] = output + nelems * (i + 1);
- subarray_sum(nthreads, output, nelems, input_ptrs,
- input_starts, input_ends);
- }
-
- trans_ker_p.G = G_O_3x3_4x4;
-PRAGMA_OMP(parallel firstprivate(trans_ker_p))
- {
- parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, jcp.oc_reg_block,
- [&](int ifm1, int ofm1, int ofm2, int ifm2, int ofm3){
- int ofm = (ofm1 * jcp.oc_block + ofm2)
- * jcp.oc_reg_block + ofm3;
- int ifm = ifm1 * jcp.ic_block + ifm2;
- trans_ker_p.src = (float *)&(U(ifm1, ofm1, 0, 0,
- ofm2, ifm2, 0, ofm3, 0));
- trans_ker_p.dst = (float *)&(diff_weights(ofm, ifm,
- 0, 0, 0, 0));
- kernel_->diff_weights_transform(&trans_ker_p);
- });
- }
-
- if (jcp.with_bias) {
- parallel_nd(jcp.oc / simd_w, [&](int ofm1) {
- float* pbias = &(diff_bias(ofm1 * simd_w));
- float *pbias_prv = &(diff_bias_prv(0, ofm1 * simd_w));
-
- const int blk_sz = ofm1 == jcp.oc / simd_w - 1
- ? jcp.oc_without_padding - ofm1 * simd_w : simd_w;
-
- PRAGMA_OMP_SIMD()
- for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) {
- pbias[ofm2] = pbias_prv[ofm2];
- }
-
- for (int ithr = 1; ithr < nthreads; ++ithr) {
- pbias_prv = &(diff_bias_prv(ithr, ofm1 * simd_w));
- PRAGMA_OMP_SIMD()
- for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) {
- pbias[ofm2] += pbias_prv[ofm2];
- }
- }
- });
- }
-}
-
-}
-}
-}
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp
deleted file mode 100644
index f1a56aac70..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp
+++ /dev/null
@@ -1,386 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_AVX512_CORE_FP32_WINO_CONV_4x3_HPP
-#define CPU_JIT_AVX512_CORE_FP32_WINO_CONV_4x3_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-
-#include "cpu_convolution_pd.hpp"
-#include "cpu_primitive.hpp"
-
-#include "jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-namespace winograd_avx512_core {
-inline void init_scratchpad(memory_tracking::registrar_t &scratchpad,
- const jit_conv_winograd_conf_t &jcp) {
- using namespace utils;
- using namespace memory_tracking::names;
-
- size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc;
- size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic * jcp.itiles
- * jcp.jtiles;
- size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc * jcp.itiles
- * jcp.jtiles;
-
- switch (jcp.sched_policy) {
- case WSCHED_DATA_W_SGD:
- V_sz = (size_t)jcp.nthr * alpha * alpha * jcp.nb_tile_block_ur
- * jcp.tile_block_ur * jcp.ic;
- M_sz = (size_t)jcp.nthr * alpha * alpha * jcp.nb_tile_block_ur
- * jcp.tile_block_ur * jcp.oc;
- break;
- case WSCHED_WEI_SDGtWo:
- U_sz = (size_t)jcp.nthr * (alpha * alpha * jcp.oc
- * (jcp.ic / jcp.nb_ic) + jcp.ic * jcp.oc * jcp.kh * jcp.kw);
- M_sz = (size_t)jcp.nthr * alpha * alpha * (jcp.ntiles / jcp.tile_block)
- * (jcp.oc / jcp.nb_oc);
- V_sz = (size_t)jcp.nthr * alpha * alpha * (jcp.ntiles / jcp.tile_block)
- * (jcp.ic / jcp.nb_ic);
- break;
- case WSCHED_WEI_S_D_Giot_W:
- U_sz = (size_t)(jcp.nthr + 1) * alpha * alpha * jcp.ic * jcp.oc;
- M_sz = (size_t)alpha * alpha * jcp.oc * jcp.ntiles;
- V_sz = (size_t)alpha * alpha * jcp.ic * jcp.ntiles;
- break;
- default: break;
- }
-
- scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M);
- scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M);
- scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M);
-
- if (one_of(jcp.sched_policy, WSCHED_WEI_SDGtWo, WSCHED_WEI_S_D_Giot_W)) {
- size_t br_sz = (size_t)jcp.nthr * jcp.oc;
- scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M);
- }
-}
-}
-
-template <bool is_fwd>
-struct _jit_avx512_core_fp32_wino_conv_4x3_t {
-
- _jit_avx512_core_fp32_wino_conv_4x3_t(
- const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr)
- : kernel_(nullptr), attr_(attr) {
- kernel_ = new _jit_avx512_core_fp32_wino_conv_4x3_data_kernel(jcp);
- }
-
- ~_jit_avx512_core_fp32_wino_conv_4x3_t() { delete kernel_; }
-
- protected:
- void weight_transform_data(const jit_conv_winograd_conf_t &jcp,
- float *wp, float *twp) const;
- void input_transform_data(int image,
- const jit_conv_winograd_conf_t &jcp,
- float *inp, float *tinp) const;
- void input_transform_tileblock_data(int tile_block,
- const jit_conv_winograd_conf_t &jcp,
- float *inp, float *tinp) const;
- void output_transform_data(int image,
- const jit_conv_winograd_conf_t &jcp,
- const post_ops_t &p_ops, float *toutp, float *pout_b,
- float *bias) const;
- void output_transform_tileblock_data(int tile_block,
- const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops,
- float *toutp, float *outp, float *bias) const;
- void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr,
- float *wei_ptr, float *bias_ptr,
- const memory_tracking::grantor_t &scratchpad) const;
- void _execute_data_W_SGD(float *inp_ptr, float *out_ptr,
- float *wei_ptr, float *bias_ptr,
- const memory_tracking::grantor_t &scratchpad) const;
- _jit_avx512_core_fp32_wino_conv_4x3_data_kernel *kernel_;
- const primitive_attr_t *attr_;
-};
-
-struct jit_avx512_core_fp32_wino_conv_4x3_fwd_t
- : _jit_avx512_core_fp32_wino_conv_4x3_t<true>
- , public cpu_primitive_t
- {
- struct pd_t : public cpu_convolution_fwd_pd_t {
- pd_t(engine_t *engine, const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const typename pd_t::base_class *hint_fwd_pd)
- : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""),
- jit_avx512_core_fp32_wino_conv_4x3_fwd_t);
-
- status_t init() {
- bool ok = true
- && is_fwd()
- && utils::one_of(desc()->alg_kind,
- alg_kind::convolution_auto,
- alg_kind::convolution_winograd)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::f32, data_type::f32, data_type::f32)
- && set_default_formats();
- if (!ok) return status::unimplemented;
-
- status_t status = jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::
- init_conf(jcp_, *desc(), src_md_, weights_md_, dst_md_,
- *attr());
- if (status != status::success) return status;
- set_default_alg_kind(alg_kind::convolution_winograd);
-
- auto scratchpad = scratchpad_registry().registrar();
- winograd_avx512_core::init_scratchpad(scratchpad, jcp_);
-
- return status;
- }
-
- jit_conv_winograd_conf_t jcp_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
- auto wei_fmt = desc()->prop_kind == prop_kind::forward_training
- ? (with_groups() ? gOIhw16i16o : OIhw16i16o) : any;
- return set_default_formats_common(nChw16c, wei_fmt, nChw16c);
- }
- };
-
- jit_avx512_core_fp32_wino_conv_4x3_fwd_t(const pd_t *apd)
- : _jit_avx512_core_fp32_wino_conv_4x3_t<true>(apd->jcp_, apd->attr())
- , cpu_primitive_t(apd, true)
- {}
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST);
-
- auto scratchpad = this->scratchpad(ctx);
-
- switch ((pd()->jcp_).sched_policy) {
- case WSCHED_DATA_W_S_G_D:
- this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights,
- (float *)bias, scratchpad);
- break;
- case WSCHED_DATA_W_SGD:
- this->_execute_data_W_SGD((float *)src, dst, (float *)weights,
- (float *)bias, scratchpad);
- break;
- default:
- break;
- }
- return status::success;
- }
-
-private:
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-struct jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t
- : _jit_avx512_core_fp32_wino_conv_4x3_t<false>,
- public cpu_primitive_t {
- struct pd_t : public cpu_convolution_bwd_data_pd_t {
- pd_t(engine_t *engine, const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""),
- jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t);
-
- status_t init() {
- bool ok = true
- && mkldnn_thr_syncable()
- && desc()->prop_kind == prop_kind::backward_data
- && utils::one_of(desc()->alg_kind,
- alg_kind::convolution_auto,
- alg_kind::convolution_winograd)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::undef, data_type::f32, data_type::f32)
- && set_default_formats();
- if (!ok) return status::unimplemented;
-
- status_t status = jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel
- ::init_conf(jcp_, *desc(), *diff_src_md(), *weights_md(),
- *diff_dst_md());
- if (status != status::success) return status;
- set_default_alg_kind(alg_kind::convolution_winograd);
-
- auto scratchpad = scratchpad_registry().registrar();
- winograd_avx512_core::init_scratchpad(scratchpad, jcp_);
-
- return status;
- }
-
- jit_conv_winograd_conf_t jcp_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
- auto wei_fmt = with_groups() ? gOIhw16i16o : OIhw16i16o;
- return set_default_formats_common(nChw16c, wei_fmt, nChw16c);
- }
- };
-
- jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t(const pd_t *apd)
- : _jit_avx512_core_fp32_wino_conv_4x3_t<false>(apd->jcp_, apd->attr())
- , cpu_primitive_t(apd, true)
- {}
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST);
- auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS);
- auto diff_src = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC);
-
- auto scratchpad = this->scratchpad(ctx);
-
- switch ((pd()->jcp_).sched_policy) {
- case WSCHED_DATA_W_S_G_D:
- this->_execute_data_W_S_G_D((float *)diff_dst, diff_src,
- (float *)weights, NULL, scratchpad);
- break;
-
- case WSCHED_DATA_W_SGD:
- this->_execute_data_W_SGD((float *)diff_dst, diff_src,
- (float *)weights, NULL, scratchpad);
- break;
-
- default:
- break;
- }
-
- return status::success;
- }
-
-private:
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-struct jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t
- : public cpu_primitive_t {
- struct pd_t : public cpu_convolution_bwd_weights_pd_t {
- pd_t(engine_t *engine, const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""),
- jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t);
-
- status_t init() {
- bool ok = true
- && mkldnn_thr_syncable()
- && desc()->prop_kind == prop_kind::backward_weights
- && utils::one_of(desc()->alg_kind,
- alg_kind::convolution_auto,
- alg_kind::convolution_winograd)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::f32, data_type::f32, data_type::f32)
- && set_default_formats();
- if (!ok)
- return status::unimplemented;
-
- status_t status =
- jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::
- init_conf(jcp_, *desc(), *src_md(), *diff_dst_md(),
- *diff_weights_md());
- if (status != status::success) return status;
- set_default_alg_kind(alg_kind::convolution_winograd);
-
- auto scratchpad = scratchpad_registry().registrar();
- winograd_avx512_core::init_scratchpad(scratchpad, jcp_);
-
- return status;
- }
-
- jit_conv_winograd_conf_t jcp_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
- auto wei_fmt = with_groups() ? gOIhw16i16o : OIhw16i16o;
- return set_default_formats_common(nChw16c, wei_fmt, nChw16c);
- }
- };
-
- jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t(const pd_t *apd)
- : cpu_primitive_t(apd, true)
- , kernel_(nullptr)
- {
- kernel_ = new jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel(
- pd()->jcp_);
- }
-
- ~jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t()
- {
- delete kernel_;
- }
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST);
- auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC);
- auto diff_weights = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS);
- auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS);
-
- switch (kernel_->jcp.sched_policy) {
- case WSCHED_WEI_SDGtWo:
- _execute_backward_weights_SDGtWo(src, diff_dst, diff_weights,
- diff_bias, scratchpad(ctx));
- break;
- case WSCHED_WEI_S_D_Giot_W:
- _execute_backward_weights_S_D_Giot_W(src, diff_dst, diff_weights,
- diff_bias, scratchpad(ctx));
- break;
- default:
- assert(kernel_->jcp.sched_policy != WSCHED_INVALID);
- break;
- }
- return status::success;
- }
-
-private:
- void _execute_backward_weights_SDGtWo(const float *src,
- const float *diff_dst, float *diff_weights, float *diff_bias,
- const memory_tracking::grantor_t &scratchpad) const;
- void _execute_backward_weights_S_D_Giot_W(const float *src,
- const float *diff_dst, float *diff_weights, float *diff_bias,
- const memory_tracking::grantor_t &scratchpad) const;
-
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel *kernel_;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp
deleted file mode 100644
index 0d64a2d13a..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp
+++ /dev/null
@@ -1,2596 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include <math.h>
-
-#include "jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp"
-
-#define GET_OFF(field) offsetof(jit_wino_transform_call_s, field)
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-namespace {
-
-using namespace mkldnn::impl::utils;
-
-unsigned int L1_cache_size = get_cache_size(1, true);
-unsigned int L2_cache_size = get_cache_size(2, true);
-unsigned int LLC_data_size = get_cache_size(3, false);
-
-// the test funtion takes jcp, the candidate and the current best.
-// it returns true if the new candidate is better
-int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number,
- int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int))
-{
- int best_divisor = default_best;
- auto test_num
- = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) {
- if (test(jcp, num, best_divisor)) {
- best_divisor = num;
- }
- };
-
- for (int divisor = 1; divisor <= ::sqrt(number); divisor++) {
- if (number % divisor == 0) {
- test_num(jcp, divisor);
- test_num(jcp, number / divisor);
- }
- }
-
- return best_divisor;
-}
-
-namespace {
-bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) {
- /* Determines if current winograd implementation is faster than direct.
- Following conditions are empirical and based on performance data */
- unsigned int ncores_per_socket =
- cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel);
- unsigned int nthreads = mkldnn_get_max_threads();
-
- if (jcp.prop_kind == prop_kind::forward_inference) {
- return jcp.mb >= 4;
- } else if (nthreads > ncores_per_socket) {
- double src_dst_transforms_per_core = alpha * alpha
- * (jcp.ic + jcp.oc)
- * jcp.mb * ((jcp.oh + tile_size - 1) / tile_size)
- * ((jcp.ow + tile_size - 1) / tile_size)
- * sizeof(float) / 1024. / 1024. / nthreads;
- double wei_transform = alpha * alpha
- * jcp.ic * jcp.oc * sizeof(float) /1024. / 1024.;
-
- if (jcp.prop_kind == prop_kind::backward_weights) {
- if (src_dst_transforms_per_core < 0.3
- || (src_dst_transforms_per_core <= 28 && wei_transform < 4))
- return false;
- else
- return true;
- } else {
- if (src_dst_transforms_per_core < 2.0 || wei_transform < 0.02)
- return false;
- }
- }
-
- return jcp.mb > 8;
-}
-}
-
-/* assumes 512 bits registers */
-/* TODO: add support for strides */
-/* TODO: handle the prefetch distance automatically */
-typedef enum cache_t_ { L1, L2, L3 } cache_t;
-
-template <typename data_t>
-struct prefetcher_t {
- prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr,
- cache_t cache_type, size_t block_size, /* in number of elements*/
- int nb_instructions_in_block, int fma_ipc)
- : cg_(generator)
- , reg_base_addr_(reg_base_addr)
- , cache_type_(cache_type)
- , cache_block_size_(block_size)
- {
- nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t));
- prefetch_spread_
- = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_);
- prefetch_blk_
- = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block);
-
- /* assumption: when fetch in Li, data is already in L(i+1) */
- int cache_latency;
- switch (cache_type_) {
- case L1: cache_latency = 14; break;
- case L2: cache_latency = 250; break;
- case L3: cache_latency = 250; break;
- }
-
- prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_);
- }
-
- void prefetch(int instruction_number)
- {
- if (instruction_number % prefetch_spread_ == 0) {
- for (int i = 0; (i < prefetch_blk_)
- && (prefetches_issued_ < nb_cache_lines_to_prefetch_);
- i++, prefetches_issued_++) {
- prefetch_inst_(cg_->EVEX_compress_addr(
- reg_base_addr_, (cache_block_size_ * prefetch_distance_)
- * sizeof(data_t)
- + (prefetches_issued_ * 64)));
- }
- }
- }
-
-private:
- void prefetch_inst_(const Xbyak::Address &addr)
- {
- switch (cache_type_) {
- case L1: cg_->prefetcht0(addr); break;
- case L2: cg_->prefetcht1(addr); break;
- case L3: cg_->prefetcht2(addr); break;
- default:
- break; // TODO: raise an exception or put an assert
- }
- }
-
- jit_generator *cg_;
- Xbyak::Reg64 reg_base_addr_;
- cache_t cache_type_;
- int cache_block_size_ = 0;
- int nb_cache_lines_to_prefetch_ = 0;
- int prefetches_issued_ = 0;
- int prefetch_spread_ = 0;
- int prefetch_blk_ = 0;
- int prefetch_distance_ = 0;
-};
-
-// utilities to support kernel parameter selection
-bool check_L2_block_per_thread(jit_conv_winograd_conf_t &jcp,
- int dimN_block, float C2_min, float C2_max) {
- float block_size = alpha * alpha * (2*(jcp.oc + jcp.ic)
- * dimN_block * jcp.dimN_reg_block
- + div_up(jcp.ic * jcp.oc,mkldnn_get_max_threads())) * (float)sizeof(float);
- float L2_lb = C2_min * L2_cache_size;
- float L2_ub = C2_max * L2_cache_size;
- return (block_size > L2_lb && block_size < L2_ub);
-}
-
-bool check_L1_block_gemm(jit_conv_winograd_conf_t &jcp, int dimK_block,
- int dimM_block, float C1_min, float C1_max) {
- float gemm_block_size = (dimM_block * jcp.dimM_simd_block * dimK_block
- * jcp.dimK_reg_block * jcp.dimM_reg_block
- + dimK_block * jcp.dimK_reg_block * jcp.dimN_reg_block
- + dimM_block * jcp.dimM_simd_block * jcp.dimN_reg_block)
- * (float)sizeof(float);
- float L1_lb = C1_min * L1_cache_size;
- float L1_ub = C1_max * L1_cache_size;
- return (gemm_block_size > L1_lb && gemm_block_size < L1_ub);
-}
-bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block,
- int dimM_block, int dimM_reg_block, int dimM_simd_block, float C)
-{
- float lhs = (dimM_block * dimN_reg_block * dimM_simd_block * dimM_reg_block
- + dimM_block * dimK_block * dimK_reg_block
- * dimM_simd_block * dimM_reg_block
- + dimK_block * dimN_reg_block * dimK_reg_block)
- * (float)sizeof(float);
- float rhs = C * L1_cache_size;
- return (lhs < rhs);
-}
-bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block,
- int dimM_block, int dimM_reg_block, int dimM_simd_block, float C)
-{
- float lhs = (dimM_block * dimM_reg_block * dimK_block * dimK_reg_block
- * dimM_simd_block + dimK_block * dimN_reg_block * dimK_reg_block)
- * (float)sizeof(float);
- float rhs = C * L1_cache_size;
- return (lhs < rhs);
-}
-bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block,
- int dimK_block, int dimK_reg_block, int dimM_block, int dimM_reg_block,
- int dimM_simd_block, float C)
-{
- float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block
- * dimM_simd_block * dimM_reg_block
- + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block
- * dimM_simd_block * dimM_reg_block
- + nb_dimN_reg_block * dimK_nb_block * dimK_block
- * dimN_reg_block * dimK_reg_block)
- * (float)sizeof(float);
- float rhs = C * L2_cache_size;
- return (lhs < rhs);
-}
-
-bool check_kernel_cond(int dimM_block, int dimM_reg_block, int dimM_simd_block,
- int dimN_block, int dimN_reg_block, int dimK, float C1, float C2)
-{
- float A_size = dimM_block * dimM_reg_block * dimM_simd_block * dimK
- * (float)sizeof(float);
- float B_size = dimN_block * dimN_reg_block * dimK
- * (float)sizeof(float);
- return (A_size > C1 * L2_cache_size && B_size > C2 * L2_cache_size);
-}
-}
-
-using namespace mkldnn::impl::format_tag;
-using namespace mkldnn::impl::utils;
-using namespace Xbyak;
-
-void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::gemm_loop_generate()
-{
- // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++)
- // for (int dimM_reg_block =0; dimM_reg_block < jcp.dimM_reg_block;
- // dimM_reg_block++) // unrolled
- // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++)
- // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block;
- // dimK_reg_block++) // unrolled
- // for (int tile =0; tile < jcp.dimN_reg_block; tile++)
- // C[dimM_block][dimM_reg_block][tile] +=
- // A[dimM_block][dimM_reg_block][dimK_block][dimK_reg_block]
- // * broadcast(B[dimK_block][tile][dimK_reg_block]);
- // Notes:
- // jcp.kernel_kind defines embedded or explicit broadcast
- // dimM_reg_block=1 for embedded bcast kernel
-
- auto zmm_srcA = [=]() {
- return Xbyak::Zmm(0);
- };
- auto zmm_srcB = [=](int tile) {
- int idx = 1 + tile;
- assert(idx < 1 + jcp.dimN_reg_block);
- return Xbyak::Zmm(idx);
- };
- auto zmm_dstC = [=](int dimM_reg_block, int tile) {
- int idx{0};
- if (jcp.kernel_kind == embd_bcast)
- idx = 1 + tile;
- else
- idx = 1 + jcp.dimN_reg_block
- + dimM_reg_block * jcp.dimN_reg_block + tile;
- assert(idx < 32);
- return Xbyak::Zmm(idx);
- };
-
- auto prepare_output = [=]() {
- for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
- dimM_reg_block++) {
- for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
- Zmm zmm = zmm_dstC(dimM_reg_block, tile);
- vpxord(zmm, zmm, zmm);
- }
- }
- };
- auto store_output = [=](bool output_is_aligned) {
- Label save;
- cmp(reg_is_beta_zero, 0);
- je(save, T_NEAR);
-
- for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
- dimM_reg_block++) {
- for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
- Zmm zmm = zmm_dstC(dimM_reg_block,tile);
- int output_offset
- = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64;
- vaddps(zmm, zmm, EVEX_compress_addr(reg_dstC, output_offset));
- }
- }
-
- L(save);
- for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
- dimM_reg_block++) {
- for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
- Zmm zmm = zmm_dstC(dimM_reg_block,tile);
- int output_offset
- = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64;
-
- // In W_SGD, output will be reused.
- if (output_is_aligned
- && jcp.dimK_nb_block == 1
- && jcp.sched_policy == WSCHED_DATA_W_S_G_D
- && (jcp.dimN * jcp.dimM * alpha * alpha
- * sizeof(float) > 2 * LLC_data_size))
- vmovntps(EVEX_compress_addr(reg_dstC, output_offset), zmm);
- else vmovups(EVEX_compress_addr(reg_dstC, output_offset), zmm);
- }
- }
- };
-
- auto inner_loops = [=]() {
- Label dimM_block_loop, dimK_block_loop;
-
- if (jcp.dimM_block > 1) {
- mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
- L(dimM_block_loop);
- }
-
- prepare_output();
-
- if (jcp.dimK_block > 1) {
- mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
- L(dimK_block_loop);
- }
-
- for (int dimK_reg_block = 0;
- dimK_reg_block < jcp.dimK_reg_block;
- dimK_reg_block ++) {
-
- if (jcp.kernel_kind == expl_bcast) {
- for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
- vbroadcastss(zmm_srcB(tile),
- ptr[reg_srcB + 64 * tile + dimK_reg_block * 4]);
- }
- }
-
- /* Performing the fmas */
-
- for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
- dimM_reg_block++) {
-
- vmovups(zmm_srcA(),
- zword[reg_srcA
- + jcp.dimK_reg_block * jcp.dimK_block * 64
- * dimM_reg_block
- + dimK_reg_block * 64]
- );
-
- for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
- if (jcp.kernel_kind == expl_bcast)
- vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(),
- zmm_srcB(tile));
- else
- vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(),
- EVEX_compress_addr(reg_srcB,
- 64 * tile + dimK_reg_block * 4, true));
- }
- }
- }
- add(reg_srcA, jcp.dimK_reg_block * 64);
- add(reg_srcB, jcp.dimN_reg_block * 64);
- if (jcp.dimK_block > 1) {
- sub(reg_dimK_block_loop_cnt, 1);
- jnz(dimK_block_loop);
- }
-
- Label unaligned_store, end_store;
- test(reg_dstC, cpu_isa_traits<avx512_core>::vlen - 1);
- jnz(unaligned_store, T_NEAR);
- store_output(true);
- jmp(end_store, T_NEAR);
- L(unaligned_store); {
- store_output(false);
- }
- L(end_store);
-
- if (jcp.dimM_block > 1) {
- sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64);
- add(reg_dstC, jcp.dimM_reg_block * jcp.dimN_reg_block * 64);
- if (jcp.kernel_kind == expl_bcast) {
- add(reg_srcA,
- (jcp.dimM_reg_block-1) * jcp.dimK_reg_block * 64
- * jcp.dimK_block);
- }
- sub(reg_dimM_block_loop_cnt, 1);
- jnz(dimM_block_loop);
- }
- };
-
- /* Preamble */
- preamble();
-
- /* kernel */
- inner_loops();
-
- /* Postamble */
- postamble();
- ret();
-}
-
-void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel
- ::weights_transform_data_ker_generate()
-{
- bool is_fwd = one_of(jcp.prop_kind,
- mkldnn_forward_training, mkldnn_forward_inference);
- int kh = jcp.kh;
- int kw = jcp.kw;
-
- auto zmm_temp = Xbyak::Zmm(31);
- auto zmm_zero = Xbyak::Zmm(30);
-
- auto zmm_M = [=](int i) {
- return Xbyak::Zmm(i);
- };
- auto zmm_MT = [=](int i) {
- return Xbyak::Zmm(i + simd_w);
- };
-
- auto zmm_G = [=](int i) {
- return Xbyak::Zmm(i);
- };
- auto zmm_F = [=](int i) {
- return Xbyak::Zmm(alpha + i);
- };
- auto zmm_T = [=](int i) {
- return Xbyak::Zmm(alpha + 3 + i);
- };
- auto zmm_t = [=](int i) {
- return Xbyak::Zmm(2 * alpha + 3 + i);
- };
-
- auto zmm_load = [=](int i) {
- return Xbyak::Zmm(i);
- };
-
- auto init_G = [=]() {
- mov(wreg_temp, ptr[param1 + GET_OFF(G)]);
- for (int i = 0; i < alpha; i++) {
- vbroadcastss(zmm_G(i), ptr[wreg_temp + i * typesize]);
- }
- vpxord(zmm_zero, zmm_zero, zmm_zero);
- };
-
- auto trans16x16 = [=]() {
- for (int i = 0; i < simd_w; i+=2 ) {
- vmovups(zmm_M(i), ptr[wreg_M + i * simd_w * 4]);
- vmovups(zmm_M(i+1), ptr[wreg_M + (i + 1) * simd_w * 4]);
- vunpcklps(zmm_MT(i), zmm_M(i), zmm_M(i+1));
- vunpckhps(zmm_MT(i+1), zmm_M(i), zmm_M(i+1));
- }
- for (int i = 0; i < simd_w; i+=4 ) {
- vunpcklpd(zmm_M(i), zmm_MT(i), zmm_MT(i+2));
- vunpckhpd(zmm_M(i+1), zmm_MT(i), zmm_MT(i+2));
- vunpcklpd(zmm_M(i+2), zmm_MT(i+1), zmm_MT(i+3));
- vunpckhpd(zmm_M(i+3), zmm_MT(i+1), zmm_MT(i+3));
- }
- for (int i = 0; i < simd_w; i += 8) {
- vshuff32x4(zmm_MT(i), zmm_M(i), zmm_M(i + 4), 0x88);
- vshuff32x4(zmm_MT(i+1), zmm_M(i+1), zmm_M(i + 5), 0x88);
- vshuff32x4(zmm_MT(i+2), zmm_M(i+2), zmm_M(i + 6), 0x88);
- vshuff32x4(zmm_MT(i+3), zmm_M(i+3), zmm_M(i + 7), 0x88);
- vshuff32x4(zmm_MT(i+4), zmm_M(i), zmm_M(i + 4), 0xdd);
- vshuff32x4(zmm_MT(i+5), zmm_M(i+1), zmm_M(i + 5), 0xdd);
- vshuff32x4(zmm_MT(i+6), zmm_M(i+2), zmm_M(i + 6), 0xdd);
- vshuff32x4(zmm_MT(i+7), zmm_M(i+3), zmm_M(i + 7), 0xdd);
- }
- {
- int i = 0;
- int mask = 0x88;
- vshuff32x4(zmm_M(0), zmm_MT(i), zmm_MT(i + 8), mask);
- vmovups(ptr[wreg_MT + 0 * 16 * 4], zmm_M(0));
- vshuff32x4(zmm_M(1), zmm_MT(i + 1), zmm_MT(i + 9), mask);
- vmovups(ptr[wreg_MT + 1 * 16 * 4], zmm_M(1));
- vshuff32x4(zmm_M(2), zmm_MT(i + 2), zmm_MT(i + 10), mask);
- vmovups(ptr[wreg_MT + 2 * 16 * 4], zmm_M(2));
- vshuff32x4(zmm_M(3), zmm_MT(i + 3), zmm_MT(i + 11), mask);
- vmovups(ptr[wreg_MT + 3 * 16 * 4], zmm_M(3));
- vshuff32x4(zmm_M(4), zmm_MT(i + 4), zmm_MT(i + 12), mask);
- vmovups(ptr[wreg_MT + 4 * 16 * 4], zmm_M(4));
- vshuff32x4(zmm_M(5), zmm_MT(i + 5), zmm_MT(i + 13), mask);
- vmovups(ptr[wreg_MT + 5 * 16 * 4], zmm_M(5));
- vshuff32x4(zmm_M(6), zmm_MT(i + 6), zmm_MT(i + 14), mask);
- vmovups(ptr[wreg_MT + 6 * 16 * 4], zmm_M(6));
- vshuff32x4(zmm_M(7), zmm_MT(i + 7), zmm_MT(i + 15), mask);
- vmovups(ptr[wreg_MT + 7 * 16 * 4], zmm_M(7));
- mask = 0xdd;
- vshuff32x4(zmm_M(8), zmm_MT(i), zmm_MT(i + 8), mask);
- vmovups(ptr[wreg_MT + 8 * 16 * 4], zmm_M(8));
- vshuff32x4(zmm_M(9), zmm_MT(i + 1), zmm_MT(i + 9), mask);
- vmovups(ptr[wreg_MT + 9 * 16 * 4], zmm_M(9));
- vshuff32x4(zmm_M(10), zmm_MT(i + 2), zmm_MT(i + 10), mask);
- vmovups(ptr[wreg_MT + 10 * 16 * 4], zmm_M(10));
- vshuff32x4(zmm_M(11), zmm_MT(i + 3), zmm_MT(i + 11), mask);
- vmovups(ptr[wreg_MT + 11 * 16 * 4], zmm_M(11));
- vshuff32x4(zmm_M(12), zmm_MT(i + 4), zmm_MT(i + 12), mask);
- vmovups(ptr[wreg_MT + 12 * 16 * 4], zmm_M(12));
- vshuff32x4(zmm_M(13), zmm_MT(i + 5), zmm_MT(i + 13), mask);
- vmovups(ptr[wreg_MT + 13 * 16 * 4], zmm_M(13));
- vshuff32x4(zmm_M(14), zmm_MT(i + 6), zmm_MT(i + 14), mask);
- vmovups(ptr[wreg_MT + 14 * 16 * 4], zmm_M(14));
- vshuff32x4(zmm_M(15), zmm_MT(i + 7), zmm_MT(i + 15), mask);
- vmovups(ptr[wreg_MT + 15 * 16 * 4], zmm_M(15));
- }
- };
-
- auto load_src = [=]() {
- mov(wreg_src, ptr[param1 + GET_OFF(src)]);
- mov(wreg_F, ptr[param1 + GET_OFF(M)]);
- for (int j = 0; j < kh; j++) {
- for (int i = 0; i < kw; i++) {
- if (is_fwd) {
- for (int v1 = 0; v1 < simd_w; v1++) {
- int offset_src = (j * kw * simd_w * simd_w
- + i * simd_w * simd_w + v1 * simd_w) * typesize;
- int offset_F = (j * kw * simd_w * simd_w
- + i * simd_w * simd_w + v1 * simd_w) * typesize;
- vmovups(zmm_temp, ptr[wreg_src + offset_src]);
- vmovups(ptr[wreg_F + offset_F], zmm_temp);
- }
- } else {
- int offset_src = ((2 - j) * kw * simd_w * simd_w
- + (2 - i) * simd_w * simd_w) * typesize;
- int offset_F = (j * kw * simd_w * simd_w
- + i * simd_w * simd_w) * typesize;
- lea(wreg_M, ptr[wreg_src + offset_src]);
- lea(wreg_MT, ptr[wreg_F + offset_F]);
- trans16x16();
- }
- }
- }
- };
-
- auto store_dst = [=]() {
- mov(wreg_dst, ptr[param1 + GET_OFF(dst)]);
- mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]);
-
- Label Loop_j;
- mov(wreg_cnt_j, 0);
- mov(wreg_dst_aux, wreg_dst);
- mov(wreg_Fw_aux, wreg_Fw);
-
- int dim5 = jcp.dimK_nb_block * (jcp.dimM_block * jcp.dimM_reg_block)
- * jcp.dimK_block * simd_w * simd_w;
-
- L(Loop_j);
- {
- for (int i = 0; i < alpha; i++) {
- // touch pages
- vmovups(zmm_load(0), ptr[wreg_Fw_aux
- + (i * simd_w * simd_w) * typesize]);
- mov(wreg_dst_idx, i * dim5 * typesize);
- vmovntps(ptr[wreg_dst_aux + wreg_dst_idx], zmm_load(0));
- }
- for (int i = 0; i < alpha; i++) {
- for (int v1 = 1; v1 < simd_w; v1++) {
- int offset_Fw = (i * simd_w * simd_w + v1 * simd_w)
- * typesize;
- vmovups(zmm_load(v1), ptr[wreg_Fw_aux + offset_Fw]);
- }
- mov(wreg_dst_idx, i * dim5 * typesize);
- for (int v1 = 1; v1 < simd_w; v1++) {
- int offset_dst = v1 * simd_w * typesize;
- vmovntps(ptr[wreg_dst_aux + wreg_dst_idx + offset_dst],
- zmm_load(v1));
- }
- }
- add(wreg_Fw_aux, alpha * simd_w * simd_w * typesize);
- add(wreg_dst_aux, alpha * dim5 * typesize);
- add(wreg_cnt_j, 1);
- cmp(wreg_cnt_j, alpha);
- jl(Loop_j, T_NEAR);
- }
- };
-
- auto trans_W_4x4_3x3 = [=]() {
- auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
- vmovups(dst, a);
- vfmadd231ps(dst, b, c);
- };
- auto fms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
- vmulps(zmm_temp, b, c);
- vsubps(dst, a, zmm_temp);
- };
- auto fnms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
- vsubps(dst, zmm_zero, a);
- vfnmadd231ps(dst, b, c);
- };
-
- mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]);
- mov(wreg_F, ptr[param1 + GET_OFF(M)]);
- mov(wreg_T, ptr[param1 + GET_OFF(T)]);
-
- Label Loop_j;
- mov(wreg_cnt_j, 0);
- L(Loop_j);
- mov(wreg_F_aux, wreg_F);
- mov(wreg_Fw_aux, wreg_Fw);
- mov(wreg_temp, wreg_cnt_j);
- shl(wreg_temp, 4 + 2);
- lea(wreg_F_aux, ptr[wreg_F + wreg_temp]);
- lea(wreg_Fw_aux, ptr[wreg_Fw + wreg_temp]);
-
- for (int i = 0; i < 3; i++) {
- for (int idx = 0; idx < 3; idx ++) {
- vmovups(zmm_F(idx), ptr[wreg_F_aux + (idx * 3 * simd_w
- * simd_w + i * simd_w * simd_w) * typesize]);
- }
- vmulps(zmm_t(0), zmm_G(0), zmm_F(2));
- fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_F(0));
- fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_F(0));
-
- vmulps(zmm_T(0), zmm_G(3), zmm_F(0));
- fms4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_F(1));
- fma4(zmm_T(2), zmm_t(1), zmm_G(4), zmm_F(1));
- fma4(zmm_T(3), zmm_t(2), zmm_G(5), zmm_F(1));
- fms4(zmm_T(4), zmm_t(2), zmm_G(5), zmm_F(1));
- vmovaps(zmm_T(5), zmm_F(2));
-
- for (int idx = 0; idx < 6; idx ++) {
- vmovups(ptr[wreg_T + (idx * 3 * simd_w + i * simd_w)
- * typesize], zmm_T(idx));
- }
- }
- for (int i = 0; i < 6; i++) {
-
- for (int idx = 0; idx < 3; idx ++) {
- vmovups(zmm_T(idx), ptr[wreg_T
- + (i * 3 * simd_w + idx * simd_w) * typesize]);
- }
- vmulps(zmm_t(0), zmm_G(0), zmm_T(2));
- fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_T(0));
- fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_T(0));
-
- vmulps(zmm_F(0), zmm_G(3), zmm_T(0));
- fms4(zmm_F(1), zmm_t(1), zmm_G(4), zmm_T(1));
- fma4(zmm_F(2), zmm_t(1), zmm_G(4), zmm_T(1));
- fma4(zmm_F(3), zmm_t(2), zmm_G(5), zmm_T(1));
- fms4(zmm_F(4), zmm_t(2), zmm_G(5), zmm_T(1));
- vmovaps(zmm_F(5), zmm_T(2));
-
- for (int l = 0; l < 6; l++) {
- vmovups(ptr[wreg_Fw_aux + (i * 6 * simd_w * simd_w
- + l * simd_w * simd_w) * typesize], zmm_F(l));
- }
- }
- add(wreg_cnt_j, 1);
- cmp(wreg_cnt_j, 16);
- jl(Loop_j, T_NEAR);
- };
-
- auto inner_loops = [=]() {
- load_src();
- init_G();
- trans_W_4x4_3x3();
- store_dst();
- };
-
- preamble();
- inner_loops();
- postamble();
-}
-
-void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel
- ::output_transform_data_ker_generate()
-{
- bool is_fwd = one_of(jcp.prop_kind,
- mkldnn_forward_training, mkldnn_forward_inference);
- int outw = is_fwd ? jcp.ow : jcp.iw;
- int outh = is_fwd ? jcp.oh : jcp.ih;
- bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D;
- bool with_bias = jcp.with_bias;
- bool with_relu = jcp.with_eltwise;
- bool with_relu_postsum = jcp.with_relu_postsum;
- bool with_sum = jcp.with_sum;
-
- auto zmm_zero = Xbyak::Zmm(0);
- auto zmm_temp = Xbyak::Zmm(31);
- auto zmm_G = [=](int i) {
- return Xbyak::Zmm(1 + i);
- };
- auto zmm_O = [=](int i) {
- return Xbyak::Zmm(1 + alpha + i);
- };
- auto zmm_T = [=](int i) {
- return Xbyak::Zmm(1 + 2 * alpha + i);
- };
- auto zmm_t = [=](int i) {
- return Xbyak::Zmm(1 + 3 * alpha + i);
- };
-
- auto init_G = [=]() {
- mov(oreg_temp, ptr[param1 + GET_OFF(G)]);
- for (int i = 0; i < 6; i++) {
- vbroadcastss(zmm_G(i), ptr[oreg_temp + i * typesize]);
- }
- };
-
- auto load_src = [=]() {
- mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]);
- mov(oreg_src, ptr[param1 + GET_OFF(src)]);
-
- mov(oreg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]);
- imul(oreg_nb_tile_block_ur, oreg_nb_tile_block_ur,
- (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block
- * jcp.dimM_simd_block * typesize);
- add(oreg_src, oreg_nb_tile_block_ur);
-
- mov(oreg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]);
- imul(oreg_tile_block_ur, oreg_tile_block_ur,
- jcp.dimM_simd_block * typesize);
- add(oreg_src, oreg_tile_block_ur);
-
- if (not_tiled) {
- mov(oreg_tile_block, ptr[param1 + GET_OFF(tile_block)]);
- imul(oreg_tile_block, oreg_tile_block,
- jcp.dimM_nb_block * alpha * alpha * jcp.dimN_block
- * (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block
- * jcp.dimM_simd_block * typesize);
- add(oreg_src, oreg_tile_block);
- }
-
- int last4dim = jcp.dimN_block * (jcp.dimM_block * jcp.dimM_reg_block)
- * jcp.dimN_reg_block * jcp.dimM_simd_block * typesize;
- for (int j = 0; j < alpha; j++) {
- for (int i = 0; i < alpha; i++) {
- int j_base_offset = j * alpha * last4dim;
- int i_base_offset = i * last4dim;
- vmovups(zmm_temp, ptr[oreg_src + j_base_offset + i_base_offset]);
- vmovups(ptr[oreg_Ow + (j * alpha * simd_w + i * simd_w)
- * typesize], zmm_temp);
- }
- }
- };
-
- auto store_dst = [=]() {
- vpxord(zmm_zero, zmm_zero, zmm_zero);
- mov(oreg_dst, ptr[param1 + GET_OFF(dst)]);
- mov(oreg_O, ptr[param1 + GET_OFF(M)]);
- mov(oreg_ydim, ptr[param1 + GET_OFF(tj)]);
- shl(oreg_ydim, 2); // tj * tile_size (==4)
- mov(oreg_xdim, ptr[param1 + GET_OFF(ti)]);
- shl(oreg_xdim, 2); // ti * tilesize (==4)
-
- if (with_bias)
- mov(oreg_bias, ptr[param1 + GET_OFF(bias)]);
-
- auto store_one = [=](int j, int i, bool is_aligned) {
- auto zmm_O = Xbyak::Zmm(31);
- auto zmm_relu_ns = Xbyak::Zmm(30);
- auto xmm_relu_ns = Xbyak::Xmm(30);
- int offset = (j * tile_size * simd_w + i * simd_w) * typesize;
-
- vmovups(zmm_O, ptr[oreg_O + offset]);
- if (is_fwd) {
- if (with_bias) {
- vaddps(zmm_O, zmm_O, ptr[oreg_bias]);
- }
- if (with_relu) {
- if (jcp.eltwise.alpha == 0) {
- vmaxps(zmm_O, zmm_O, zmm_zero);
- } else {
- Opmask kmask = Opmask(7);
- mov(imm_addr64, float2int(jcp.eltwise.alpha));
- vmovq(xmm_relu_ns, imm_addr64);
- vbroadcastss(zmm_relu_ns, xmm_relu_ns);
- vcmpps(kmask, zmm_O, zmm_zero, _cmp_lt_os);
- vmulps(zmm_O | kmask, zmm_O, zmm_relu_ns);
- }
- }
- }
- if (with_sum) {
- vaddps(zmm_O, zmm_O, ptr[oreg_out_j + oreg_temp]);
- if (with_relu_postsum) // orig: with_relu_postsum
- vmaxps(zmm_O, zmm_O, zmm_zero);
- }
- if (is_aligned)
- vmovntps(ptr[oreg_out_j + oreg_temp], zmm_O);
- else
- vmovups(ptr[oreg_out_j + oreg_temp], zmm_O);
- };
-
- auto i_loop = [=](int j, bool is_aligned) {
- for (int i = 0; i < tile_size; i++) {
- Label next;
- mov(oreg_temp, oreg_xdim);
- add(oreg_temp, i);
- cmp(oreg_temp, outw);
- jge(next, T_NEAR);
- shl(oreg_temp, 4 + 2); // * 16 * 4
-
- store_one(j, i, is_aligned);
-
- L(next);
- }
- };
-
-
- for (int j = 0; j < tile_size; j++) {
- Label next, unaligned;
- mov(oreg_temp, oreg_ydim);
- add(oreg_temp, j);
- cmp(oreg_temp, outh);
- jge(next, T_NEAR);
-
- mov(oreg_out_j, oreg_dst);
- imul(oreg_temp, oreg_temp, outw * simd_w * typesize);
- add(oreg_out_j, oreg_temp);
-
- test(oreg_dst, 63);
- jnz(unaligned, T_NEAR);
-
- i_loop(j, true);
- jmp(next, T_NEAR);
-
- L(unaligned);
- i_loop(j, false);
-
- L(next);
- }
- };
-
- auto trans_O_4x4_3x3 = [=]() {
- auto fma2 = [=](Zmm dst, Zmm v1, Zmm u1, Zmm v2, Zmm u2){
- vmulps(dst, v1, u1);
- vfmadd231ps(dst, v2, u2);
- };
- mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]);
- mov(oreg_T, ptr[param1 + GET_OFF(T)]);
- mov(oreg_O, ptr[param1 + GET_OFF(M)]);
-
- for (int i = 0; i < alpha; i++) {
- for (int j = 0; j < alpha; j++) {
- vmovups(zmm_O(j), ptr[oreg_Ow + (j * alpha * simd_w
- + i * simd_w) * typesize]);
- }
-
- vaddps(zmm_t(0), zmm_O(1), zmm_O(2));
- vaddps(zmm_t(1), zmm_O(3), zmm_O(4));
- vsubps(zmm_t(2), zmm_O(1), zmm_O(2));
- vsubps(zmm_t(3), zmm_O(3), zmm_O(4));
-
- vaddps(zmm_T(0), zmm_t(0), zmm_t(1));
- vaddps(zmm_T(0), zmm_T(0), zmm_O(0));
- fma2(zmm_T(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1));
- fma2(zmm_T(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3));
- fma2(zmm_T(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5));
- vaddps(zmm_T(3), zmm_T(3), zmm_O(5));
-
- for (int j = 0; j < tile_size; j++) {
- vmovups(ptr[oreg_T + (j * alpha * simd_w
- + i * simd_w) * typesize], zmm_T(j));
- }
- }
- for (int j = 0; j < tile_size; j++) {
- for (int i = 0; i < alpha; i++) {
- vmovups(zmm_T(i), ptr[oreg_T + (j * alpha * simd_w
- + i * simd_w) * typesize]);
- }
- vaddps(zmm_t(0), zmm_T(1), zmm_T(2));
- vaddps(zmm_t(1), zmm_T(3), zmm_T(4));
- vsubps(zmm_t(2), zmm_T(1), zmm_T(2));
- vsubps(zmm_t(3), zmm_T(3), zmm_T(4));
-
- vaddps(zmm_O(0), zmm_t(0), zmm_t(1));
- vaddps(zmm_O(0), zmm_O(0), zmm_T(0));
- fma2(zmm_O(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1));
- fma2(zmm_O(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3));
- fma2(zmm_O(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5));
- vaddps(zmm_O(3), zmm_O(3), zmm_T(5));
-
- for (int i = 0; i < tile_size; i++) {
- vmovups(ptr[oreg_O + (j * tile_size * simd_w
- + i * simd_w) * typesize], zmm_O(i));
- }
- }
- };
-
- auto inner_loops = [=]() {
- init_G();
- load_src();
- trans_O_4x4_3x3();
- store_dst();
- };
-
- preamble();
- inner_loops();
- postamble();
-}
-
-void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel
- ::input_transform_data_ker_generate()
-{
- bool is_fwd = one_of(jcp.prop_kind,
- mkldnn_forward_training, mkldnn_forward_inference);
- int inpw = is_fwd ? jcp.iw : jcp.ow;
- int inph = is_fwd ? jcp.ih : jcp.oh;
- int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow;
- int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh;
- int wp_max = inpw + l_pad;
- int hp_max = inph + t_pad;
- bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D;
- int G_size = 9;
-
- auto zmm_zero = Xbyak::Zmm(0);
- auto zmm_temp = Xbyak::Zmm(31);
- auto zmm_G = [=](int i) {
- return Xbyak::Zmm(1 + i);
- };
- auto zmm_I = [=](int i) {
- return Xbyak::Zmm(1 + G_size + i);
- };
- auto zmm_T = [=](int i) {
- return Xbyak::Zmm(1 + G_size + alpha + i);
- };
- auto zmm_t = [=](int i) {
- return Xbyak::Zmm(1 + G_size + 2 * alpha + i);
- };
-
- auto init_G = [=]() {
- mov(ireg_temp, ptr[param1 + GET_OFF(G)]);
- for (int i = 0; i < G_size; i++) {
- vbroadcastss(zmm_G(i), ptr[ireg_temp + i * typesize]);
- }
- };
-
- auto load_src = [=]() {
- mov(ireg_src, ptr[param1 + GET_OFF(src)]); // base addr of inp
- mov(ireg_I, ptr[param1 + GET_OFF(M)]);
-
- xor_(ireg_zero, ireg_zero);
- vpxord(zmm_zero, zmm_zero, zmm_zero);
-
- mov(ireg_ydim, ptr[param1 + GET_OFF(tj)]);
- shl(ireg_ydim, 2); // tj * tile_size (==4)
- mov(ireg_xdim, ptr[param1 + GET_OFF(ti)]);
- shl(ireg_xdim, 2); // ti * tilesize (==4)
-
- for (int j = 0; j < alpha; j++) {
- mov(ireg_temp, ireg_ydim);
- add(ireg_temp, j);
-
- mov(ireg_mask_j, 0xffff);
- cmp(ireg_temp, t_pad);
- cmovl(ireg_mask_j, ireg_zero);
- cmp(ireg_temp, hp_max);
- cmovge(ireg_mask_j, ireg_zero);
-
- sub(ireg_temp, t_pad);
- imul(ireg_temp, ireg_temp, inpw * simd_w * typesize);
- mov(ireg_inp_j, ireg_src);
- add(ireg_inp_j, ireg_temp);
-
- for (int i = 0; i < alpha; i++) {
-
- mov(ireg_temp, ireg_xdim);
- add(ireg_temp, i);
-
- mov(ireg_mask, 0xffff);
- cmp(ireg_temp, l_pad);
- cmovl(ireg_mask, ireg_zero);
- cmp(ireg_temp, wp_max);
- cmovge(ireg_mask, ireg_zero);
- and_(ireg_mask, ireg_mask_j);
-
- sub(ireg_temp, l_pad);
- shl(ireg_temp, 4 + 2);
-
- vpxord(zmm_temp, zmm_temp, zmm_temp);
- Opmask kmask = Opmask(7);
- kmovw(kmask, ireg_mask_32);
- vmovups(zmm_temp | kmask, ptr[ireg_inp_j + ireg_temp]);
- vmovups(ptr[ireg_I + (j * alpha * simd_w + i * simd_w)
- * typesize], zmm_temp);
- }
- }
- };
-
- auto store_Iw = [=]() {
-
- mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]);
- mov(ireg_output, ptr[param1 + GET_OFF(dst)]);
-
- bool streamout
- = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float)
- > 2 * LLC_data_size
- ? true : false;
-
- if (not_tiled) {
- mov(ireg_tile_block, ptr[param1 + GET_OFF(tile_block)]);
- imul(ireg_tile_block, ireg_tile_block,
- alpha * alpha * jcp.dimN_block * jcp.dimK_nb_block
- * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block
- * typesize);
- }
-
- mov(ireg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]);
- imul(ireg_nb_tile_block_ur, ireg_nb_tile_block_ur,
- jcp.dimK_nb_block * jcp.dimK_block * jcp.dimN_reg_block
- * jcp.dimK_reg_block * typesize);
-
- mov(ireg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]);
- imul(ireg_tile_block_ur, ireg_tile_block_ur,
- jcp.dimK_reg_block * typesize);
-
- add(ireg_output, ireg_nb_tile_block_ur);
- add(ireg_output, ireg_tile_block_ur);
- if (not_tiled)
- add(ireg_output, ireg_tile_block);
-
- for (int j = 0; j < alpha; j++) {
- for (int i = 0; i < alpha; i++) {
- vmovups(zmm_temp,ptr[ireg_Iw + (j * alpha * simd_w
- + i * simd_w) * typesize]);
-
- int j_base_offset =
- j * alpha * jcp.dimN_block * jcp.dimK_nb_block
- * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block
- * typesize;
- int i_base_offset =
- i * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block
- * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize;
-
- if (not_tiled && streamout)
- vmovntps(ptr[ireg_output + j_base_offset + i_base_offset],
- zmm_temp);
- else
- vmovups(ptr[ireg_output + j_base_offset + i_base_offset],
- zmm_temp);
- }
- }
- };
-
- auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
- vmulps(zmm_temp, a, b);
- vaddps(dst, zmm_temp, c);
- };
-
- auto trans_I_4x4_3x3 = [=]() {
- mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]);
- mov(ireg_T, ptr[param1 + GET_OFF(T)]);
- mov(ireg_I, ptr[param1 + GET_OFF(M)]);
-
- mov(ireg_output, ptr[param1 + GET_OFF(dst)]); // for prefetch
- for (int i = 0; i < alpha; i++) {
- for (int idx = 0; idx < alpha; idx++) {
- vmovups(zmm_I(idx), ptr[ireg_I + (idx * alpha * simd_w
- + i * simd_w) * typesize]);
- int j_base_offset =
- i * alpha * jcp.dimN_block * jcp.dimK_nb_block
- * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block
- * typesize;
- int idx_base_offset =
- idx * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block
- * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize;
- prefetcht0(ptr[ireg_output + j_base_offset + idx_base_offset]);
- }
-
- fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4));
- fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3));
- fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4));
- fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3));
- fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4));
- fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5));
-
- fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4));
- fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0));
- fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0));
- fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2));
- fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2));
- fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5));
-
- for (int idx = 0; idx < alpha; idx++) {
- vmovups(ptr[ireg_T + (idx * alpha * simd_w + i * simd_w)
- * typesize],zmm_T(idx));
- }
- }
- for (int i = 0; i < alpha; i++) {
- for (int idx = 0; idx < alpha; idx++) {
- vmovups(zmm_T(idx), ptr[ireg_T + (i * alpha * simd_w + idx
- * simd_w) * typesize]);
- }
-
- fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4));
- fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3));
- fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4));
- fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3));
- fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4));
- fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5));
-
- fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4));
- fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0));
- fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0));
- fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2));
- fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2));
- fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5));
-
- for (int idx = 0; idx < alpha; idx++) {
- vmovups(ptr[ireg_Iw + (i * alpha * simd_w + idx * simd_w)
- * typesize],zmm_I(idx));
- }
- }
- };
-
- auto inner_loops = [=]() {
- init_G();
- load_src();
- trans_I_4x4_3x3();
- store_Iw();
- };
-
- preamble();
- inner_loops();
- postamble();
-}
-
-status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_common(
- jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d)
-{
- if (!mayiuse(avx512_core)) {
- return status::unimplemented;
- }
-
- jcp.nthr = mkldnn_get_max_threads();
-
- jcp.ver = ver_avx512_core;
- jcp.prop_kind = cd.prop_kind;
-
- const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
-
- jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
- jcp.mb = src_d.dims()[0];
- jcp.oc = dst_d.dims()[1] / jcp.ngroups;
- jcp.oc_without_padding = jcp.oc;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
- jcp.ih = src_d.dims()[2];
- jcp.iw = src_d.dims()[3];
- jcp.oh = dst_d.dims()[2];
- jcp.ow = dst_d.dims()[3];
- jcp.kh = weights_d.dims()[with_groups + 2];
- jcp.kw = weights_d.dims()[with_groups + 3];
- jcp.t_pad = cd.padding[0][0];
- jcp.l_pad = cd.padding[0][1];
- jcp.stride_h = cd.strides[0];
- jcp.stride_w = cd.strides[1];
- jcp.dilate_h = cd.dilates[0];
- jcp.dilate_w = cd.dilates[1];
- jcp.r_pad = nstl::max(
- 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
- jcp.b_pad = nstl::max(
- 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
- jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
- jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
- jcp.ohp = jcp.oh;
- jcp.owp = jcp.ow;
-
- bool ok_to_pad_channels = jcp.ngroups == 1;
- if (ok_to_pad_channels) {
- jcp.oc = rnd_up(jcp.oc, simd_w);
- jcp.ic = rnd_up(jcp.ic, simd_w);
- }
-
- // Checking conditions not supported by these kernels
- if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
- is_winograd_faster_than_direct(jcp)))
- return status::unimplemented;
-
- if (jcp.ngroups != 1)
- return status::unimplemented;
- if ((jcp.kh != 3) || (jcp.kw != 3))
- return status::unimplemented;
- if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
- return status::unimplemented;
- if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
- return status::unimplemented;
- if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
- return status::unimplemented;
-
- format_tag_t dat_tag = nChw16c;
- jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
- jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
-
- if (jcp.src_tag != dat_tag) return status::unimplemented;
- if (jcp.dst_tag != dat_tag) return status::unimplemented;
-
- if (!one_of(weights_d.format_kind(), format_kind::any, format_kind::wino)) {
- format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o;
- jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
- if (jcp.wei_tag != wei_tag)
- return status::unimplemented;
- }
-
- bool layout_consistency = true
- && jcp.ic <= src_d.padded_dims()[1]
- && jcp.oc <= dst_d.padded_dims()[1]
- && (one_of(weights_d.format_kind(),
- format_kind::any, format_kind::wino)
- || (jcp.ic <= weights_d.padded_dims()[with_groups + 1]
- && jcp.oc <= weights_d.padded_dims()[with_groups + 0]));
- if (!layout_consistency)
- return status::unimplemented;
-
- return status::success;
-}
-
-void set_kernel_dims_reg_block(jit_conv_winograd_conf_t &jcp) {
-
- /* ----------- dimM reg block ---------------------*/
- auto test_cond_dimM_reg_block = [](jit_conv_winograd_conf_t &jcp,
- int dimM_reg_block, int current_best) {
- int max_dimM_reg_block = jcp.kernel_kind == embd_bcast ? 1 : 4;
- return (dimM_reg_block >= 1)
- && (dimM_reg_block <= max_dimM_reg_block )
- && (dimM_reg_block > current_best);
- };
- jcp.dimM_reg_block = get_divisor_satisfying_cond(jcp,
- jcp.dimM/jcp.dimM_simd_block, 1, test_cond_dimM_reg_block);
-
- /* ----------- dimN reg block ---------------------*/
-
- auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
- int dimN_reg_block, int current_best) {
- return jcp.kernel_kind == embd_bcast
- ? dimN_reg_block < jcp.nb_reg && dimN_reg_block > current_best
- : dimN_reg_block >= 1
- && (dimN_reg_block * jcp.dimM_reg_block + dimN_reg_block)
- < jcp.nb_reg
- && dimN_reg_block > current_best;
- };
- jcp.dimN_reg_block = get_divisor_satisfying_cond(jcp,
- jcp.dimN, 1, test_cond_dimN_reg_block);
-}
-
-status_t set_wsched_DATA_W_SGD_avx512_core(jit_conv_winograd_conf_t &jcp) {
- if (jcp.ver != ver_avx512_core)
- return status::unimplemented;
-
- jcp.kernel_kind = embd_bcast;
-
- set_kernel_dims_reg_block(jcp);
-
- /*-------------- L2 blocking for dimN block ---------*/
-
- auto test_cond_dimN_block = [](jit_conv_winograd_conf_t &jcp,
- int dimN_block, int current_best) {
- return check_L2_block_per_thread(jcp, dimN_block, 0.1, 2.0)
- && (dimN_block > current_best)
- && ((jcp.dimN / dimN_block / jcp.dimN_reg_block)
- >= 1.5 * mkldnn_get_max_threads());
- };
-
- jcp.dimN_block = get_divisor_satisfying_cond(
- jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond_dimN_block);
- jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block;
-
- if (check_L2_block_per_thread(jcp, jcp.dimN_block, 0.1, 3.2)
- && (jcp.dimN_nb_block >= 1.5 * mkldnn_get_max_threads())) {
-
- /* ------------------- L1 blocking for GEMM --------------*/
- /* -------------------- Choose dimK block ----------------*/
-
- auto test_cond_dimK_block = [](jit_conv_winograd_conf_t &jcp,
- int dimK_block, int current_best) {
- return check_L1_block_gemm(jcp, dimK_block, 1, 0.1, 0.5)
- && (dimK_block > current_best);
- };
-
- jcp.dimK_block = get_divisor_satisfying_cond(
- jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond_dimK_block);
-
- if (check_L1_block_gemm(jcp, jcp.dimK_block, 1, 0.1, 1.0)) {
- jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block;
-
- /* -------------- Choose dimM block -------------------*/
- auto test_cond_dimM_block = [](jit_conv_winograd_conf_t &jcp,
- int dimM_block, int current_best) {
- return check_L1_block_gemm(jcp, jcp.dimK_block, dimM_block,
- 0.2, 0.5) && (dimM_block > current_best);
- };
-
- jcp.dimM_block = get_divisor_satisfying_cond(jcp,
- jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_reg_block), 1,
- test_cond_dimM_block);
- jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block
- / jcp.dimM_simd_block;
-
- jcp.sched_policy = WSCHED_DATA_W_SGD;
- return status::success;
- }
-
- }
- return status::unimplemented;
-}
-
-void set_kernel_blocking_DATA_W_S_G_D(jit_conv_winograd_conf_t &jcp) {
-
- set_kernel_dims_reg_block(jcp);
-
- //********************* Choosing dimK_block **********************//
- auto test_cond1_dimK_block = [](
- jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
- return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block,
- 1, jcp.dimM_reg_block, jcp.dimM_simd_block, .75f)
- && (dimK_block > current_best);
- };
-
- auto test_cond1_bis_dimK_block = [](
- jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
- return check_cond1_bis(jcp.dimN_reg_block, dimK_block,
- jcp.dimK_reg_block, 1, jcp.dimM_reg_block,
- jcp.dimM_simd_block, .9f)
- && (dimK_block > current_best);
- };
-
- jcp.dimK_block = get_divisor_satisfying_cond(
- jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block);
- // If we are not able to use streams, we fall back to condition [1]
- if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
- jcp.dimK_block = get_divisor_satisfying_cond(
- jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block);
- jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block;
-
- //********************* Choosing dimM_block **********************//
- auto test_cond1_dimM_block = [](
- jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
- return check_cond1(jcp.dimN_reg_block, jcp.dimK_block,
- jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block,
- jcp.dimM_simd_block, .5f)
- && (dimM_block > current_best);
- };
-
- auto test_cond1_bis_dimM_block = [](
- jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
- return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block,
- jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block,
- jcp.dimM_simd_block, .3f)
- && (dimM_block > current_best);
- };
-
- if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
- jcp.dimM_block = get_divisor_satisfying_cond(
- jcp, jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1,
- test_cond1_dimM_block);
- else
- jcp.dimM_block = get_divisor_satisfying_cond(jcp,
- jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1,
- test_cond1_bis_dimM_block);
- jcp.dimM_nb_block = jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_block
- * jcp.dimM_reg_block);
-
- //******************* Choosing dimN_block *******************//
- auto test_cond2_dimN_block = [](
- jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) {
- return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block,
- jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block,
- jcp.dimM_reg_block, jcp.dimM_simd_block, .9f)
- && (dimN_block > current_best);
- };
-
- jcp.dimN_block = get_divisor_satisfying_cond(
- jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block);
- jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block);
-}
-
-status_t set_wsched_DATA_W_S_G_D_avx512_core(jit_conv_winograd_conf_t &jcp) {
-
- jcp.kernel_kind = expl_bcast;
- set_kernel_blocking_DATA_W_S_G_D(jcp);
- if (!(check_kernel_cond(jcp.dimM_block, jcp.dimM_reg_block,
- jcp.dimM_simd_block, jcp.dimN_block, jcp.dimN_reg_block, jcp.dimK,
- .1f, .35f))) {
- jcp.kernel_kind = embd_bcast;
- set_kernel_blocking_DATA_W_S_G_D(jcp);
- }
- jcp.sched_policy = WSCHED_DATA_W_S_G_D;
- return status::success;
-}
-
-status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_kernel(
- jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK)
-{
- jcp.nb_reg = 32;
- jcp.dimN = dimN;
- jcp.dimK = dimK;
- jcp.dimM = dimM;
- jcp.sched_policy = WSCHED_INVALID;
-
- jcp.dimK_reg_block = 16;
- jcp.dimM_simd_block = 16;
-
- if (jcp.kernel_kind == embd_bcast) {
- jcp.dimM_reg_block = 1;
- }
-
- if (!(set_wsched_DATA_W_SGD_avx512_core(jcp) == status::success))
- set_wsched_DATA_W_S_G_D_avx512_core(jcp);
-
- assert(jcp.sched_policy != WSCHED_INVALID);
- return status::success;
-}
-
-bool jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::post_ops_ok(
- jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
- const auto &p = attr.post_ops_;
-
- auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
- auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
-
- switch (p.len_) {
- case 0: return true; // no post_ops
- case 1: return is_relu(0) || is_sum(0); // relu or sum
- case 2: return (is_sum(0) && is_relu(1))
- || (is_relu(0) && is_sum(1)); // sum->relu or relu->sum
- case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu
- default: return false;
- }
-
- return false;
-}
-
-status_t jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::init_conf(
- jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
- const memory_desc_t &src_md, memory_desc_t &weights_md,
- const memory_desc_t &dst_md, const primitive_attr_t &attr) {
-
- status_t st = init_conf_common(jcp, cd, src_md, weights_md, dst_md);
-
- if (st != status::success)
- return st;
-
- // Winograd specific initialization
- jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
- jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
- jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
-
- jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
-
- if (!post_ops_ok(jcp, attr))
- return status::unimplemented;
-
- const auto &p = attr.post_ops_;
- const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1);
- jcp.with_eltwise = eltwise_ind != -1;
- if (jcp.with_eltwise)
- jcp.eltwise = p.entry_[eltwise_ind].eltwise;
-
- jcp.with_sum = p.find(primitive_kind::sum, 0) != -1;
- jcp.with_relu_postsum = p.find(primitive_kind::eltwise, 1) != -1;
-
- status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic);
-
- jcp.ic_simd_block = jcp.dimK_reg_block;
- jcp.ic_block = jcp.dimK_block;
- jcp.nb_ic = jcp.dimK_nb_block;
- jcp.oc_simd_block = jcp.dimM_simd_block;
- jcp.oc_block = jcp.dimM_block;
- jcp.oc_reg_block = jcp.dimM_reg_block;
- jcp.ic_reg_block = 1;
- jcp.nb_oc = jcp.dimM_nb_block;
- jcp.tile_block_ur = jcp.dimN_reg_block;
- jcp.nb_tile_block_ur = jcp.dimN_block;
- jcp.tile_block = jcp.dimN_nb_block;
-
- /* re-create weights primitive descriptor
- and set weights wino_blocking */
- if (cd.prop_kind == mkldnn_forward_inference) {
- memory_desc_t expect_wei_md = weights_md;
-
- expect_wei_md.format_kind = format_kind::wino;
- expect_wei_md.data_type = data_type::f32;
- mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc;
- wd.wino_format = mkldnn_wino_wei_OBaaIBOIio;
- wd.r = 3;
- wd.alpha = 6;
-
- wd.ic = jcp.ic;
- wd.oc = jcp.oc;
- wd.ic_block = jcp.dimK_reg_block;
- wd.oc_block = jcp.dimM_simd_block;
- wd.ic2_block = jcp.dimK_block;
- wd.oc2_block = jcp.dimM_block * jcp.dimM_reg_block;
- size_t max_size = sizeof(float) * wd.alpha * wd.alpha * jcp.ic * jcp.oc;
- wd.size = max_size;
- wd.adj_scale = 1.f;
-
- if (weights_md.format_kind == format_kind::any)
- weights_md = expect_wei_md;
- if (weights_md != expect_wei_md)
- return status::unimplemented;
- }
-
- return res;
-}
-
-status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel::init_conf(
- jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
- const memory_desc_wrapper &diff_src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &diff_dst_d)
-{
- status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d);
-
- if (st != status::success)
- return st;
-
- jcp.itiles = (jcp.iw + tile_size - 1) / tile_size;
- jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size;
- jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
-
- status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc);
-
- jcp.oc_simd_block = jcp.dimK_reg_block;
- jcp.oc_block = jcp.dimK_block;
- jcp.nb_oc = jcp.dimK_nb_block;
- jcp.ic_simd_block = jcp.dimM_simd_block;
- jcp.ic_block = jcp.dimM_block;
- jcp.ic_reg_block = jcp.dimM_reg_block;
- jcp.oc_reg_block = 1;
- jcp.nb_ic = jcp.dimM_nb_block;
- jcp.tile_block_ur = jcp.dimN_reg_block;
- jcp.nb_tile_block_ur = jcp.dimN_block;
- jcp.tile_block = jcp.dimN_nb_block;
-
- return res;
-}
-
-void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::
-src_transform_generate() {
- constexpr int G_size = 9;
- const size_t ifwp = jcp.iw + jcp.l_pad;
- const size_t ifhp = jcp.ih + jcp.t_pad;
-
- auto zmm_G = [=](int i) {
- return Xbyak::Zmm(i);
- };
- auto zmm_I = [=](int i) {
- return Xbyak::Zmm(G_size + i);
- };
- auto zmm_T = [=](int i) {
- return Xbyak::Zmm(G_size + alpha + i);
- };
- auto zmm_t = [=](int i) {
- return Xbyak::Zmm(G_size + 2 * alpha + i);
- };
-
- auto init_G = [=]() {
- mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
- for (int i = 0; i < G_size; i++) {
- vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]);
- }
- };
-
- auto load_src = [=]() {
- mov(reg_I, ptr[reg_transp + GET_OFF(M)]);
- xor_(reg_zero, reg_zero);
-
- mov(reg_ydim, reg_tj);
- shl(reg_ydim, 2); //tj * tile_size(=4)
-
- for (int j = 0; j < alpha; j++) {
- /* check if tile index is within physical spatial boundaries*/
- mov(reg_maskj, 0xffff);
- cmp(reg_ydim, jcp.t_pad);
- cmovl(reg_maskj, reg_zero);
- cmp(reg_ydim, ifhp);
- cmovge(reg_maskj, reg_zero);
-
- /*address offset for tile in src*/
- mov(reg_src_offset, reg_ydim);
- sub(reg_src_offset, jcp.t_pad); // tj*tile_size - t_pad
- imul(reg_src_offset, reg_src_offset, jcp.iw);
-
- mov(reg_xdim, reg_ti);
- shl(reg_xdim, 2); // xdim = ti * tile_size
-
- add(reg_src_offset, reg_xdim);
- sub(reg_src_offset, jcp.l_pad);
- imul(reg_src_offset, reg_src_offset, simd_w * typesize);
- for (int i = 0; i < alpha; i++) {
- /* check if tile index is within physical spatial boundaries*/
- mov(reg_maski, 0xffff);
- cmp(reg_xdim, jcp.l_pad);
- cmovl(reg_maski, reg_zero);
- cmp(reg_xdim, ifwp);
- cmovge(reg_maski, reg_zero);
- and_(reg_maski, reg_maskj);
-
- Opmask kmask_src = Xbyak::Opmask(7);
- auto zmm_src = Xbyak::Zmm(31);
- kmovw(kmask_src, reg_maski_32);
- vpxord(zmm_src, zmm_src, zmm_src);
- vmovups(zmm_src | kmask_src, ptr[reg_src + reg_src_offset]);
- vmovups(ptr[reg_I], zmm_src);
-
- add(reg_xdim, 1); //xdim = ti * tile_size + i
- add(reg_src_offset, simd_w * typesize);
- add(reg_I, simd_w * typesize);
- }
- add(reg_ydim, 1);
- }
- };
-
- auto fma4 = [=](Xbyak::Zmm dst, Xbyak::Zmm a, Xbyak::Zmm b, Xbyak::Zmm c) {
- vmovups(dst, c);
- vfmadd231ps(dst, a, b);
- };
-
- auto trans_I_3x3_4x4 = [=]() {
- //Use 24 registers
- mov(reg_I, ptr[reg_transp + GET_OFF(M)]);
- mov(reg_T, ptr[reg_transp + GET_OFF(T)]);
- for (int i = 0; i < alpha; i++) {
- for (int j = 0; j < alpha; j++) {
- size_t I_off = (j * alpha + i) * simd_w * typesize;
- vmovups(zmm_I(j), ptr[reg_I + I_off]);
- }
-
- fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4));
- fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3));
- fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4));
- fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3));
- fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4));
- fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5));
-
- fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4));
- fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0));
- fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0));
- fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2));
- fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2));
- fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5));
-
- for (int j = 0; j < alpha; j++) {
- vmovups(ptr[reg_T + (j * alpha + i) * simd_w * typesize],
- zmm_T(j));
- }
-
- }
-
- for (int j = 0; j < alpha; j++) {
- for (int i = 0; i < alpha; i++) {
- vmovups(zmm_T(i), ptr[reg_T + (j * alpha + i) * simd_w * typesize]);
- }
-
- fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4));
- fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3));
- fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4));
- fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3));
- fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4));
- fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5));
-
- fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4));
- fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0));
- fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0));
- fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2));
- fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2));
- fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5));
-
- for (int i = 0; i < alpha; i++) {
- size_t dst_off = (j * alpha * jcp.ic_block
- * jcp.nb_tile_block_ur * jcp.tile_block_ur
- + i * jcp.ic_block * jcp.nb_tile_block_ur * jcp.tile_block_ur)
- * simd_w * typesize;
- vmovups(ptr[reg_dst + dst_off], zmm_I(i));
- }
- }
- };
-
- auto compute_transform_SDGtWo = [=]() {
- mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]);
- mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]);
- mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
- mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
- xor_(reg_tile_count, reg_tile_count);
- Label loop_mb, loop_jtiles, loop_itiles, done;
- L(loop_mb);
- {
- L(loop_jtiles);
- {
- L(loop_itiles);
- {
- load_src();
-
- trans_I_3x3_4x4();
-
- add(reg_tile_count, 1);
- cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
- jge(done);
-
- add(reg_dst, simd_w * typesize);
- add(reg_ti, 1);
- cmp(reg_ti, jcp.itiles);
- jl(loop_itiles);
- }
- xor_(reg_ti, reg_ti);
- add(reg_tj, 1);
- cmp(reg_tj, jcp.jtiles);
- jl(loop_jtiles);
- }
- xor_(reg_tj, reg_tj);
- add(reg_src, jcp.ic * jcp.iw * jcp.ih * typesize);
- jmp(loop_mb);
- }
- L(done);
- };
-
- auto compute_transform = [=]() {
- mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
- xor_(reg_ti, reg_ti);
- xor_(reg_tj, reg_tj);
-
- mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
- mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
- imul(reg_temp, reg_tile_count, simd_w * typesize);
- add(reg_dst, reg_temp);
-
- Label loop_jtiles, loop_itiles, next_tile_block, next_tile;
- L(loop_jtiles);
-
- {
- L(loop_itiles);
- {
- load_src();
-
- trans_I_3x3_4x4();
-
- add(reg_tile_count, 1);
- cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
- jge(next_tile_block);
- add(reg_dst, simd_w * typesize);
- jmp(next_tile);
-
- L(next_tile_block);
- sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1)
- * simd_w * typesize);
- size_t tblk_off = alpha * alpha * jcp.ic_block
- * jcp.nb_tile_block_ur * jcp.tile_block_ur
- * simd_w * typesize;
- add(reg_dst, tblk_off);
- xor_(reg_tile_count, reg_tile_count);
-
- L(next_tile);
- add(reg_ti, 1);
- cmp(reg_ti, jcp.itiles);
- jl(loop_itiles);
- }
- xor_(reg_ti, reg_ti);
- add(reg_tj, 1);
- cmp(reg_tj, jcp.jtiles);
- jl(loop_jtiles);
- }
- };
-
- preamble();
- init_G();
- if (jcp.sched_policy == WSCHED_WEI_SDGtWo)
- compute_transform_SDGtWo();
- else
- compute_transform();
- postamble();
-}
-
-void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::
-diff_dst_transform_generate(bool with_bias) {
-
- constexpr int G_size = 8;
- auto zmm_G = [](int i) {
- return Xbyak::Zmm(31);
- };
-
- auto zmm_src = [=](int j, int i) {
- return Xbyak::Zmm(G_size + j * 4 + i);
- };
-
- auto zmm_bias = Xbyak::Zmm(31);
-
- auto load_src = [=]() {
- if (with_bias) vmovups(zmm_bias, ptr[reg_bias]);
- mov(reg_ydim, reg_tj);
- shl(reg_ydim, 2); //tj * tile_size(=4)
- for (int j = 0; j < tile_size; j++) {
- /* check if tile index is within physical spatial boundaries*/
- mov(reg_maskj, 0xffff);
- cmp(reg_ydim, jcp.oh);
- cmovge(reg_maskj, reg_zero);
-
- /*address offset for tile in src*/
- mov(reg_src_offset, reg_ydim);
- imul(reg_src_offset, reg_src_offset, jcp.ow);
-
- mov(reg_xdim, reg_ti);
- shl(reg_xdim, 2); // xdim = ti * tile_size
-
- add(reg_src_offset, reg_xdim);
- imul(reg_src_offset, reg_src_offset, simd_w * typesize);
- for (int i = 0; i < tile_size; i++) {
- /* check if tile index is within physical spatial boundaries*/
- mov(reg_maski, 0xffff);
- cmp(reg_xdim, jcp.ow);
- cmovge(reg_maski, reg_zero);
- and_(reg_maski, reg_maskj);
-
- Opmask kmask_src = Xbyak::Opmask(7);
- kmovw(kmask_src, reg_maski_32);
- vpxord(zmm_src(j, i), zmm_src(j, i), zmm_src(j, i));
- vmovups(zmm_src(j, i) | kmask_src, ptr[reg_src + reg_src_offset]);
- if (with_bias) vaddps(zmm_bias | kmask_src, zmm_bias,
- ptr[reg_src + reg_src_offset]);
-
- add(reg_xdim, 1); //xdim = ti * tile_size + i
- add(reg_src_offset, simd_w * typesize);
- }
- add(reg_ydim, 1);
- }
- if(with_bias) vmovups(ptr[reg_bias], zmm_bias);
- };
-
- auto zmm_t = [=](int i) {
- return Xbyak::Zmm(G_size + 16 + i);
- };
-
- auto zmm_T = [=](int j, int i) {
- return Xbyak::Zmm(j * 4 + i);
- };
-
- auto movps = [=](Xbyak::Reg64 reg_dst, size_t dst_off, Xbyak::Zmm a) {
- if (jcp.sched_policy == WSCHED_WEI_SDGtWo)
- vmovups(ptr[reg_dst + dst_off], a);
- else
- vmovntps(ptr[reg_dst + dst_off], a);
- };
-
- auto trans_W_3x3_4x4 = [=]() {
- mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
- for (int i = 0; i < tile_size; i++) {
- vbroadcastss(zmm_G(0), ptr[reg_G]);
- vmulps(zmm_t(0), zmm_src(2, i), zmm_G(0));
-
- vbroadcastss(zmm_G(1), ptr[reg_G + typesize]);
- vmovups(zmm_t(1), zmm_t(0));
- vfmsub231ps(zmm_t(1), zmm_src(0, i), zmm_G(1));
-
- vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]);
- vmovups(zmm_t(2), zmm_t(0));
- vfmadd231ps(zmm_t(2), zmm_src(0, i), zmm_G(2));
-
- vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]);
- vmulps(zmm_t(3), zmm_src(1, i), zmm_G(3));
-
- vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]);
- vfmadd231ps(zmm_t(3), zmm_src(3, i), zmm_G(4));
-
- vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]);
- vmulps(zmm_t(4), zmm_src(1, i), zmm_G(5));
-
- vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]);
- vfmadd231ps(zmm_t(4), zmm_src(3, i), zmm_G(6));
-
- vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]);
- vmulps(zmm_T(0, i), zmm_src(0, i), zmm_G(7));
- vsubps(zmm_T(1, i), zmm_t(1), zmm_t(3));
- vaddps(zmm_T(2, i), zmm_t(1), zmm_t(3));
- vaddps(zmm_T(3, i), zmm_t(2), zmm_t(4));
- vsubps(zmm_T(4, i), zmm_t(2), zmm_t(4));
- vmovups(zmm_T(5, i), zmm_src(3, i));
- }
-
- for (int j = 0; j < alpha; j++) {
- vbroadcastss(zmm_G(0), ptr[reg_G]);
- vmulps(zmm_t(0), zmm_T(j, 2), zmm_G(0));
-
- vbroadcastss(zmm_G(1), ptr[reg_G + typesize]);
- vmovups(zmm_t(1), zmm_t(0));
- vfmsub231ps(zmm_t(1), zmm_T(j, 0), zmm_G(1));
-
- vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]);
- vmovups(zmm_t(2), zmm_t(0));
- vfmadd231ps(zmm_t(2), zmm_T(j, 0), zmm_G(2));
-
- vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]);
- vmulps(zmm_t(3), zmm_T(j, 1), zmm_G(3));
-
- vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]);
- vfmadd231ps(zmm_t(3), zmm_T(j, 3), zmm_G(4));
-
- vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]);
- vmulps(zmm_t(4), zmm_T(j, 1), zmm_G(5));
-
- vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]);
- vfmadd231ps(zmm_t(4), zmm_T(j, 3), zmm_G(6));
-
- vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]);
- vmulps(zmm_t(0), zmm_T(j, 0), zmm_G(7));
- vsubps(zmm_t(5), zmm_t(1), zmm_t(3));
- vaddps(zmm_t(1), zmm_t(1), zmm_t(3));
- vaddps(zmm_t(6), zmm_t(2), zmm_t(4));
- vsubps(zmm_t(2), zmm_t(2), zmm_t(4));
- vmovups(zmm_t(3), zmm_T(j, 3));
-
- int alpha_offset = (jcp.oc / jcp.nb_oc)
- * (jcp.ntiles / jcp.tile_block) * typesize;
- int dst_off = j * alpha * alpha_offset;
- movps(reg_dst, dst_off, zmm_t(0));
- dst_off += alpha_offset;
- movps(reg_dst, dst_off, zmm_t(5));
- dst_off += alpha_offset;
- movps(reg_dst, dst_off, zmm_t(1));
- dst_off += alpha_offset;
- movps(reg_dst, dst_off, zmm_t(6));
- dst_off += alpha_offset;
- movps(reg_dst, dst_off, zmm_t(2));
- dst_off += alpha_offset;
- movps(reg_dst, dst_off, zmm_t(3));
- }
-
- };
- auto compute_transform_SDGtWo = [=]() {
- mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
- mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
- if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]);
-
- xor_(reg_zero, reg_zero);
- xor_(reg_oc_ur, reg_oc_ur);
- Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, tiles_done;
-
- L(loop_oc_ur);
- {
- mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]);
- mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]);
- xor_(reg_tile_count, reg_tile_count);
- L(loop_mb);
- {
- L(loop_jtiles);
- {
- L(loop_itiles);
- {
- load_src();
-
- trans_W_3x3_4x4();
-
- add(reg_tile_count, 1);
- cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
- jge(tiles_done);
-
- add(reg_dst, jcp.oc_reg_block * simd_w * typesize);
- add(reg_ti, 1);
- cmp(reg_ti, jcp.itiles);
- jl(loop_itiles);
- }
- xor_(reg_ti, reg_ti);
- add(reg_tj, 1);
- cmp(reg_tj, jcp.jtiles);
- jl(loop_jtiles);
- }
- xor_(reg_tj, reg_tj);
- add(reg_src, jcp.oc * jcp.ow * jcp.oh * typesize);
- jmp(loop_mb);
- }
-
- L(tiles_done);
- mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
- add(reg_dst, simd_w * typesize);
- mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
- add(reg_src, jcp.oh * jcp.ow * simd_w * typesize);
-
- if (with_bias) add(reg_bias, simd_w * typesize);
- add(reg_oc_ur, 1);
- cmp(reg_oc_ur, jcp.oc_reg_block);
- jl(loop_oc_ur);
- }
- };
-
- auto compute_transform = [=]() {
- mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
- mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
- if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]);
-
- mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
- mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
- imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize);
- add(reg_dst, reg_temp);
-
- xor_(reg_zero, reg_zero);
- xor_(reg_oc_ur, reg_oc_ur);
- Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, next_tile_block, next_tile;
-
- L(loop_oc_ur);
- {
- xor_(reg_ti, reg_ti);
- xor_(reg_tj, reg_tj);
-
- L(loop_jtiles);
- {
- L(loop_itiles);
- {
- load_src();
-
- trans_W_3x3_4x4();
-
- add(reg_tile_count, 1);
- cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
- jge(next_tile_block);
- add(reg_dst, jcp.oc_reg_block * simd_w * typesize);
- jmp(next_tile);
-
- L(next_tile_block);
- sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1)
- * jcp.oc_reg_block * simd_w * typesize);
- int tblk_off = alpha * alpha * (jcp.oc/jcp.nb_oc)
- * (jcp.ntiles/jcp.tile_block) * typesize;
- add(reg_dst, tblk_off);
- xor_(reg_tile_count, reg_tile_count);
-
- L(next_tile);
- add(reg_ti, 1);
- cmp(reg_ti, jcp.itiles);
- jl(loop_itiles);
- }
- xor_(reg_ti, reg_ti);
- add(reg_tj, 1);
- cmp(reg_tj, jcp.jtiles);
- jl(loop_jtiles);
- }
-
- mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
- mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
- imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize);
- add(reg_dst, reg_temp);
- add(reg_dst, simd_w * typesize);
- mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
- add(reg_src, jcp.oh * jcp.ow * simd_w * typesize);
-
- if (with_bias) add(reg_bias, simd_w * typesize);
- add(reg_oc_ur, 1);
- cmp(reg_oc_ur, jcp.oc_reg_block);
- jl(loop_oc_ur);
- }
- };
-
- preamble();
- if (jcp.sched_policy == WSCHED_WEI_SDGtWo) {
- compute_transform_SDGtWo();
- } else {
- compute_transform();
- }
- postamble();
-}
-
-void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::
-diff_weights_transform_generate(bool first_tile) {
- int G_size = 4;
-
- auto zmm_G = [](int i) {
- return Xbyak::Zmm(i);
- };
-
- auto init_G = [=]() {
- mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
- for (int i = 0; i < G_size; i++)
- vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]);
- };
-
- auto zmm_src = [=](int i) {
- return Xbyak::Zmm(G_size + i);
- };
-
- auto load_src = [=](int i) {
- for (int j = 0; j < alpha; j++) {
- size_t alpha_offset = jcp.oc_block * jcp.oc_reg_block
- * jcp.ic_block * simd_w * simd_w * typesize;
- size_t src_off = (j * alpha + i) * alpha_offset;
- vmovups(zmm_src(j), EVEX_compress_addr(reg_src, src_off));
- }
- };
-
- auto zmm_t = [=](int i) {
- return Xbyak::Zmm(G_size + 6 + i);
- };
-
- auto zmm_T = [=](int j, int i) {
- return Xbyak::Zmm(G_size + 6 + 3 + j * 6 + i);
- };
-
- auto zmm_dst = [=](int i) {
- return Xbyak::Zmm(G_size + i);
- };
-
- auto zmm_temp = Xbyak::Zmm(31);
-
- auto store_dst = [=](int j) {
- for (int i = 0; i < jcp.kw; i++) {
- size_t dst_off = (j * jcp.kw + i) * simd_w * simd_w * typesize;
-
- if (!first_tile) {
- vmovups(zmm_temp, EVEX_compress_addr(reg_dst, dst_off));
- vaddps(zmm_dst(i), zmm_dst(i), zmm_temp);
- }
- vmovntps(EVEX_compress_addr(reg_dst, dst_off), zmm_dst(i));
- }
- };
-
- auto compute_transform = [=] () {
- mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
- mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
-
- xor_(reg_ic_simd, reg_ic_simd);
- Label loop_ic_simd;
- L(loop_ic_simd);
- {
- for (int i = 0; i < alpha; i++) {
- load_src(i);
-
- vaddps(zmm_t(0), zmm_src(1), zmm_src(2));
- vaddps(zmm_t(1), zmm_src(3), zmm_src(4));
- vmovups(zmm_t(2), zmm_src(5));
- vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0));
-
- vaddps(zmm_T(0, i), zmm_src(0), zmm_t(0));
- vaddps(zmm_T(0, i), zmm_T(0, i), zmm_t(1));
- vsubps(zmm_T(1, i), zmm_src(1), zmm_src(2));
- vmulps(zmm_T(1, i), zmm_T(1, i), zmm_G(1));
- vsubps(zmm_temp, zmm_src(3), zmm_src(4));
- vfmadd231ps(zmm_T(1, i), zmm_temp, zmm_G(2));
- vmovups(zmm_T(2, i), zmm_t(2));
- vfmadd231ps(zmm_T(2, i), zmm_t(0), zmm_G(3));
- }
-
- for (int j = 0; j < jcp.kh; j++) {
- vaddps(zmm_t(0), zmm_T(j, 1), zmm_T(j, 2));
- vaddps(zmm_t(1), zmm_T(j, 3), zmm_T(j, 4));
- vmovups(zmm_t(2), zmm_T(j, 5));
- vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0));
-
- vaddps(zmm_dst(0), zmm_T(j, 0), zmm_t(0));
- vaddps(zmm_dst(0), zmm_dst(0), zmm_t(1));
- vsubps(zmm_dst(1), zmm_T(j, 1), zmm_T(j, 2));
- vmulps(zmm_dst(1), zmm_dst(1), zmm_G(1));
- vsubps(zmm_temp, zmm_T(j, 3), zmm_T(j, 4));
- vfmadd231ps(zmm_dst(1), zmm_temp, zmm_G(2));
- vmovups(zmm_dst(2), zmm_t(2));
- vfmadd231ps(zmm_dst(2), zmm_t(0), zmm_G(3));
-
- store_dst(j);
- }
-
- add(reg_src, jcp.oc_reg_block * simd_w * typesize);
- add(reg_dst, simd_w * typesize);
- add(reg_ic_simd, 1);
- cmp(reg_ic_simd, simd_w);
- jl(loop_ic_simd);
- }
- };
- preamble();
- push(reg_EVEX_max_8b_offt);
- mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
- init_G();
- compute_transform();
- pop(reg_EVEX_max_8b_offt);
- postamble();
-}
-
-void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::gemm_loop_generate(
- bool is_first_tile)
-{
- auto zmm_srcA = [=]() {
- return Xbyak::Zmm(0);
- };
-
- auto zmm_srcB = [=] (size_t N_ur){
- return Xbyak::Zmm(N_ur + 1);
- };
-
- auto broadcastB = [=](size_t K_ur) {
- for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) {
- size_t srcB_off = (K_ur * jcp.dimN_reg_block + N_bcast)
- * sizeof(float);
- vbroadcastss(zmm_srcB(N_bcast), EVEX_compress_addr(reg_srcB, srcB_off));
- }
- };
-
- auto load_srcA = [=] (size_t K_ur, int M_ur) {
- size_t srcA_off = (K_ur * jcp.dimM_reg_block * jcp.dimM_simd_block
- + M_ur * jcp.dimM_simd_block) * sizeof(float);
- vmovups(zmm_srcA(), EVEX_compress_addr(reg_srcA, srcA_off));
- };
-
- auto zmm_dstC = [=](size_t M_reg_ur, int N_bcast){
- size_t idx = 1 // zmm_srcA
- + jcp.dimN_bcast_ur // zmm_srcB
- + M_reg_ur * jcp.dimN_bcast_ur + N_bcast;
- assert(idx < 32);
- return Xbyak::Zmm(idx);
- };
- auto prepare_accumm = [=](){
- for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) {
- for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) {
- Zmm zmm = zmm_dstC(M_reg_ur, N_bcast);
- vpxord(zmm, zmm, zmm);
- }
- }
- };
-
- auto store_dstC = [=](){
- /******** Write C back to memory *******/
- for (int M_reg = 0; M_reg < jcp.dimM_reg_block; M_reg++) {
- for (int N_ur = 0; N_ur < jcp.dimN_bcast_ur; ++N_ur) {
- Zmm zmm = zmm_dstC(M_reg, N_ur);
- size_t C_off = (N_ur * jcp.dimM_reg_block * jcp.dimM_simd_block
- + M_reg * jcp.dimM_simd_block) * sizeof(float);
- if (!is_first_tile) {
- vmovups(Xbyak::Zmm(0), EVEX_compress_addr(reg_dstC, C_off));
- vaddps(zmm, zmm, Xbyak::Zmm(0));
- }
- vmovups(EVEX_compress_addr(reg_dstC, C_off), zmm);
- }
- }
- };
-
- auto inner_loops = [=]() {
- Label dimM_block_loop, dimK_block_loop, dimN_block_loop, dimN_bcast_ur;
-
- mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
- L(dimM_block_loop);
- { /************* OC_block (M) loop ***********/
- mov(reg_dimN_block_loop_cnt, jcp.dimN_block);
- L(dimN_block_loop);
- { /*************** IC_block (N) loop *********/
-
- mov(reg_nb_dimN_bcast_ur, jcp.dimN_reg_block/jcp.dimN_bcast_ur);
- L(dimN_bcast_ur);
- {
- prepare_accumm();
-
- mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
- L(dimK_block_loop);
- {
- /************* nb_tile_ur(K) loop ********/
- for (int K_ur = 0; K_ur < jcp.dimK_reg_block; K_ur++) {
-
- broadcastB(K_ur);
-
- for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) {
- load_srcA(K_ur, M_reg_ur);
- for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; ++N_bcast) {
- vfmadd231ps(zmm_dstC(M_reg_ur, N_bcast), zmm_srcA(),
- zmm_srcB(N_bcast));
- }
- }
- }
- add(reg_srcA, jcp.dimK_reg_block
- * jcp.dimM_reg_block * jcp.dimM_simd_block
- * sizeof(float));
- add(reg_srcB, jcp.dimK_reg_block
- * jcp.dimN_reg_block
- * sizeof(float));
- sub(reg_dimK_block_loop_cnt, 1);
- jnz(dimK_block_loop);
- }
-
- store_dstC();
-
- sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block
- * jcp.dimM_reg_block * jcp.dimM_simd_block
- * sizeof(float));
- sub(reg_srcB, jcp.dimK_block * jcp.dimK_reg_block
- * jcp.dimN_reg_block
- * sizeof(float));
- add(reg_srcB, jcp.dimN_bcast_ur * sizeof(float));
- add(reg_dstC, jcp.dimN_bcast_ur
- * jcp.dimM_reg_block * jcp.dimM_simd_block
- * sizeof(float));
- sub(reg_nb_dimN_bcast_ur, 1);
- jnz(dimN_bcast_ur);
- }
-
- sub(reg_srcB, jcp.dimN_reg_block * sizeof(float));
- add(reg_srcB, jcp.dimK_block
- * jcp.dimK_reg_block
- * jcp.dimN_reg_block * sizeof(float));
- sub(reg_dimN_block_loop_cnt, 1);
- jnz(dimN_block_loop);
- }
-
- sub(reg_srcB, jcp.dimN_block
- * jcp.dimK_block * jcp.dimK_reg_block
- * jcp.dimN_reg_block
- * sizeof(float));
- add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block
- * jcp.dimM_reg_block * jcp.dimM_simd_block
- * sizeof(float));
- sub(reg_dimM_block_loop_cnt, 1);
- jnz(dimM_block_loop);
- }
- };
-
- /* Preamble */
- preamble();
-
- inner_loops();
-
- /* Postamble */
- postamble();
- ret();
-}
-
-namespace {
-
-void set_jcp_WEI_params(jit_conv_winograd_conf_t &jcp) {
-/*M params*/
- jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block
- / jcp.dimM_simd_block;
- jcp.oc_reg_block = jcp.dimM_reg_block;
- jcp.oc_block = jcp.dimM_block;
- jcp.nb_oc = jcp.dimM_nb_block;
- /*N params*/
- jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block;
- jcp.ic_block = jcp.dimN_block;
- jcp.nb_ic = jcp.dimN_nb_block;
-
- /*K params*/
- jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block;
- jcp.tile_block_ur = jcp.dimK_reg_block;
- jcp.nb_tile_block_ur = jcp.dimK_block;
- jcp.tile_block = jcp.dimK_nb_block;
-}
-
-status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) {
-
- size_t K_blk_ur, N_blk, M_blk;
- /* IS this strategy feasible? */
- auto test_MV_large_enough = [](jit_conv_winograd_conf_t &jcp) {
- size_t M_sz = alpha * alpha * jcp.dimM * jcp.dimK * sizeof(float);
- size_t V_sz = alpha * alpha * jcp.dimN * jcp.dimK * sizeof(float);
- size_t nthreads = mkldnn_get_max_threads();
- return (((V_sz + M_sz) / nthreads) >= 2 * L2_cache_size)
- && (jcp.dimK / nthreads >= 1.0);
- };
-
- auto test_min_dimK_L1 = [](jit_conv_winograd_conf_t &jcp, int dimK_block_ur,
- int max_block=1) {
- size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * dimK_block_ur * sizeof(float);
- size_t L1_block_N = jcp.dimN_reg_block * dimK_block_ur * sizeof(float);
- size_t M_L2_block = alpha * alpha * jcp.dimM * dimK_block_ur * sizeof(float);
- size_t nthreads = mkldnn_get_max_threads();
- bool load_balance=true;
- if (!(jcp.dimK % nthreads)) {
- load_balance = ((jcp.dimK / dimK_block_ur) % nthreads == 0);
- }
- return (L1_block_M + L1_block_N >= 0.1 * L1_cache_size)
- && (L1_block_M + L1_block_N <= 0.5 * L1_cache_size)
- && load_balance
- && (M_L2_block < L2_cache_size);
- };
-
- auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur,
- int useless_arg=0) {
- return (dimK_ur >= 2) && (dimK_ur <= 8);
- };
-
- auto blocking_ok = [&](){
- size_t M_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block
- * K_blk_ur * sizeof(float);
- size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block
- * K_blk_ur * sizeof(float);
- size_t U_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block
- * N_blk * jcp.dimN_reg_block * sizeof(float);
- size_t L2_block = M_L2_block + V_L2_block + U_L2_block;
- /*Replace 2.375 with L2+L3 cache size*/
- return (L2_block > 0.1 * L2_cache_size) && (L2_block <= 1.2 * L2_cache_size);
- };
-
- if (test_MV_large_enough(jcp)) {
- if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) {
- jcp.dimM_reg_block = 2;
- } else {
- jcp.dimM_reg_block = 1;
- }
- jcp.dimM_simd_block = jcp.oc_simd_block;
- jcp.dimN_reg_block = jcp.ic_simd_block;
- jcp.dimN_bcast_ur = 8;
- /*dimK_block and dimK_ur*/
- size_t min_dimK_block_ur = get_divisor_satisfying_cond(jcp, jcp.dimK, 1, test_min_dimK_L1);
-
- jcp.dimM_block = jcp.dimM/jcp.dimM_reg_block/jcp.dimM_simd_block;
- jcp.dimN_block = jcp.dimN/jcp.dimN_reg_block;
- for (K_blk_ur = min_dimK_block_ur; K_blk_ur >= 1; --K_blk_ur) {
- if (test_min_dimK_L1(jcp, K_blk_ur) && !(jcp.dimK % K_blk_ur)) {
- for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) {
- if (!(jcp.dimN_block % N_blk)) {
- for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) {
- if (!(jcp.dimM_block % M_blk) && blocking_ok()) {
- jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur);
- if (!test_dimK_ur(jcp, jcp.dimK_reg_block)) return status::unimplemented;
- jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block;
- jcp.dimN_block = N_blk;
- jcp.dimM_block = M_blk;
- jcp.sched_policy = WSCHED_WEI_SDGtWo;
- set_jcp_WEI_params(jcp);
- jcp.nthr = nstl::min(mkldnn_get_max_threads(),
- jcp.tile_block);
- return status::success;
- }
- }
- }
- }
- }
- }
- }
- return status::unimplemented;
-}
-
-status_t set_wsched_WEI_S_D_Giot_W(jit_conv_winograd_conf_t &jcp) {
- if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) {
- jcp.dimM_reg_block = 2;
- } else {
- jcp.dimM_reg_block = 1;
- }
- jcp.dimN_bcast_ur = 8;
- jcp.dimN_reg_block = jcp.ic_simd_block;
- jcp.dimM_simd_block = jcp.oc_simd_block;
- jcp.dimN_block = jcp.dimN / jcp.dimN_reg_block;
- jcp.dimM_block = jcp.dimM / jcp.dimM_reg_block / jcp.dimM_simd_block;
- float C1 = 0.0, C2 = 0.0;
- float C1_max = 0.5, C2_max = 1.4;
- int N_blk, M_blk, K_blk_ur;
-
- auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur,
- int useless_arg=0) {
- return (dimK_ur >= 2) && (dimK_ur <= 8);
- };
-
- auto blocking_ok = [&]() -> bool {
- size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * K_blk_ur * sizeof(float);
- size_t L1_block_N = jcp.dimN_reg_block * K_blk_ur * sizeof(float);
- bool L1_cond = ((L1_block_N + L1_block_M) >= C1 * L1_cache_size)
- && ((L1_block_N + L1_block_M) <= C1_max * L1_cache_size);
-
- size_t nb_N_blk = jcp.dimN/N_blk/jcp.dimN_reg_block;
- size_t nb_M_blk = jcp.dimM/M_blk/jcp.dimM_reg_block/jcp.dimM_simd_block;
- size_t nb_K_blk = jcp.dimK / K_blk_ur;
- size_t nthreads = mkldnn_get_max_threads();
- bool load_balance = (nb_K_blk * nb_N_blk * nb_M_blk) >= nthreads;
- if (!(nb_K_blk % nthreads)) {
- load_balance = load_balance && (nb_K_blk % nthreads == 0);
- }
-
- size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block * K_blk_ur * sizeof(float);
-
- size_t L2_block = V_L2_block;
- /*Replace 2.375 with L2+L3 cache size*/
- bool L2_cond = (L2_block >= C2 * L2_cache_size) && (L2_block <= C2_max * L2_cache_size);
- return L1_cond && load_balance && L2_cond;
- };
-
- for (K_blk_ur = jcp.dimK; K_blk_ur >= 1; --K_blk_ur) {
- if (jcp.dimK % K_blk_ur == 0) {
- for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) {
- if (jcp.dimN_block % N_blk == 0) {
- for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) {
- if (jcp.dimM_block % M_blk == 0) {
- if (blocking_ok()) {
- jcp.dimN_block = N_blk;
- jcp.dimM_block = M_blk;
- jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur);
- jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block;
- jcp.sched_policy = WSCHED_WEI_S_D_Giot_W;
- set_jcp_WEI_params(jcp);
- return status::success;
- }
- }
- }
- }
- }
- }
- }
- jcp.dimK_reg_block = 1;
- jcp.dimK_block = 1;
- jcp.sched_policy = WSCHED_WEI_S_D_Giot_W;
- set_jcp_WEI_params(jcp);
- return status::success;
-}
-} // namespace
-status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::init_conf(
- jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d,
- const memory_desc_wrapper &diff_weights_d) {
- if (!mayiuse(avx512_core))
- return status::unimplemented;
- else
- jcp.ver = ver_avx512_core;
-
- jcp.nthr = mkldnn_get_max_threads();
-
- jcp.prop_kind = cd.prop_kind;
- const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
- jcp.mb = src_d.dims()[0];
- jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
- jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
- jcp.oc_without_padding = jcp.oc;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
- jcp.ih = src_d.dims()[2];
- jcp.iw = src_d.dims()[3];
- jcp.oh = diff_dst_d.dims()[2];
- jcp.ow = diff_dst_d.dims()[3];
- jcp.kh = diff_weights_d.dims()[with_groups + 2];
- jcp.kw = diff_weights_d.dims()[with_groups + 3];
- jcp.t_pad = cd.padding[0][0];
- jcp.l_pad = cd.padding[0][1];
- jcp.stride_h = cd.strides[0];
- jcp.stride_w = cd.strides[1];
- jcp.r_pad = nstl::max(
- 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
- jcp.b_pad = nstl::max(
- 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
- jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
- jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
- jcp.ohp = jcp.oh;
- jcp.owp = jcp.ow;
- jcp.with_bias = (cd.diff_bias_desc.format_kind != format_kind::undef);
- jcp.dilate_h = cd.dilates[0];
- jcp.dilate_w = cd.dilates[1];
-
- bool ok_to_pad_channels = jcp.ngroups == 1;
- if (ok_to_pad_channels) {
- jcp.oc = rnd_up(jcp.oc, simd_w);
- jcp.ic = rnd_up(jcp.ic, simd_w);
- }
-
- // Winograd specific initialization
- jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
- jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
- jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
-
- // Winograd kernel works only for 3x3 convolution with stride 1
- if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
- is_winograd_faster_than_direct(jcp)))
- return status::unimplemented;
-
- if (jcp.ngroups != 1)
- return status::unimplemented;
- if ((jcp.kh != 3) || (jcp.kw != 3))
- return status::unimplemented;
- if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
- return status::unimplemented;
- if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
- return status::unimplemented;
- if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
- return status::unimplemented;
-
- format_tag_t dat_tag = nChw16c;
- format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o;
- jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
- jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
- jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag);
-
- if (jcp.src_tag != dat_tag) return status::unimplemented;
- if (jcp.wei_tag != wei_tag) return status::unimplemented;
- if (jcp.dst_tag != dat_tag) return status::unimplemented;
-
- bool layout_consistency = true
- && jcp.ic <= src_d.padded_dims()[1]
- && jcp.oc <= diff_dst_d.padded_dims()[1]
- && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
- && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
- if (!layout_consistency) return status::unimplemented;
-
- /******************Kernel blocking Parameters ***********/
- jcp.ic_simd_block = simd_w;
- jcp.oc_simd_block = simd_w;
-
- jcp.dimK = jcp.ntiles;
- jcp.dimN = jcp.ic;
- jcp.dimM = jcp.oc;
- jcp.dimM_simd_block = jcp.oc_simd_block;
- jcp.dimN_reg_block = jcp.ic_simd_block;
- jcp.sched_policy = WSCHED_INVALID;
- status_t res = set_wsched_WEI_SDGtWo(jcp);
- if (res == status::unimplemented) {
- res = set_wsched_WEI_S_D_Giot_W(jcp);
- assert(res == status::success);
- }
- return res;
-}
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp
deleted file mode 100644
index 025a554d92..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp
+++ /dev/null
@@ -1,291 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef JIT_AVX512_CORE_FP32_WINO_CONV_4x3_KERNEL_HPP
-#define JIT_AVX512_CORE_FP32_WINO_CONV_4x3_KERNEL_HPP
-
-#include "c_types_map.hpp"
-
-#include "jit_generator.hpp"
-#include "jit_primitive_conf.hpp"
-
-#include "jit_avx512_common_conv_winograd_kernel_f32.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct _jit_avx512_core_fp32_wino_conv_4x3_data_kernel
- : public jit_generator {
- _jit_avx512_core_fp32_wino_conv_4x3_data_kernel(
- jit_conv_winograd_conf_t ajcp)
- : jcp(ajcp) {
- {
- this->weights_transform_data_ker_generate();
- weights_transform_data_ker
- = (decltype(weights_transform_data_ker)) this->getCode();
- }
- {
- align();
- const Xbyak::uint8 *addr = getCurr();
- this->input_transform_data_ker_generate();
- input_transform_data_ker = (decltype(input_transform_data_ker))addr;
- }
- {
- align();
- const Xbyak::uint8 *addr = getCurr();
- this->output_transform_data_ker_generate();
- output_transform_data_ker
- = (decltype(output_transform_data_ker))addr;
- }
- {
- align();
- const Xbyak::uint8 *addr = getCurr();
- this->gemm_loop_generate();
- gemm_loop_ker = (decltype(gemm_loop_ker))addr;
- }
- }
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_fp32_wino_conv_4x3_data_kernel)
-
- static status_t init_conf_common(jit_conv_winograd_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d);
-
- static status_t init_conf_kernel(
- jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK);
-
- jit_conv_winograd_conf_t jcp;
- void (*gemm_loop_ker)(float *, const float *, const float *, const int);
- void (*input_transform_data_ker)(jit_wino_transform_call_s *);
- void (*output_transform_data_ker)(jit_wino_transform_call_s *);
- void (*weights_transform_data_ker)(jit_wino_transform_call_s *);
-
-protected:
- using reg64_t = const Xbyak::Reg64;
- using reg32_t = const Xbyak::Reg32;
- enum { typesize = sizeof(float) };
-
- void gemm_loop_generate();
- void input_transform_data_ker_generate();
- void output_transform_data_ker_generate();
- void weights_transform_data_ker_generate();
-
- /* registers used for GEMM */
- reg64_t reg_dstC = abi_param1;
- reg64_t reg_srcA = abi_param2;
- reg64_t reg_srcB = abi_param3;
- reg64_t reg_is_beta_zero = abi_param4;
-
- reg64_t reg_dimM_block_loop_cnt = r10;
- reg64_t reg_dimK_block_loop_cnt = r11;
-
- /* registers used for transforms*/
- reg64_t param = abi_param1;
-
- /* registers used for output_transform_data_ker */
- reg64_t oreg_temp = abi_not_param1;
- reg64_t oreg_Ow = r9;
- reg64_t oreg_src = r11;
- reg64_t oreg_tile_block = r12;
- reg64_t oreg_tile_block_ur = r13;
- reg64_t oreg_nb_tile_block_ur = r14;
- reg64_t oreg_O = r8;
- reg64_t oreg_T = r10;
- reg64_t oreg_dst = r11;
- reg64_t oreg_ydim = r14;
- reg64_t oreg_xdim = r15;
- reg64_t oreg_out_j = r12;
- reg64_t oreg_bias = rbx;
- reg64_t imm_addr64 = rax;
-
- /* registers used for input_transform_data_ker */
- reg64_t ireg_temp = abi_not_param1;
- reg64_t ireg_jtiles = rax;
- reg64_t ireg_itiles = rbx;
- reg64_t ireg_I = r8;
- reg64_t ireg_src = r13;
- reg64_t ireg_ydim = r14;
- reg64_t ireg_xdim = r15;
- reg64_t ireg_inp_j = r12;
- reg64_t ireg_inp_i = rdx;
- reg64_t ireg_mask_j = r11;
- reg64_t ireg_mask = rsi;
- reg32_t ireg_mask_32 = esi;
- reg64_t ireg_zero = r9;
- reg64_t ireg_Iw = r9;
- reg64_t ireg_T = r10;
- reg64_t ireg_tile_block = r12;
- reg64_t ireg_tile_block_ur = r13;
- reg64_t ireg_nb_tile_block_ur = r14;
- reg64_t ireg_output = r15;
-
- /* registers used for wei transform */
- reg64_t wreg_temp = abi_not_param1;
- reg64_t wreg_F = r8;
- reg64_t wreg_src = r9;
- reg64_t wreg_MT = r15;
- reg64_t wreg_M = r14;
- reg64_t wreg_dst = r10;
- reg64_t wreg_dst_aux = r9;
- reg64_t wreg_dst_idx = r8;
- reg64_t wreg_Fw = r11;
- reg64_t wreg_T = r12;
- reg64_t wreg_cnt_j = rdx;
- reg64_t wreg_F_aux = r14;
- reg64_t wreg_Fw_aux = r15;
-};
-
-struct jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel
- : _jit_avx512_core_fp32_wino_conv_4x3_data_kernel {
- using _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::
- _jit_avx512_core_fp32_wino_conv_4x3_data_kernel;
-
- static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr);
-
- static status_t init_conf(jit_conv_winograd_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_t &src_md,
- memory_desc_t &weights_md, const memory_desc_t &dst_md,
- const primitive_attr_t &attr);
-};
-
-struct jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel
- : public _jit_avx512_core_fp32_wino_conv_4x3_data_kernel {
- using _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::
- _jit_avx512_core_fp32_wino_conv_4x3_data_kernel;
-
- static status_t init_conf(jit_conv_winograd_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &diff_dst_d);
-};
-
-struct jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel
- : public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(
- _jit_avx512_core_conv_winograd_bwd_weights_kernel_f32)
-
- jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel(
- jit_conv_winograd_conf_t ajcp)
- : jcp(ajcp)
- {
- //******************* First iter kernel ********************//
- this->gemm_loop_generate(true);
- gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))this->getCode();
-
- align();
- const Xbyak::uint8 *addr = getCurr();
- this->src_transform_generate();
- src_transform = (decltype(src_transform))addr;
-
- if (jcp.with_bias) {
- align();
- addr = getCurr();
- this->diff_dst_transform_generate(true);
- diff_dst_transform_wbias = (decltype(diff_dst_transform_wbias))addr;
- }
-
- align();
- addr = getCurr();
- this->diff_dst_transform_generate(false);
- diff_dst_transform = (decltype(diff_dst_transform))addr;
-
- if (jcp.sched_policy != WSCHED_WEI_SDGtWo && jcp.tile_block > 1) {
- align();
- addr = getCurr();
- this->gemm_loop_generate(false);
- gemm_loop_ker = (decltype(gemm_loop_ker))addr;
- }
-
- align();
- addr = getCurr();
- this->diff_weights_transform_generate(true);
- diff_weights_transform = (decltype(diff_weights_transform))addr;
-
- if (jcp.sched_policy == WSCHED_WEI_SDGtWo) {
- align();
- addr = getCurr();
- this->diff_weights_transform_generate(false);
- diff_weights_transform_accum =
- (decltype(diff_weights_transform_accum))addr;
- };
- }
-
- static status_t init_conf(jit_conv_winograd_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &diff_dst_d,
- const memory_desc_wrapper &diff_weights_d);
-
- jit_conv_winograd_conf_t jcp;
- void (*gemm_loop_ker)(float *, const float *, const float *);
- void (*gemm_loop_ker_first_iter)(float *, const float *, const float *);
- void (*src_transform)(jit_wino_transform_call_s *);
- void (*diff_dst_transform)(jit_wino_transform_call_s *);
- void (*diff_dst_transform_wbias)(jit_wino_transform_call_s *);
- void (*diff_weights_transform)(jit_wino_transform_call_s *);
- void (*diff_weights_transform_accum)(jit_wino_transform_call_s *);
-
-private:
- using reg64_t = const Xbyak::Reg64;
- using reg32_t = const Xbyak::Reg32;
- enum { typesize = sizeof(float) };
-
- void src_transform_generate();
- void diff_dst_transform_generate(bool with_bias);
- void diff_weights_transform_generate(bool first_tile);
-
- /*registers common to transforms*/
- reg64_t reg_transp = abi_param1;
- reg64_t reg_ti = rbx;
- reg64_t reg_tj = abi_not_param1;
- reg64_t reg_src = r8;
- reg64_t reg_dst = r9;
- reg64_t reg_G = rsi; /*TODO: check if this is ok*/
- reg64_t reg_temp = rsi;
-
- /*registers common to src/diff_dst transform*/
- reg64_t reg_I = r10;
- reg64_t reg_ydim = r11;
- reg64_t reg_xdim = r12;
- reg64_t reg_src_offset = r13;
- reg64_t reg_zero = r14;
- reg64_t reg_tile_count = r15;
- reg64_t reg_maski = rsi;
- reg32_t reg_maski_32 = esi;
- reg64_t reg_maskj = rdx;
-
- reg64_t reg_T = rax;
- reg64_t reg_oc_ur = rax;
- reg64_t reg_ic_simd = r14;
- reg64_t reg_bias = r10;
-
- void gemm_loop_generate(bool is_first_tile);
-
- reg64_t reg_dstC = abi_param1;
- reg64_t reg_srcA = abi_param2;
- reg64_t reg_srcB = abi_param3;
-
- reg64_t reg_dimM_block_loop_cnt = r9;
- reg64_t reg_dimN_block_loop_cnt = r10;
- reg64_t reg_nb_dimN_bcast_ur = r11;
- reg64_t reg_dimK_block_loop_cnt = r12;
-};
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp
deleted file mode 100644
index 002010ffa2..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp
+++ /dev/null
@@ -1,1284 +0,0 @@
-/*******************************************************************************
- * Copyright 2018 Intel Corporation
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *******************************************************************************/
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "jit_avx512_core_u8s8s32x_wino_convolution.hpp"
-#include "jit_generator.hpp"
-
-#include <string.h>
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::memory_tracking::names;
-using namespace mkldnn::impl::utils;
-using namespace Xbyak;
-
-namespace {
- // Below scales are applied to source and weights data accordingly
- // because this winograd implementation
- // transforms source which may increase values up to 4x
- // and transforms weights which may increase values up to 9/4x
- const float adj_src_scale = 1.f / 4.f;
- const float adj_wei_scale = 4.f / 9.f;
- // Winograd transforms need ic and oc to be multiples of 16
- const int load_block = 16;
-}
-
-/// SRC TRANSFORMS /////////////////////////////////////////////////////////////
-struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t: public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(
- jit_avx512_core_u8s8s32x_wino_conv_src_trans_t)
-
- jit_conv_conf_2x3_wino_t jcp;
- const primitive_attr_t &attr_;
-
- struct call_params_t {
- const void *src;
- const void *wino_src;
- const void *v_y_masks;
- const void *v_x_masks;
- };
- void (*ker_)(const call_params_t *);
-
- jit_avx512_core_u8s8s32x_wino_conv_src_trans_t(
- jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
- : jcp(ajcp), attr_(attr), unsign_val_in_wino_domain(5) {
- generate();
- ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode()));
- }
- void generate();
-
- int reg_inp_ind(int i) {
- assert(i < jcp.alpha * jcp.alpha);
- return (31 - i);
- }
-
- Xmm vreg_inp(int i) {
- return Xmm(reg_inp_ind(i));
- }
-
- Zmm zmm_inp(int i) {
- return Zmm(reg_inp_ind(i));
- }
-
- Xmm vreg_tmp(int i) {
- assert(i < jcp.alpha * jcp.alpha);
- return Xmm(15 - i);
- }
- Xmm vreg_out(int i) {
- assert(i < jcp.alpha * jcp.alpha);
- return Xmm(31 - i);
- }
-
- Opmask y_mask = Opmask(1);
- Opmask r_mask = Opmask(2);
- Opmask x_mask(int id) {
- assert(id < 4);
- return Opmask(3 + id);
- }
-
- Reg64 reg_ptr_src = r14;
- Reg64 reg_ptr_dst = r13;
-
- Reg64 reg_ptr_v_y_masks = r12;
- Reg64 reg_ptr_v_x_masks = r11;
-
- Reg64 reg_aux_ptr_src = r10;
- Reg64 reg_aux_ptr_dst = r9;
-
- Reg64 reg_ic_block = r8;
-
- int unsign_val_in_wino_domain;
-
- Reg64 reg_scratch_src_alpha = rdx;
- Xmm xmm_src_alpha = Xmm(0);
- Zmm zmm_src_alpha = Zmm(0);
-
- Reg64 reg_shift = rax;
- Xmm xmm_shift = Xmm(1);
- Xmm xmm_zero = Xmm(0);
-
- Reg64 reg_maskx = rbx;
- Reg64 reg_masky = rsi;
- Reg64 reg_nomask = reg_maskx;
-};
-
-void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() {
- Label ic_block_label;
- Label end_label;
- Label mask_label;
- Label nomask_label;
-
- auto load_src = [=](bool mask) {
- for (int y = 0; y < jcp.alpha; y++) {
- if (mask)
- kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(uint16_t) * y]);
- for (int x = 0; x < jcp.alpha; x++) {
- Zmm zmm_i = zmm_inp(y * jcp.alpha + x);
- Xmm vreg_i = vreg_inp(y * jcp.alpha + x);
- int inp_offset = sizeof(uint8_t)
- * ((-jcp.t_pad + y) * jcp.iw * jcp.ic
- + (-jcp.l_pad + x) * jcp.ic);
- if (mask) {
- kandw(r_mask, y_mask, x_mask(x));
- vmovdqu8(vreg_i | r_mask | T_z,
- EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
- } else {
- vmovdqu8(vreg_i,
- EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
- }
- vpmovzxbd(zmm_i, vreg_i); // to int32
- vcvtdq2ps(zmm_i, zmm_i); // to fp32
- vmulps(zmm_i, zmm_i, zmm_src_alpha); // *alpha
- vcvtps2dq(zmm_i, zmm_i); // to int32
- vpmovusdb(vreg_i, zmm_i); // to u8
- }
- }
- };
-
- preamble();
-
-# define READ_PARAM(reg, field) \
- mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
- READ_PARAM(reg_ptr_src, src);
- READ_PARAM(reg_ptr_dst, wino_src);
- READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
- READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
-# undef READ_PARAM
-
- mov(reg_maskx, ptr[reg_ptr_v_x_masks]);
- mov(reg_masky, ptr[reg_ptr_v_y_masks]);
- test(reg_maskx, reg_maskx);
- jz(end_label, T_NEAR); // skip kernel if x mask is all 0's
- test(reg_masky, reg_masky);
- jz(end_label, T_NEAR); // skip kernel if y mask is all 0's
- and_(reg_maskx, reg_masky);
- mov(reg_nomask, reg_maskx);
- not_(reg_nomask); // zero if x and y masks are all 1's
-
- xor_(reg_shift, reg_shift);
- mov(reg_shift.cvt8(), (int8_t)-128);
-
- mov(reg_aux_ptr_src, reg_ptr_src);
- mov(reg_aux_ptr_dst, reg_ptr_dst);
-
- for (int i = 0; i < jcp.alpha; i++) {
- kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]);
- }
-
- mov(reg_scratch_src_alpha, float2int(adj_src_scale));
-
- mov(reg_ic_block, jcp.ic / load_block);
- L(ic_block_label);
- {
- vmovq(xmm_src_alpha, reg_scratch_src_alpha);
- vbroadcastss(zmm_src_alpha, xmm_src_alpha);
-
- test(reg_nomask, reg_nomask);
- jz(nomask_label, T_NEAR);
- load_src(true);
- jmp(mask_label, T_NEAR);
- L(nomask_label);
- load_src(false);
- L(mask_label);
-
- for(int y = 0; y < 4; y++) {
- vpsubb(vreg_tmp(y*4+0), vreg_inp(y*4+0), vreg_inp(y*4+2));
- vpaddb(vreg_tmp(y*4+1), vreg_inp(y*4+1), vreg_inp(y*4+2));
- vpsubb(vreg_tmp(y*4+2), vreg_inp(y*4+2), vreg_inp(y*4+1));
- vpsubb(vreg_tmp(y*4+3), vreg_inp(y*4+1), vreg_inp(y*4+3));
- }
- for(int x = 0;x < 4; x++) {
- vpsubb(vreg_out(x+0*4), vreg_tmp(x+4*0), vreg_tmp(x+4*2));
- vpaddb(vreg_out(x+1*4), vreg_tmp(x+4*1), vreg_tmp(x+4*2));
- vpsubb(vreg_out(x+2*4), vreg_tmp(x+4*2), vreg_tmp(x+4*1));
- vpsubb(vreg_out(x+3*4), vreg_tmp(x+4*1), vreg_tmp(x+4*3));
- }
-
- vmovd(xmm_shift, reg_shift.cvt32());
- vpxor(xmm_zero, xmm_zero, xmm_zero);
- vpshufb(xmm_shift, xmm_shift, xmm_zero);
-
- for (int i = 0; i < 16; i++) {
- int out_offset = sizeof(uint8_t) * (jcp.inp_stride * i);
- if (i != unsign_val_in_wino_domain)
- vpsubb(vreg_out(i), vreg_out(i), Xmm(1));
- vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset), vreg_out(i));
- }
-
- add(reg_aux_ptr_src, sizeof(uint8_t) * load_block);
- add(reg_aux_ptr_dst, sizeof(uint8_t) * load_block);
- }
- dec(reg_ic_block);
- jnz(ic_block_label, T_NEAR);
-
- L(end_label);
- postamble();
-}
-
-/// DST TRANSFORMS /////////////////////////////////////////////////////////////
-struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t: public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(
- jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t)
-
- jit_conv_conf_2x3_wino_t jcp;
- const primitive_attr_t &attr_;
-
- struct call_params_t {
- const void *wino_dst;
- const void *dst;
- const void *v_y_masks;
- const void *v_x_masks;
-
- const void *bias;
- const void *scales;
- };
- void (*ker_)(const call_params_t *);
-
- jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t(
- jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
- : jcp(ajcp), attr_(attr) {
- generate();
- ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode()));
- }
-
- void generate();
- bool maybe_relu(int position);
-
- Zmm vreg_inp(int i) { // 16
- assert(i < jcp.alpha * jcp.alpha);
- return Zmm(31 - i);
- }
- Zmm vreg_stg(int id) { // 8
- const int id_reg_stg = jcp.alpha * jcp.alpha + id;
- assert(id < 8);
- return Zmm(31 - id_reg_stg);
- }
- Zmm vreg_out(int id) { // 4
- const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
- assert(id < 4);
- return Zmm(31 - id_reg_out);
- }
- Xmm xmm_out(int id) { // 4
- const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
- assert(id < 4);
- return Xmm(31 - id_reg_out);
- }
- Zmm vreg_tmp(int id) { // 2
- const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id;
- assert(id < 2);
- return Zmm(31 - id_reg_tmp);
- }
-
- Zmm vreg_zero = Zmm(0);
- Zmm vreg_bias = Zmm(1);
- Zmm vreg_prev_dst = Zmm(2);
- Zmm zmm_bias_alpha = Zmm(2);
- Xmm xmm_bias_alpha = Xmm(2);
-
- Opmask y_mask = Opmask(1);
- Opmask r_mask = Opmask(2);
- Opmask x_mask(int id) {
- assert(id < 4);
- return Opmask(3 + id);
- }
-
- Reg64 reg_scratch_bias_alpha = r15;
-
- Reg64 reg_ptr_src = r14;
- Reg64 reg_ptr_dst = r13;
-
- Reg64 reg_ptr_v_y_masks = r12;
- Reg64 reg_ptr_v_x_masks = r11;
-
- Reg64 reg_aux_ptr_src = r10;
- Reg64 reg_aux_ptr_dst = r9;
-
- Reg64 reg_oc_block = r8;
-
- Reg64 reg_ptr_bias = rbx;
- Reg64 reg_ptr_scales = abi_not_param1;
- Reg64 reg_ptr_sum_scale = rdx;
-};
-
-bool jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::maybe_relu(int position) {
- using namespace primitive_kind;
- const auto &p = attr_.post_ops_;
-
- if (position == 0) {
- /* relu before sum */
- return false
- || p.contain(eltwise, 0)
- || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
- } else if (position == 1) {
- /* relu after sum */
- const int sum_idx = p.contain(sum, 0)
- ? 0 : (p.contain(sum, 1) ? 1 : -1);
- if (sum_idx == -1)
- return false;
-
- return false
- || p.contain(eltwise, sum_idx + 1)
- || jcp.dst_dt == data_type::u8;
- }
-
- return false;
-}
-
-void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() {
- Label oc_block_label;
-
- auto loop_body = [=]() {
- const auto &p = attr_.post_ops_;
- const int sum_idx = p.find(primitive_kind::sum);
- const float *p_sum_scale = (sum_idx != -1)
- ? &p.entry_[sum_idx].sum.scale
- : nullptr;
- if (p_sum_scale && *p_sum_scale != 1.f)
- mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
-
- for(int i = 0; i < 16; i++) {
- int internal_offset = sizeof(int32_t) * jcp.out_stride * i;
- vmovups(vreg_inp(i),
- EVEX_compress_addr(reg_aux_ptr_src, internal_offset));
- }
- for(int y = 0; y < jcp.alpha; y++) {
- vpaddd(vreg_tmp(0), vreg_inp(y*4 + 0), vreg_inp(y*4 + 1));
- vpaddd(vreg_stg(y*2), vreg_tmp(0), vreg_inp(y*4 + 2));
-
- vpsubd(vreg_tmp(1), vreg_inp(y*4 + 1), vreg_inp(y*4 + 2));
- vpsubd(vreg_stg(y*2+1), vreg_tmp(1), vreg_inp(y*4 + 3));
- }
- for(int x = 0; x < jcp.m; x++) {
- vpaddd(vreg_tmp(0), vreg_stg(x), vreg_stg(x+2*1));
- vpaddd(vreg_out(x), vreg_tmp(0), vreg_stg(x+2*2));
-
- vpsubd(vreg_tmp(1), vreg_stg(x+2*1), vreg_stg(x+2*2));
- vpsubd(vreg_out(x+2), vreg_tmp(1), vreg_stg(x+2*3));
- }
-
-
- if (jcp.with_bias) {
- vmovq(xmm_bias_alpha, reg_scratch_bias_alpha);
- vbroadcastss(zmm_bias_alpha, xmm_bias_alpha);
-
- auto bias_addr = ptr [ reg_ptr_bias ];
- switch (jcp.bia_dt) {
- case data_type::f32:
- case data_type::s32: vmovups(vreg_bias, bias_addr); break;
- case data_type::s8: vpmovsxbd(vreg_bias, bias_addr); break;
- case data_type::u8: vpmovzxbd(vreg_bias, bias_addr); break;
- default: assert(!"unsupported dst data type");
- }
- if (jcp.bia_dt != data_type::f32)
- vcvtdq2ps(vreg_bias, vreg_bias);
- vmulps(vreg_bias, vreg_bias, zmm_bias_alpha); // *alpha
- }
- for(int y = 0; y < jcp.m; y++) {
- kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(uint16_t) * y ]);
- for(int x = 0; x < jcp.m; x++) {
- kandw(r_mask, y_mask, x_mask(x));
-
- int i = y * jcp.m + x;
- int offset = jcp.typesize_out *
- (y * jcp.ow * jcp.oc + x * jcp.oc);
- Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset);
-
- Zmm zmm = vreg_out(i);
- Xmm xmm = xmm_out(i);
- vcvtdq2ps(zmm, zmm);
- if (jcp.with_bias)
- vaddps(zmm, zmm, vreg_bias);
- vmulps(zmm, zmm, ptr [reg_ptr_scales]);
- if (maybe_relu(0))
- vmaxps(zmm, vreg_zero, zmm);
- if (p_sum_scale) { // post_op: sum
- vpxord(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst);
- switch (jcp.dst_dt) {
- case data_type::f32:
- case data_type::s32:
- vmovups(vreg_prev_dst | r_mask, addr); break;
- case data_type::s8:
- vpmovsxbd(vreg_prev_dst | r_mask, addr); break;
- case data_type::u8:
- vpmovzxbd(vreg_prev_dst | r_mask, addr); break;
- default: assert(!"unknown dst_dt");
- }
- if (jcp.dst_dt != data_type::f32)
- vcvtdq2ps(vreg_prev_dst, vreg_prev_dst);
- if (*p_sum_scale == 1.f)
- vaddps(zmm, vreg_prev_dst);
- else
- vfmadd231ps(zmm, vreg_prev_dst,
- zword_b[reg_ptr_sum_scale]);
- }
- if (maybe_relu(1))
- vmaxps(zmm, vreg_zero, zmm);
- if (jcp.dst_dt != data_type::f32)
- vcvtps2dq(zmm, zmm);
- switch (jcp.dst_dt) {
- case data_type::f32:
- case data_type::s32:
- vmovups(addr, zmm | r_mask); break;
- case data_type::s8:
- vpmovsdb(xmm, zmm); vmovups(addr, xmm | r_mask); break;
- case data_type::u8:
- vpmovusdb(xmm, zmm); vmovups(addr, xmm | r_mask); break;
- default: assert(!"unknown dst_dt");
- }
- }
- }
- };
-
- preamble();
-
-# define READ_PARAM(reg, field) \
- mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
- READ_PARAM(reg_ptr_src, wino_dst);
- READ_PARAM(reg_ptr_dst, dst);
- READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
- READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
- READ_PARAM(reg_ptr_bias, bias);
- READ_PARAM(reg_ptr_scales, scales);
-# undef READ_PARAM
-
- if (jcp.with_bias)
- mov(reg_scratch_bias_alpha, float2int(adj_src_scale * adj_wei_scale));
-
- mov(reg_aux_ptr_src, reg_ptr_src);
- mov(reg_aux_ptr_dst, reg_ptr_dst);
-
- vpxord(vreg_zero, vreg_zero, vreg_zero);
-
- for (int i = 0; i < jcp.m; i++)
- kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]);
-
- int oc_blocks = jcp.oc / load_block;
- mov(reg_oc_block, oc_blocks);
- L(oc_block_label); {
- loop_body();
- add(reg_aux_ptr_src, sizeof(int32_t) * load_block);
- add(reg_aux_ptr_dst, jcp.typesize_out * load_block);
-
- add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block);
- add(reg_ptr_bias, sizeof(jcp.typesize_bia) * load_block);
- }
- dec(reg_oc_block);
- jnz(oc_block_label, T_NEAR);
-
- postamble();
-
-}
-
-/// GEMM kernel ////////////////////////////////////////////////////////////////
-struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t: public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t)
- jit_conv_conf_2x3_wino_t jcp;
- const primitive_attr_t &attr_;
-
- struct call_params_t {
- const void *src;
- const void *dst;
- const void *wei;
- const void *dst_b;
- };
- void (*ker_)(const call_params_t *);
-
- void generate();
- static bool post_ops_ok(jit_conv_conf_2x3_wino_t &jcp,
- const primitive_attr_t &attr);
-
- jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t(
- jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
- : jcp(ajcp), attr_(attr)
- {
- generate();
- ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode()));
- }
-
- static status_t init_conf(
- jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
- memory_desc_t &src_md, memory_desc_t &weights_md,
- memory_desc_t &dst_md, memory_desc_t &bias_md,
- const primitive_attr_t &attr);
-
- Zmm vreg_out(int n, int m) {
- const int id_reg_out = n * jcp.m_block + m;
- assert(id_reg_out < jcp.n2_block * jcp.m_block);
- return Zmm(31 - id_reg_out);
- }
- Zmm vreg_wei(int i) {
- assert(31 - jcp.n2_block * jcp.m_block - i
- > (jcp.ver == ver_vnni ? 0 : 2));
- return Zmm(31 - jcp.n2_block * jcp.m_block - i);
- }
-
- Zmm vreg_src = Zmm(0);
- Zmm vreg_one = Zmm(1);
- Zmm vreg_tmp = Zmm(2);
-
- Reg64 reg_ptr_src = r15;
-
- Reg64 reg_aux_dst_b = r13;
- Reg64 reg_aux_dst = r12;
- Reg64 reg_aux_dst2 = r11;
- Reg64 reg_aux_wei = r10;
- Reg64 reg_aux_wei2 = r9;
- Reg64 reg_aux_src = r8;
- Reg64 reg_aux_src2 = rax;
- Reg64 reg_mb = rbx;
- Reg64 reg_nnb = abi_not_param1;
- Reg64 reg_scratch = rdx;
- Reg64 reg_K = rsi;
-};
-
-bool jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::post_ops_ok(
- jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) {
- using namespace primitive_kind;
- const auto &p = attr.post_ops_;
-
- auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
-
- switch (p.len_) {
- case 0: return true;
- case 1: return is_relu(0) || p.contain(sum, 0);
- case 2: return (p.contain(sum, 0) && is_relu(1)) ||
- (p.contain(sum, 1) && is_relu(0));
- case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2);
- default: return false;
- }
-
- return false;
-}
-
-void jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::generate() {
- Label nnb_loop_label, K_loop_label, mb_loop_label;
-
- auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
- if (jcp.ver == ver_vnni) {
- vpdpbusd(vreg_acc, vreg_src, vreg_wei);
- } else {
- vpmaddubsw(vreg_tmp, vreg_src, vreg_wei);
- vpmaddwd(vreg_tmp, vreg_tmp, vreg_one);
- vpaddd(vreg_acc, vreg_acc, vreg_tmp);
- }
- };
-
- preamble();
-# define READ_PARAM(reg, field) \
- mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
- READ_PARAM(reg_ptr_src, src);
- READ_PARAM(reg_aux_dst, dst);
- READ_PARAM(reg_aux_wei, wei);
- READ_PARAM(reg_aux_dst_b, dst_b);
-# undef READ_PARAM
-
- if (jcp.ver != ver_vnni) {
- xor_(reg_scratch, reg_scratch);
- Reg16 _t = reg_scratch.cvt16();
- mov(_t, 0x1);
- vpbroadcastw(vreg_one, _t);
- }
-
- if (!jcp.small_mb) {
- mov(reg_nnb, jcp.n_chunks);
- L(nnb_loop_label);
- }
- mov(reg_aux_dst2, reg_aux_dst);
- mov(reg_aux_src, reg_ptr_src);
- mov(reg_mb, jcp.M / jcp.m_block);
- L(mb_loop_label);
- {
- for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
- for (int m = 0; m < jcp.m_block; m++) {
- int offset = jcp.typesize_acc * nb2 * jcp.n_block;
- vmovups(vreg_out(nb2, m),
- EVEX_compress_addr(reg_aux_dst_b, offset));
- }
- }
- mov(reg_aux_src2, reg_aux_src);
- mov(reg_aux_wei2, reg_aux_wei);
- mov(reg_K, jcp.k_chunks);
- L(K_loop_label);
- {
- for (int k = 0; k < jcp.k2_block; k += 4) {
- for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
- int wei_offset
- = jcp.typesize_in * (nb2 * jcp.n_block * jcp.K);
- vmovups(vreg_wei(nb2),
- EVEX_compress_addr(reg_aux_wei2, wei_offset));
- }
- for (int m = 0; m < jcp.m_block; m++) {
- int inp_offset = jcp.typesize_in * m * jcp.K;
- vpbroadcastd(vreg_src,
- EVEX_compress_addr(reg_aux_src2, inp_offset));
- for (int nb2 = 0; nb2 < jcp.n2_block; nb2++)
- compute(vreg_out(nb2, m), vreg_wei(nb2), vreg_src);
- }
- add(reg_aux_src2, jcp.typesize_in * 4);
- add(reg_aux_wei2, jcp.typesize_in * 4 * jcp.n_block);
- }
- }
- dec(reg_K);
- jnz(K_loop_label, T_NEAR);
-
- for (int m = 0; m < jcp.m_block; m++) {
- for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
- int offset = jcp.typesize_acc * (m * jcp.N + nb2 * jcp.n_block);
- vmovups(EVEX_compress_addr(reg_aux_dst2, offset),
- vreg_out(nb2, m));
- }
- }
- add(reg_aux_src, jcp.typesize_in * jcp.m_block * jcp.K);
- add(reg_aux_dst2, jcp.typesize_acc * jcp.m_block * jcp.N);
- }
- dec(reg_mb);
- jnz(mb_loop_label, T_NEAR);
-
- if (!jcp.small_mb) {
- add(reg_aux_dst, jcp.typesize_acc * jcp.n2_block * jcp.n_block);
- add(reg_aux_dst_b, jcp.typesize_acc * jcp.n2_block * jcp.n_block);
- add(reg_aux_wei, jcp.typesize_in * jcp.n2_block * jcp.n_block * jcp.K);
-
- dec(reg_nnb);
- jnz(nnb_loop_label, T_NEAR);
- }
-
- postamble();
-}
-namespace {
-bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) {
- if (jcp.ver == ver_vnni) {
- return (jcp.mb <= mkldnn_get_max_threads()
- && (jcp.mb > 4
- && jcp.ic > 64
- && !(jcp.oc > 128 && jcp.ih < 14)))
- || jcp.mb > mkldnn_get_max_threads();
- }
- return true;
-}
-}
-
-status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
-::init_conf(jit_conv_conf_2x3_wino_t &jcp,
- const convolution_desc_t &cd, memory_desc_t &src_md,
- memory_desc_t &wei_md, memory_desc_t &dst_md,
- memory_desc_t &bias_md, const primitive_attr_t &attr) {
- const memory_desc_wrapper src_d(&src_md);
- const memory_desc_wrapper wei_d(&wei_md);
- const memory_desc_wrapper dst_d(&dst_md);
- const memory_desc_wrapper bias_d(&bias_md);
-
- const bool with_groups = wei_d.ndims() == src_d.ndims() + 1;
-
- jcp.nthr = mkldnn_get_max_threads();
-
- jcp.ngroups = with_groups ? wei_d.dims()[0] : 1;
- jcp.mb = src_d.dims()[0];
- jcp.oc = dst_d.dims()[1] / jcp.ngroups;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
- jcp.ih = src_d.dims()[2];
- jcp.iw = src_d.dims()[3];
- jcp.oh = dst_d.dims()[2];
- jcp.ow = dst_d.dims()[3];
- jcp.kh = wei_d.dims()[with_groups + 2];
- jcp.kw = wei_d.dims()[with_groups + 3];
- jcp.t_pad = cd.padding[0][0];
- jcp.b_pad = cd.padding[1][0];
- jcp.l_pad = cd.padding[0][1];
- jcp.r_pad = cd.padding[1][1];
- jcp.stride_h = cd.strides[0];
- jcp.stride_w = cd.strides[1];
- jcp.dilate_h = cd.dilates[0];
- jcp.dilate_w = cd.dilates[1];
-
- jcp.ver = ver_avx512_core;
- if (!(mayiuse(avx512_core) &&
- src_d.data_type() == data_type::u8
- && wei_d.data_type() == data_type::s8
- && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
- data_type::s8, data_type::u8)))
- return status::unimplemented;
- if (mayiuse(avx512_core_vnni))
- jcp.ver = ver_vnni;
-
- if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
- is_winograd_faster_than_direct(jcp)))
- return status::unimplemented;
-
- // block sizes needed for GEMM kernel
- jcp.ic_block = 4;
- jcp.oc_block = 16;
-
- bool ok = true
- && jcp.ngroups == 1
- && jcp.oc % load_block == 0 && jcp.ic % load_block == 0
- && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0
- && everyone_is(3, jcp.kh, jcp.kw)
- && everyone_is(1, jcp.stride_h, jcp.stride_w)
- && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
- && jcp.t_pad == jcp.b_pad && jcp.l_pad == jcp.r_pad
- && one_of(jcp.t_pad, 0, 1)
- && one_of(jcp.l_pad, 0, 1);
- if (!ok) return status::unimplemented;
-
- jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
-
- if (!post_ops_ok(jcp, attr))
- return status::unimplemented;
-
- jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
- jcp.dst_dt = cd.dst_desc.data_type;
-
- jcp.typesize_in = types::data_type_size(src_d.data_type());
- jcp.typesize_out = types::data_type_size(dst_d.data_type());
- jcp.typesize_acc = sizeof(int32_t);
- jcp.typesize_bia = jcp.with_bias
- ? types::data_type_size(bias_d.data_type())
- : 0;
-
- jcp.nb_oc = jcp.oc / jcp.oc_block;
- jcp.nb_ic = jcp.ic / jcp.ic_block;
-
- jcp.m = 2;
- jcp.r = 3;
- jcp.alpha = jcp.m + jcp.r - 1;
-
- int aa = jcp.alpha * jcp.alpha;
- int L1_cap = get_cache_size(1, true);
- int L2_cap = get_cache_size(2, true);
- // need 1 extra reg for bcast, and 2 tmp regs for non-vnni
- int free_regs = jcp.ver == ver_vnni ? 31 : 29;
-
- auto get_thr_eff = [&](int small_mb, int ix, int iy, int n2_b) {
- float thr_eff;
- float Z = (float)jcp.ic + jcp.oc;
- float Y = (float)jcp.ic * jcp.oc;
- if (small_mb == 0) { // outer par
- int nblocks = jcp.mb * div_up(jcp.oh, iy) * div_up(jcp.ow, ix);
- thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr);
- } else { // inner par
- int tranw = iy * ix / jcp.alpha;
- int gemmw = aa * (jcp.nb_oc / n2_b);
- int tranw_r = rnd_up(tranw, jcp.nthr);
- int gemmw_r = rnd_up(gemmw, jcp.nthr);
- thr_eff = (Z * tranw / tranw_r + Y * gemmw / gemmw_r) / (Z + Y);
- }
- return thr_eff;
- };
-
- auto get_mem_eff = [&](int small_mb, int ix, int iy, int n2_b) {
- float mem_eff, req_mem;
- int M = ix * iy / jcp.alpha;
- if (small_mb == 0) { // outer parallelization strategy
- // memory for wino transforms (other memory has poor reuse)
- req_mem = (float)aa * M * (jcp.ic + jcp.typesize_acc * jcp.oc);
- mem_eff = req_mem < L1_cap ? 1.f : req_mem < L2_cap ? 0.5f : 0.f;
- } else { // inner parallelization strategy
- // memory used during gemm
- int N = jcp.oc_block * n2_b;
- req_mem = (float)jcp.ic * (M + N) + jcp.typesize_acc * M * N;
- mem_eff = nstl::min(1.f, L2_cap / req_mem);
- // memory used during wino transforms
- int M_per_thr = div_up(M, jcp.nthr);
- req_mem = (float)aa * M_per_thr
- * (jcp.ic + jcp.typesize_acc * jcp.oc);
- if (req_mem > L2_cap)
- mem_eff = 0.1f;
- }
- return mem_eff;
- };
-
- auto get_tot_eff = [&](int small_mb, float thr_eff, float work_eff,
- float mem_eff, float reg_eff) {
- // these coefficients are chosen empirically
- float mem_fac = 0.1f, reg_fac = 0.2f;
- // normalized overhead relative to memory and register components
- float tot_eff = 1.f + mem_fac * mem_eff + reg_fac * reg_eff;
- // thread and work components affect all others
- tot_eff *= thr_eff * work_eff;
- return tot_eff;
- };
-
- auto find_m_n2_blocks = [&](bool small_mb, int ix, int iy, float work_eff,
- int &m_block, int &n2_block, float &tot_eff) {
- int M = (ix * iy) / jcp.alpha;
- int max_m_block = nstl::min(M, free_regs);
- int max_n2_block = nstl::min(jcp.nb_oc, free_regs);
- tot_eff = 0.f;
- for (int im = max_m_block; im > 0; im--) {
- if (M % im)
- continue;
- for (int in2 = max_n2_block; in2 > 0; in2--) {
- int used_regs = (im + 1) * in2;
- float mem_eff = get_mem_eff(small_mb, ix, iy, in2);
- float reg_eff = (float)(im * in2) / (im + in2);
- float thr_eff = get_thr_eff(small_mb, ix, iy, in2);
- float cur_tot_eff = get_tot_eff(
- small_mb, thr_eff, work_eff, mem_eff, reg_eff);
- if (jcp.nb_oc % in2 || used_regs > free_regs
- || cur_tot_eff <= tot_eff)
- continue;
- tot_eff = cur_tot_eff;
- m_block = im;
- n2_block = in2;
- }
- }
- };
-
- /* Selecting xb and yb blocking */
- int min_yb = jcp.m;
- int min_xb = jcp.m;
- int max_yb = nstl::max(min_yb, rnd_up(jcp.oh, 2));
- int max_xb = nstl::max(min_xb, rnd_up(jcp.ow, 2));
- float best_eff = 0.f;
- for (int ix = min_xb; ix <= max_xb; ix += 2) {
- assert(rnd_up(jcp.ow, ix) >= jcp.iw - 2);
- for (int iy = max_yb; iy >= min_yb; iy -= 2) {
- assert(rnd_up(jcp.oh, iy) >= jcp.ih - 2);
-
- int m_b[2];
- int n2_b[2];
- bool small_mb;
- float inner_eff, outer_eff, work_eff;
-
- int tiled_area = rnd_up(jcp.oh, iy) * rnd_up(jcp.ow, ix);
- work_eff = (float)jcp.oh * jcp.ow / tiled_area;
- if (best_eff > 0.f && work_eff < 4.f / 9.f)
- continue; // no gain from Winograd transformation
-
- /* outer parallelization */
- find_m_n2_blocks(0, ix, iy, work_eff, m_b[0], n2_b[0], outer_eff);
-
- /* inner parallelization */
- find_m_n2_blocks(1, ix, iy, work_eff, m_b[1], n2_b[1], inner_eff);
-
- small_mb = inner_eff > outer_eff;
- float eff = small_mb ? inner_eff : outer_eff;
- if (eff > best_eff) {
- best_eff = eff;
- jcp.yb = iy;
- jcp.xb = ix;
- jcp.m_block = m_b[small_mb];
- jcp.n2_block = n2_b[small_mb];
- jcp.small_mb = small_mb;
- }
- }
- }
-
- assert((jcp.m_block + 1) * jcp.n2_block <= free_regs);
- assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0);
-
- jcp.mb_block = 1;
- if (jcp.small_mb) {
- // For small mb harness, set mb_block as large as possible subject to
- // the constraint that winograd activations fit into available L3 cache
- int L3_cap = get_cache_size(3, true);
- int M = jcp.xb * jcp.yb / 4;
- int wino_src_size = 16 * M * jcp.ic * jcp.typesize_in;
- int wino_dst_size = 16 * M * jcp.oc * jcp.typesize_acc;
- int max_mb_block = nstl::min(
- jcp.mb, jcp.nthr * L3_cap / (wino_src_size + wino_dst_size));
- for (int i = max_mb_block; i > 1; i--) {
- if (jcp.mb % i == 0) {
- jcp.mb_block = i;
- break;
- }
- }
- }
- jcp.nb_mb = jcp.mb / jcp.mb_block;
-
- jcp.M = jcp.mb_block * jcp.xb * jcp.yb / 4;
- jcp.N = jcp.oc;
- jcp.K = jcp.ic;
-
- jcp.inp_stride = jcp.M * jcp.ic;
- jcp.out_stride = jcp.M * jcp.oc;
- jcp.wei_stride = jcp.ic * jcp.oc;
- jcp.bia_stride = jcp.oc;
-
- jcp.n_block = jcp.oc_block;
- jcp.k_block = jcp.ic_block;
-
- jcp.n_chunks = (jcp.N / jcp.n_block) / jcp.n2_block;
-
- // We need jcp.k2_block to be a multiple of jcp.k_block = jcp.ic_block = 4
- // and jcp.K = jcp.ic to be a multiple of jcp.k2_block. Since jcp.ic is
- // a multiple of load_block = 16, we just use that for now.
- jcp.k2_block = load_block;
- jcp.k_chunks = jcp.K / jcp.k2_block;
-
- const auto &oscales = attr.output_scales_;
- jcp.is_oc_scale = oscales.mask_ == 1 << 1;
- assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
-
- /* re-create weights primitive descriptor
- and set weights wino_blocking */
- memory_desc_t expect_wei_md = wei_md;
-
- expect_wei_md.format_kind = format_kind::wino;
- expect_wei_md.data_type = data_type::s8;
- mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc;
- wd.wino_format = mkldnn_wino_wei_aaOIoi;
- wd.r = jcp.r;
- wd.alpha = jcp.alpha;
- wd.ic = jcp.ic;
- wd.oc = jcp.oc;
- wd.ic_block = jcp.ic_block;
- wd.oc_block = jcp.oc_block;
- wd.oc2_block = jcp.n2_block;
- wd.ic2_block = 1;
- wd.adj_scale = adj_wei_scale;
-
- size_t max_size = types::data_type_size(data_type::s8) *
- jcp.alpha * jcp.alpha * jcp.ic * jcp.oc;
- max_size += types::data_type_size(data_type::s32) *
- jcp.alpha * jcp.alpha * jcp.oc;
- wd.size = max_size;
-
- if (wei_md.format_kind == format_kind::any)
- wei_md = expect_wei_md;
- if (wei_md != expect_wei_md)
- return status::unimplemented;
-
- const int tilesize = jcp.alpha * jcp.alpha;
- const int numtiles = jcp.M;
- const int alltiles = numtiles * tilesize;
-
- jcp.size_wino_src
- = utils::rnd_up(jcp.typesize_in * alltiles * jcp.ic, PAGE_4K)
- / jcp.typesize_in;
- jcp.size_wino_wei = tilesize * jcp.oc * jcp.ic;
- jcp.size_wino_dst = alltiles * jcp.oc;
-
- return status::success;
-}
-////////////////////////////////////////////////////////////////////////////////
-
-template <data_type_t dst_data_type>
-status_t jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
- pd_t::jit_conf() {
- return jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::init_conf(
- jcp_, *this->desc(), this->src_md_, this->weights_md_,
- this->dst_md_,this->bias_md_, *this->attr());
-}
-
-template <data_type_t dst_data_type>
-void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::pd_t::
-init_scratchpad() {
- auto scratchpad = this->scratchpad_registry().registrar();
-
- int nthr_multiplier = jcp_.small_mb ? 1 : jcp_.nthr;
- scratchpad.book(key_wino_V,
- sizeof(src_data_t) * jcp_.size_wino_src * nthr_multiplier, PAGE_4K);
- scratchpad.book(key_wino_M,
- sizeof(acc_data_t) * jcp_.size_wino_dst * nthr_multiplier, PAGE_4K);
-
- dim_t scale_count = attr()->output_scales_.count_;
- scratchpad.book(key_conv_adjusted_scales,
- sizeof(float) * nstl::max<dim_t>(scale_count, 16));
-}
-
-template <data_type_t dst_data_type>
-jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
- jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd)
- : cpu_primitive_t(apd)
-{
- kernel_ = new jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t(
- pd()->jcp_, *pd()->attr());
- src_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_src_trans_t(
- pd()->jcp_, *pd()->attr());
- dst_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t(
- pd()->jcp_, *pd()->attr());
-}
-
-template <data_type_t dst_data_type>
-jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
- ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t() {
- delete kernel_;
- delete src_trans_;
- delete dst_trans_;
-}
-
-template <data_type_t dst_data_type>
-const float *jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
-adjust_oscales(const memory_tracking::grantor_t &scratchpad) const {
- const float *oscales = pd()->attr()->output_scales_.scales_;
- auto loc_scales = scratchpad.template get<float>(key_conv_adjusted_scales);
- size_t count = pd()->attr()->output_scales_.count_;
- float factor = 1.f / (adj_src_scale * adj_wei_scale);
- if (count == 1)
- utils::array_set(loc_scales, oscales[0] * factor, 16);
- else
- for (size_t c = 0; c < count; c++) loc_scales[c] = oscales[c] * factor;
- return loc_scales;
-}
-
-template <data_type_t dst_data_type>
-void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
-execute_forward(const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
-
- const auto &jcp = kernel_->jcp;
- if (jcp.small_mb)
- execute_forward_small_mb(src, weights, bias, dst, this->scratchpad(ctx));
- else
- execute_forward_mbN(src, weights, bias, dst, this->scratchpad(ctx));
-}
-
-template <data_type_t dst_data_type>
-void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
-execute_forward_mbN(const src_data_t *src, const wei_data_t *wei,
- const char *bia, dst_data_t *dst,
- const memory_tracking::grantor_t &scratchpad) const {
- const auto &jcp = kernel_->jcp;
- const float *oscales = adjust_oscales(scratchpad);
-
- auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei);
- auto wino_src_base = scratchpad.template get<src_data_t>(key_wino_V);
- auto wino_dst_base = scratchpad.template get<acc_data_t>(key_wino_M);
-
- parallel_nd(jcp.mb, div_up(jcp.oh, jcp.yb), div_up(jcp.ow, jcp.xb),
- [&](int mb, int tile_y_b, int tile_x_b) {
-
- int tile_y = tile_y_b * jcp.yb;
- int tile_x = tile_x_b * jcp.xb;
-
- int ithr = mkldnn_get_thread_num();
- auto wino_src = wino_src_base + jcp.size_wino_src * ithr;
- auto wino_dst = wino_dst_base + jcp.size_wino_dst * ithr;
-
- auto src_trans_p =
- jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t();
- auto dst_trans_p =
- jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t();
- auto gemm_p =
- jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::call_params_t();
-
- /* transformation of input tensor to winograd domain */
- for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
- for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) {
- uint16_t v_y_masks[4], v_x_masks[4];
-
- int y = y_in_block + tile_y;
- int x = x_in_block + tile_x;
- int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
-
- int v_ys = nstl::max(0, jcp.t_pad - y);
- int v_ye = nstl::min(jcp.alpha,
- nstl::max(0, jcp.ih + jcp.t_pad - y));
-
- int v_xs = nstl::max(0, jcp.l_pad - x);
- int v_xe = nstl::min(jcp.alpha,
- nstl::max(0, jcp.iw + jcp.l_pad - x));
-
-#pragma unroll(4)
- for (int i = 0; i < jcp.alpha; i++) {
- v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff);
- v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff);
- }
- auto local_s = src
- + mb * jcp.ih * jcp.iw * jcp.ic
- + y * jcp.iw * jcp.ic + x * jcp.ic;
- auto local_w = wino_src + m * jcp.ic;
-
- src_trans_p.src = local_s;
- src_trans_p.wino_src = local_w;
- src_trans_p.v_y_masks = v_y_masks;
- src_trans_p.v_x_masks = v_x_masks;
-
- src_trans_->ker_(&src_trans_p);
- }
- }
- /* gemms */
- for (int tile_ij = 0; tile_ij < 16; tile_ij++) {
- // start threads at different GEMMs to help bring weights into LLC
- int offset = (tile_ij + ithr) % 16;
- gemm_p.src = wino_src + jcp.inp_stride * offset;
- gemm_p.dst = wino_dst + jcp.out_stride * offset;
- gemm_p.wei = wei + jcp.wei_stride * offset;
- gemm_p.dst_b = dst_bias + jcp.bia_stride * offset;
-
- kernel_->ker_(&gemm_p);
- }
-
- /* transformation from winograd domain to output tensor */
- for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
- for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) {
- uint16_t v_y_masks[2], v_x_masks[2];
-
- int y = y_in_block + tile_y;
- int x = x_in_block + tile_x;
- int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
-
-#pragma unroll(2)
- for (int i = 0; i < jcp.m; i++) {
- v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0);
- v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0);
- }
- auto local_d = dst
- + mb * jcp.oh * jcp.ow * jcp.oc
- + y * jcp.ow * jcp.oc + x * jcp.oc;
- auto local_w = wino_dst + m * jcp.oc;
-
- auto scales = oscales;
- dst_trans_p.dst = local_d;
- dst_trans_p.wino_dst = local_w;
- dst_trans_p.v_y_masks = v_y_masks;
- dst_trans_p.v_x_masks = v_x_masks;
-
- dst_trans_p.scales = scales;
- dst_trans_p.bias = bia;
-
- dst_trans_->ker_(&dst_trans_p);
- }
- }
- });
-}
-
-template <data_type_t dst_data_type>
-void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
-execute_forward_small_mb(const src_data_t *src, const wei_data_t *wei,
- const char *bia, dst_data_t *dst,
- const memory_tracking::grantor_t &scratchpad) const {
- const auto &jcp = kernel_->jcp;
- const float *oscales = adjust_oscales(scratchpad);
-
- auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei);
- auto wino_src = scratchpad.template get<src_data_t>(key_wino_V);
- auto wino_dst = scratchpad.template get<acc_data_t>(key_wino_M);
-
- for (int mbb = 0; mbb < jcp.nb_mb; mbb++) {
- for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) {
- for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) {
- /* transformation of input tensor to winograd domain */
- parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block,
- [&](int y_in_block_b, int x_in_block_b, int mb) {
- int y_in_block = y_in_block_b * 2;
- int x_in_block = x_in_block_b * 2;
-
- auto src_trans_p =
- jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t();
-
- uint16_t v_y_masks[4], v_x_masks[4];
-
- int y = y_in_block + tile_y;
- int x = x_in_block + tile_x;
- int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2)
- + (x_in_block / 2);
-
- int v_ys = nstl::max(0, jcp.t_pad - y);
- int v_ye = nstl::min(
- jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y));
-
- int v_xs = nstl::max(0, jcp.l_pad - x);
- int v_xe = nstl::min(
- jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x));
-
-#pragma unroll(4)
- for (int i = 0; i < jcp.alpha; i++) {
- v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff);
- v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff);
- }
- auto local_s = src
- + (mbb * jcp.mb_block + mb) * jcp.ih * jcp.iw * jcp.ic
- + y * jcp.iw * jcp.ic + x * jcp.ic;
- auto local_w = wino_src + m * jcp.ic;
-
- src_trans_p.src = local_s;
- src_trans_p.wino_src = local_w;
- src_trans_p.v_y_masks = v_y_masks;
- src_trans_p.v_x_masks = v_x_masks;
-
- src_trans_->ker_(&src_trans_p);
- });
-
- /* gemms */
- parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) {
- auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::
- call_params_t();
-
- gemm_p.src = wino_src + jcp.inp_stride * tile_ij;
- gemm_p.dst = wino_dst + jcp.out_stride * tile_ij
- + nnb * jcp.n2_block * jcp.n_block;
- gemm_p.wei = wei + jcp.wei_stride * tile_ij
- + nnb * jcp.n2_block * jcp.n_block * jcp.K;
- gemm_p.dst_b = dst_bias + jcp.bia_stride * tile_ij
- + nnb * jcp.n2_block * jcp.n_block;
-
- kernel_->ker_(&gemm_p);
- });
-
- /* transformation from winograd domain to output tensor */
- parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block,
- [&](int y_in_block_b, int x_in_block_b, int mb) {
- int y_in_block = y_in_block_b * 2;
- int x_in_block = x_in_block_b * 2;
-
- auto dst_trans_p =
- jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t();
-
- uint16_t v_y_masks[2], v_x_masks[2];
-
- int y = y_in_block + tile_y;
- int x = x_in_block + tile_x;
- int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2)
- + (x_in_block / 2);
-
-#pragma unroll(2)
- for (int i = 0; i < jcp.m; i++) {
- v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0);
- v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0);
- }
- auto local_d = dst
- + (mbb * jcp.mb_block + mb) * jcp.oh * jcp.ow * jcp.oc
- + y * jcp.ow * jcp.oc + x * jcp.oc;
- auto local_w = wino_dst + m * jcp.oc;
-
- auto scales = oscales;
- dst_trans_p.dst = local_d;
- dst_trans_p.wino_dst = local_w;
- dst_trans_p.v_y_masks = v_y_masks;
- dst_trans_p.v_x_masks = v_x_masks;
-
- dst_trans_p.scales = scales;
- dst_trans_p.bias = bia;
-
- dst_trans_->ker_(&dst_trans_p);
- });
- }}}
-}
-
-template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::s8>;
-template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::u8>;
-template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::s32>;
-template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::f32>;
-
-} // namespace cpu
-} // namespace impl
-} // namespace mkldnn
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp
deleted file mode 100644
index 9e6e57b051..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp
+++ /dev/null
@@ -1,128 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_AVX512_CORE_U8S8S32X_WINO_CONVOLUTION_HPP
-#define CPU_JIT_AVX512_CORE_U8S8S32X_WINO_CONVOLUTION_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_convolution_pd.hpp"
-#include "cpu_primitive.hpp"
-
-#include "jit_primitive_conf.hpp"
-#include "jit_generator.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t;
-struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t;
-struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t;
-
-template <data_type_t dst_data_type>
-struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t : public cpu_primitive_t {
- struct pd_t : public cpu_convolution_fwd_pd_t {
- pd_t(engine_t *engine, const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const typename pd_t::base_class *hint_fwd_pd)
- : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_()
- {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_int8_wino:", avx512_core, ""),
- jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>);
-
- status_t init() {
- bool ok = true
- && is_fwd()
- && utils::one_of(desc()->alg_kind,
- alg_kind::convolution_auto,
- alg_kind::convolution_winograd)
- && expect_data_types(data_type::u8, data_type::s8,
- data_type::undef, dst_data_type, data_type::s32)
- && IMPLICATION(with_bias(), utils::one_of(
- desc()->bias_desc.data_type, data_type::f32,
- data_type::s32, data_type::s8, data_type::u8))
- && !has_zero_dim_memory()
- && set_default_formats();
-
- if (!ok) return status::unimplemented;
-
- status_t status = jit_conf();
- if (status != status::success) return status;
- set_default_alg_kind(alg_kind::convolution_winograd);
-
- init_scratchpad();
-
- return status;
- }
-
- jit_conv_conf_2x3_wino_t jcp_;
-
- protected:
- status_t jit_conf();
- void init_scratchpad();
-
- bool set_default_formats() {
- using namespace format_tag;
- return set_default_formats_common(nhwc, any, nhwc);
- }
- };
-
- typedef typename prec_traits<data_type::u8>::type src_data_t;
- typedef typename prec_traits<data_type::s8>::type wei_data_t;
- typedef typename prec_traits<data_type::s32>::type acc_data_t;
- typedef typename prec_traits<dst_data_type>::type dst_data_t;
-
- jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd);
- ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t();
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- const float *adjust_oscales(const memory_tracking::grantor_t &scratchpad)
- const;
- void execute_forward(const exec_ctx_t &ctx) const;
- void execute_forward_small_mb(const src_data_t *src, const wei_data_t *wei,
- const char *bia, dst_data_t *dst,
- const memory_tracking::grantor_t &scratchpad) const;
- void execute_forward_mbN(const src_data_t *src, const wei_data_t *wei,
- const char *bia, dst_data_t *dst,
- const memory_tracking::grantor_t &scratchpad) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t *kernel_;
- jit_avx512_core_u8s8s32x_wino_conv_src_trans_t *src_trans_;
- jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t *dst_trans_;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp
deleted file mode 100644
index f4ec29ab00..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp
+++ /dev/null
@@ -1,820 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_memory.hpp"
-
-#include "jit_uni_1x1_conv_utils.hpp"
-#include "jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp"
-
-#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::utils;
-
-using namespace Xbyak;
-
-bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::maybe_eltwise(int position)
-{
- using namespace primitive_kind;
- const auto &p = attr_.post_ops_;
-
- if (position == 0) {
- /* eltwise before sum */
- return p.contain(eltwise, 0);
- } else if (position == 1) {
- /* eltwise after sum */
- return p.contain(sum, 0) && p.contain(eltwise, 1);
- }
-
- return false;
-}
-
-void jit_avx512_core_x8s8s32x_1x1_conv_kernel::bcast_loop(int load_loop_blk)
-{
- mov(aux1_reg_bcast_data, reg_bcast_data);
- mov(aux_reg_bcast_data, reg_bcast_data);
-
- mov(aux_reg_output_data, reg_output_data);
- mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_off));
-
- Label bcast_loop;
- Label bcast_loop_tail;
-
- cmp(bcast_loop_iter, jcp.ur);
- jl(bcast_loop_tail, T_NEAR);
-
- L(bcast_loop); {
- assert(jcp.bcast_block % jcp.ur == 0);
- int num_substeps = jcp.bcast_block / jcp.ur;
- assert(num_substeps > 0 && num_substeps < 10);
- for (int i = 0; i < num_substeps; i++) {
- reduce_loop(load_loop_blk, jcp.ur, i, false);
- if (i < num_substeps - 1) {
- add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
- add(aux_reg_output_data, jcp.bcast_loop_output_substep);
- }
- else {
- add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
- - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
- int output_offset = jcp.bcast_loop_output_step
- - (num_substeps - 1) * jcp.bcast_loop_output_substep;
-
- add(aux_reg_output_data, output_offset);
- }
- }
- sub(bcast_loop_iter, jcp.bcast_block);
- cmp(bcast_loop_iter, jcp.bcast_block);
- jge(bcast_loop, T_NEAR);
- }
-
- L(bcast_loop_tail);
- if (jcp.ur_tail) {
- Label bcast_loop_tail_out;
- cmp(bcast_loop_iter, 0);
- jz(bcast_loop_tail_out, T_NEAR);
- reduce_loop(load_loop_blk, jcp.ur_tail, 0, true);
- L(bcast_loop_tail_out);
- }
-}
-
-void jit_avx512_core_x8s8s32x_1x1_conv_kernel::cvt2ps(data_type_t type_in,
- zmm_t zmm_in, const Xbyak::Operand &op, bool mask_flag) {
- zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in;
- switch (type_in) {
- case data_type::f32:
- case data_type::s32: vmovups(zmm, op); break;
- case data_type::s8: vpmovsxbd(zmm, op); break;
- case data_type::u8: vpmovzxbd(zmm, op); break;
- default: assert(!"unsupported data type");
- }
- if (type_in != data_type::f32)
- vcvtdq2ps(zmm_in, zmm_in);
-}
-
-void jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
- int ur, int substep, bool wraparound)
-{
- auto vreg_load = [=](int i_load) {
- return Zmm(ur * load_loop_blk + i_load);
- };
-
- auto vreg_accum = [=](int i_load, int i_ur) {
- return Zmm(i_ur * load_loop_blk + i_load);
- };
-
- auto zmm_bias_alpha = [=]() {
- return Zmm(ur * load_loop_blk);
- };
-
- auto xmm_bias_alpha = [=]() {
- return Xmm(ur * load_loop_blk);
- };
- auto bias_ptr = [=](int i_load) {
- return EVEX_compress_addr(reg_bias_data,
- jcp.typesize_bia * jcp.oc_block * i_load);
- };
-
- auto comp_ptr = [=](int i_load) {
- return EVEX_compress_addr(reg_comp_data,
- sizeof(int32_t) * jcp.oc_block * i_load);
- };
-
- auto scale_ptr = [=](int i_load) {
- return EVEX_compress_addr(reg_ptr_scales,
- jcp.is_oc_scale * (sizeof(float) * jcp.oc_block * i_load));
- };
-
- auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) {
- assert(i_ur < jcp.ur);
- assert(i_reduce <= jcp.reduce_loop_unroll);
- assert(jcp.reduce_loop_unroll == jcp.reduce_block);
-
- int offt = (jcp.ic_without_padding * i_ur + i_reduce);
-
- return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt,
- bcast);
- };
-
- auto load_ptr = [=](int i_reduce, int i_load) {
- int u0 = i_reduce % jcp.reduce_loop_unroll;
- int u1 = i_reduce / jcp.reduce_loop_unroll;
-
- int offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block;
-
- return EVEX_compress_addr(aux_reg_load_data,
- u1 * jcp.reduce_loop_load_step
- + jcp.typesize_in * offt);
- };
-
- auto output_ptr = [=](int i_load, int i_ur) {
- return EVEX_compress_addr(aux_reg_output_data,
- jcp.typesize_out * (jcp.oc_without_padding * i_ur
- + i_load * jcp.load_block));
- };
-
- auto init = [=]() {
- for (int i_load = 0; i_load < load_loop_blk; ++i_load)
- for (int i_ur = 0; i_ur < ur; ++i_ur) {
- auto r = vreg_accum(i_load, i_ur);
- vpxord(r, r, r);
- }
- if (jcp.signed_input) {
- xor_(reg_scratch, reg_scratch);
- Reg8 _t8 = reg_scratch.cvt8();
- mov(_t8, (int8_t)-128);
- vpbroadcastb(zmm_shift, _t8);
- }
- };
-
- auto store = [=](const bool mask_flag_in) {
- const auto &p = attr_.post_ops_;
- const int sum_idx = p.find(primitive_kind::sum);
- const float *p_sum_scale = (sum_idx != -1)
- ? &p.entry_[sum_idx].sum.scale
- : nullptr;
- mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
- mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off));
- if (p_sum_scale && *p_sum_scale != 1.f) {
- mov(EVEX_compress_addr(rsp, reg_load_data_off), reg_load_data);
- mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
- }
- if (jcp.signed_input && jcp.ver != ver_vnni) {
- mov(reg_scratch, float2int(jcp.wei_adj_scale));
- vmovq(xmm_bias_alpha(), reg_scratch);
- vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha());
- }
- for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
- const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1;
- auto zmm_bias = zmm_tmp;
- auto zmm_comp = zmm_bcast;
- if (jcp.with_bias) {
- if (jcp.signed_input)
- mov(reg_bias_data,
- EVEX_compress_addr(rsp,reg_bias_data_off));
- cvt2ps(jcp.bia_dt, zmm_bias, bias_ptr(i_load), mask_flag);
- if (jcp.signed_input && jcp.ver != ver_vnni)
- vmulps(zmm_bias, zmm_bias, zmm_bias_alpha());
- }
- if (jcp.signed_input) {
- mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off));
- cvt2ps(data_type::s32, zmm_comp, comp_ptr(i_load), mask_flag);
- }
-
- for (int i_ur = 0; i_ur < ur; ++i_ur) {
- auto r = vreg_accum(i_load, i_ur);
- vcvtdq2ps(r, r);
- if (jcp.signed_input)
- vaddps(r, r, zmm_comp);
- if (jcp.with_bias)
- vaddps(r, r, zmm_bias);
-
- zmm_t mask_zmm = mask_flag ? r | ktail_mask | T_z : r;
- vmulps(mask_zmm, r, scale_ptr(i_load));
- }
- }
-
- if (maybe_eltwise(0))
- eltwise_injector_->compute_vector_range(0, ur * load_loop_blk);
-
- if (p_sum_scale) { // post_op: sum
- for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
- const bool mask_flag = mask_flag_in &&
- i_load == load_loop_blk - 1;
- for (int i_ur = 0; i_ur < ur; ++i_ur) {
- vpxord(zmm_zero, zmm_zero, zmm_zero);
- auto zmm_prev_dst = zmm_zero;
-
- auto r = vreg_accum(i_load, i_ur);
- cvt2ps(jcp.dst_dt, zmm_prev_dst, output_ptr(i_load, i_ur),
- mask_flag);
-
- if (*p_sum_scale == 1.f)
- vaddps(r, zmm_prev_dst);
- else
- vfmadd231ps(r, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
- }
- }
- }
-
- if (maybe_eltwise(1))
- eltwise_injector_->compute_vector_range(0, ur * load_loop_blk);
-
- for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
- const bool mask_flag = mask_flag_in &&
- i_load == load_loop_blk - 1;
- for (int i_ur = 0; i_ur < ur; ++i_ur) {
- auto r = vreg_accum(i_load, i_ur);
- if (jcp.dst_dt == data_type::u8) {
- vpxord(zmm_zero, zmm_zero, zmm_zero);
- vmaxps(r, zmm_zero, r);
- }
- if (jcp.dst_dt != data_type::f32)
- vcvtps2dq(r, r);
- }
- for (int i_ur = 0; i_ur < ur; ++i_ur) {
- auto r = vreg_accum(i_load, i_ur);
- zmm_t r_zmm = mask_flag ? r | ktail_mask : r;
-
- switch (jcp.dst_dt) {
- case data_type::f32:
- case data_type::s32:
- vmovups(output_ptr(i_load, i_ur), r_zmm); break;
- case data_type::s8:
- vpmovsdb(output_ptr(i_load, i_ur), r_zmm); break;
- case data_type::u8:
- vpmovusdb(output_ptr(i_load, i_ur), r_zmm); break;
- default: assert(!"unknown dst_dt");
- }
- }
- }
- mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off));
- if (p_sum_scale && *p_sum_scale != 1.f)
- mov(reg_load_data, EVEX_compress_addr(rsp, reg_load_data_off));
- };
-
- auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
- if (jcp.ver == ver_vnni) {
- vpdpbusd(vreg_acc, vreg_src, vreg_wei);
- } else {
- vpmaddubsw(zmm_tmp, vreg_src, vreg_wei);
- vpmaddwd(zmm_tmp, zmm_tmp, zmm_one);
- vpaddd(vreg_acc, vreg_acc, zmm_tmp);
- }
- };
-
- auto fma_block = [=](bool last_block) {
- int reduce_step = 4;
- int tail_size = jcp.ic_without_padding % reduce_step;
- int loop_unroll = last_block && jcp.ic != jcp.ic_without_padding
- ? rnd_up(jcp.ic_without_padding % jcp.ic_block, reduce_step)
- : jcp.reduce_loop_unroll;
- for (int i_reduce = 0; i_reduce < loop_unroll;
- i_reduce += reduce_step) {
- for (int i_load = 0; i_load < load_loop_blk; ++i_load)
- vmovups(vreg_load(i_load), load_ptr(i_reduce, i_load));
- for (int i_ur = 0; i_ur < ur; ++i_ur) {
- if (last_block && tail_size != 0
- && i_reduce == loop_unroll - reduce_step) {
- Xmm xmm_bcast = Xmm(zmm_bcast.getIdx());
- for (int r = 0; r < tail_size; ++r)
- vpinsrb(xmm_bcast, xmm_bcast, ptr[aux_reg_bcast_data
- + jcp.ic_without_padding * i_ur + i_reduce + r], r);
- vpbroadcastd(zmm_bcast, xmm_bcast);
- } else {
- vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, i_ur, false));
- }
- if (jcp.signed_input)
- vpsubb(zmm_bcast, zmm_bcast, zmm_shift);
- for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
- compute(vreg_accum(i_load, i_ur),
- vreg_load(i_load), zmm_bcast);
- }
- }
- }
- };
-
- Label reduce_loop;
- Label reduce_loop_tail;
-
- mov(aux_reg_load_data, reg_load_data);
-
- mov(aux_reg_bcast_data, aux1_reg_bcast_data);
- init();
-
- mov(reduce_loop_iter, reg_reduce_loop_work);
- sub(reduce_loop_iter, jcp.reduce_loop_unroll);
- jle(reduce_loop_tail, T_NEAR);
-
- L(reduce_loop); {
- fma_block(false);
- add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
- add(aux_reg_load_data, jcp.reduce_loop_load_step);
- sub(reduce_loop_iter, jcp.reduce_loop_unroll);
- jg(reduce_loop, T_NEAR);
- }
-
- L(reduce_loop_tail);
- if (jcp.ic != jcp.ic_without_padding) {
- fma_block(true);
- } else {
- fma_block(false);
- }
-
- if (jcp.oc_without_padding != jcp.oc) {
- Label end_store, common_store;
- mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
-
- /*Check if it is the last load_loop_blk*/
- sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
- cmp(reg_load_loop_work, 0);
- jg(common_store, T_NEAR);
-
- /*Check if it is the last ocb*/
- test(reg_reduce_pos_flag, FLAG_OC_LAST);
- jz(common_store, T_NEAR);
-
- store(true);
- jmp(end_store, T_NEAR);
-
- L(common_store);
- store(false);
-
- L(end_store);
-
- add(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
- } else {
- store(false);
- }
-}
-
-void jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate()
-{
- preamble();
-
- xor_(reg_scratch, reg_scratch);
- Reg16 _t = reg_scratch.cvt16();
- mov(_t, 0x1);
- vpbroadcastw(zmm_one, _t);
-
- sub(rsp, stack_space_needed);
-
- if (jcp.oc_without_padding != jcp.oc) {
- int tail_size = jcp.oc_without_padding % jcp.oc_block;
- int mask = (1 << tail_size) - 1;
- Reg32 regw_tmp = reg_last_load.cvt32();
- mov(regw_tmp, mask);
- kmovw(ktail_mask, regw_tmp);
- }
-
- if (jcp.with_bias)
- mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
- if (jcp.signed_input) {
- mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data);
- mov(reg_comp_data, ptr[param1 + GET_OFF(compensation)]);
- mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data);
- }
- mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
- mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales);
- mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
- mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
- mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
-
- mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
- mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
- mov(EVEX_compress_addr(rsp, bcast_loop_work_off), reg_bcast_loop_work);
- mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
- mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
-
-
- auto load_loop_body = [=](int load_loop_blk) {
- bcast_loop(load_loop_blk);
- add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
- if (jcp.with_bias) {
- if (jcp.signed_input)
- mov(reg_bias_data, EVEX_compress_addr(rsp, reg_bias_data_off));
- add(reg_bias_data,
- load_loop_blk * jcp.load_block * jcp.typesize_bia);
- if (jcp.signed_input)
- mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data);
- }
- if (jcp.signed_input) {
- mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off));
- add(reg_comp_data,
- load_loop_blk * jcp.load_block * sizeof(int32_t));
- mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data);
- }
- mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
- mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off));
- add(reg_ptr_scales,
- jcp.is_oc_scale * load_loop_blk * jcp.load_block * sizeof(float));
- mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales);
- mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off));
- add(reg_output_data,
- load_loop_blk * jcp.load_block * jcp.typesize_out);
- sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
- };
-
- const int simd_w = 16;
-
- Label load_loop_blk[7];
-
- static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 };
- const int size_ur_cases_fma = sizeof(ur_cases_fma_expl_bcast);
- const int *ur_cases_fma = ur_cases_fma_expl_bcast;
- const int *ur_cases = ur_cases_fma;
- const int num_ur_cases = (size_ur_cases_fma) / sizeof(*ur_cases);
-
- for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) {
- int label_idx = num_ur_cases - ur_idx - 1;
- if (jcp.ur <= ur_cases[ur_idx]) {
- cmp(reg_load_loop_work, simd_w * (label_idx + 1));
- jle(load_loop_blk[label_idx], T_NEAR);
- }
- }
-
- for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) {
- if (jcp.ur <= ur_cases[ur_idx]) {
- int label_idx = num_ur_cases - ur_idx - 1;
- L(load_loop_blk[label_idx]);
- {
- if (label_idx == 0) {
- cmp(reg_load_loop_work, 0);
- je(load_loop_blk[num_ur_cases], T_NEAR);
- }
-
- for (int _i = 1; _i <= label_idx + 1; _i++) {
- prefetcht0(ptr [ reg_load_data + _i * jcp.ic * jcp.oc_block ]);
- prefetcht1(ptr [ reg_output_data + _i * jcp.oc_block ]);
- }
-
- load_loop_body(label_idx + 1);
- if (label_idx - 1 > 0) {
- cmp(reg_load_loop_work, 2 * label_idx * simd_w);
- je(load_loop_blk[label_idx - 1], T_NEAR);
- }
- cmp(reg_load_loop_work, (label_idx + 1) * simd_w);
- jge(load_loop_blk[label_idx]);
- }
- for (int idx = label_idx - 1; idx > 0; --idx) {
- cmp(reg_load_loop_work, simd_w * (idx + 1));
- je(load_loop_blk[idx], T_NEAR);
- }
- if (ur_idx < num_ur_cases - 2) {
- cmp(reg_load_loop_work, simd_w);
- jle(load_loop_blk[0], T_NEAR);
- }
- }
- }
- L(load_loop_blk[num_ur_cases]);
-
- add(rsp, stack_space_needed);
-
- postamble();
-
- if (jcp.with_eltwise)
- eltwise_injector_->prepare_table();
-}
-
-bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::post_ops_ok(
- jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
- using namespace primitive_kind;
- const auto &p = attr.post_ops_;
-
- auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
-
- switch (p.len_) {
- case 0: return true;
- case 1: return is_eltwise(0) || p.contain(sum, 0);
- case 2: return (p.contain(sum, 0) && is_eltwise(1))
- || (p.contain(sum, 1) && is_eltwise(0));
- default: return false;
- }
-
- return false;
-}
-
-status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf(
- jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d, const memory_desc_wrapper &bias_d,
- const primitive_attr_t &attr, int nthreads, bool reduce_src) {
- if (!mayiuse(avx512_core)) return status::unimplemented;
-
- const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
- if (!one_of(src_d.data_type(), data_type::u8, data_type::s8)
- || weights_d.data_type() != data_type::s8
- || !one_of(dst_d.data_type(),
- data_type::f32, data_type::s32, data_type::s8, data_type::u8))
- return status::unimplemented;
- jcp.ver = ver_avx512_core;
- if (mayiuse(avx512_core_vnni))
- jcp.ver = ver_vnni;
-
- jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
- jcp.mb = src_d.dims()[0];
- jcp.oc = dst_d.dims()[1] / jcp.ngroups;
- jcp.oc_without_padding = jcp.oc;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
- jcp.ic_without_padding = jcp.ic;
- jcp.ih = src_d.dims()[2];
- jcp.iw = src_d.dims()[3];
- jcp.oh = dst_d.dims()[2];
- jcp.ow = dst_d.dims()[3];
- jcp.kh = weights_d.dims()[with_groups + 2];
- jcp.kw = weights_d.dims()[with_groups + 3];
- jcp.t_pad = cd.padding[0][0];
- jcp.l_pad = cd.padding[0][1];
- jcp.stride_h = cd.strides[0];
- jcp.stride_w = cd.strides[1];
- jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
-
- jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
-
- jcp.os = jcp.oh * jcp.ow;
- jcp.is = jcp.ih * jcp.iw;
- jcp.tr_is = rnd_up(jcp.is, 4);
-
- if (!post_ops_ok(jcp, attr))
- return status::unimplemented;
-
- const auto &p = attr.post_ops_;
- const int eltwise_ind = p.find(primitive_kind::eltwise);
- jcp.with_eltwise = eltwise_ind != -1;
- if (jcp.with_eltwise)
- jcp.eltwise = p.entry_[eltwise_ind].eltwise;
-
- format_tag_t dat_tag = format_tag::nhwc;
- jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
- jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
-
- bool args_ok = true
- && jcp.ngroups == 1
- && jcp.src_tag == dat_tag
- && jcp.dst_tag == dat_tag;
- if (!args_ok) return status::unimplemented;
-
- const int simd_w = 16;
-
- jcp.oc = rnd_up(jcp.oc, simd_w);
- jcp.ic = rnd_up(jcp.ic, simd_w);
-
- args_ok = true
- && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
- && jcp.t_pad == 0 && jcp.l_pad == 0
- && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
- && jcp.kh == 1 && jcp.kw == 1;
- if (!args_ok) return status::unimplemented;
-
- jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
- jcp.dst_dt = cd.dst_desc.data_type;
-
- jcp.ic_block = jcp.oc_block = simd_w;
-
- jcp.typesize_in = types::data_type_size(src_d.data_type());
- jcp.typesize_out = types::data_type_size(dst_d.data_type());
- jcp.typesize_bia = jcp.with_bias
- ? types::data_type_size(bias_d.data_type())
- : 0;
-
- const int SMALL_SPATIAL = 7 * 7;
- const int BIG_REDUCE_DIM = 1024;
-
- int load_blocking = 0;
- int load_blocking_max = 0;
- int bcast_blocking = 0;
- int bcast_blocking_max = 0;
- int reduce_blocking = 0;
- int reduce_blocking_max = 0;
- jcp.load_grp_count = 1;
- jcp.use_vmovntps = false;
-
- const int L2_size = get_cache_size(2, true) / sizeof(jcp.typesize_in);
- const int L2_capacity = (L2_size * 3) / 4;
-
- int size_treshold = 28;
- int max_regs = 0;
- int min_regs = 6;
- if (jcp.ver == ver_vnni)
- max_regs = ((jcp.oh > size_treshold && jcp.ow > size_treshold)
- && (jcp.oc < 128 || jcp.ic < 128)) ? min_regs : 9;
- else
- max_regs = 8;
- jcp.expl_bcast = true;
-
- if (jcp.mb == 1 && jcp.ic > 128
- && (jcp.oh <= size_treshold && jcp.ow <= size_treshold)) {
- if (jcp.os <= SMALL_SPATIAL && jcp.oc * jcp.ic < L2_size)
- max_regs = min_regs; // mobilenet_v2 performance improvement
- jcp.ur = nstl::min(max_regs, jcp.os);
- } else {
- const int spatial = jcp.oh;
- jcp.ur = 1;
- for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) {
- if ((spatial >= size_treshold && spatial % ur_w == 0)
- || (spatial < size_treshold && jcp.os % ur_w == 0)) {
- jcp.ur = ur_w;
- break;
- }
- }
- if (jcp.ur == 1) {
- jcp.ur = nstl::min(max_regs, jcp.os);
- int os_tail = jcp.os % max_regs;
- for (int i = max_regs; i >= min_regs; i--) {
- int i_tail = jcp.os % i;
- if (i_tail > os_tail || i_tail == 0) {
- jcp.ur = i;
- os_tail = i_tail;
- if (i_tail == 0)
- break;
- }
- }
- }
- }
-
- jcp.reduce_dim = jcp.ic;
- jcp.reduce_block = jcp.ic_block;
-
- jcp.load_dim = jcp.oc;
- jcp.load_block = jcp.oc_block;
-
- jcp.bcast_dim = jcp.is;
-
- jcp.bcast_block = jcp.ur;
-
- jcp.reduce_loop_unroll = jcp.reduce_block;
- jcp.reduce_loop_bcast_step
- = jcp.reduce_loop_unroll * jcp.typesize_in;
-
- jcp.reduce_loop_load_step
- = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in;
-
- jcp.bcast_loop_output_step = jcp.ur * jcp.oc_without_padding * jcp.typesize_out;
- jcp.bcast_loop_output_substep = -1; // unused
- jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_without_padding * jcp.typesize_in;
- jcp.bcast_loop_bcast_substep = -1; // unused
-
- jcp.load_loop_load_step
- = jcp.reduce_dim * jcp.load_block * jcp.typesize_in;
-
- jcp.load_loop_iter_step = jcp.load_block;
-
- jcp.loop_order = reduce_src ? loop_blr : loop_lbr;
-
- int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
- int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
-
- reduce_blocking = nb_reduce;
- if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
- reduce_blocking = 64;
- else if (jcp.bcast_dim > SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
- reduce_blocking = 16;
- reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true);
- reduce_blocking *= jcp.reduce_block;
-
- bool cmp_reduce = reduce_blocking <= jcp.reduce_dim;
- if (cmp_reduce)
- jcp.loop_order = reduce_src ? loop_rbl : loop_rlb;
- load_blocking = jcp.load_dim;
-
- jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast);
- jcp.load_grp_count = best_divider(
- nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false);
-
- if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.load_dim * jcp.reduce_dim >= L2_size) {
- jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4);
- } else if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.mb <= nthreads
- && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) {
- jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); //
- load_blocking = jcp.load_block;
- }
-
- bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
- div_up(nthreads, jcp.load_grp_count)) * jcp.bcast_block;
- bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking);
- bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block);
-
- int space_for_bcast
- = (L2_capacity - /* kernel_size - */
- 2 * jcp.load_block * reduce_blocking
- - jcp.ur * reduce_blocking - 3 * 1024);
- if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity)
- space_for_bcast /= 2;
-
- int bcast_in_cache
- = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking);
- bcast_blocking = nstl::min(
- bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block));
-
- load_blocking_max = load_blocking;
- bcast_blocking_max = bcast_blocking * 3 / 2;
- reduce_blocking_max = reduce_blocking;
-
- assert(load_blocking);
- assert(load_blocking_max);
- assert(bcast_blocking);
- assert(bcast_blocking_max);
- assert(reduce_blocking);
- assert(reduce_blocking_max);
- assert(load_blocking % jcp.load_block == 0);
- assert(reduce_blocking % jcp.reduce_block == 0);
- assert(load_blocking_max % jcp.load_block == 0);
- assert(reduce_blocking_max % jcp.reduce_block == 0);
-
- assert(jcp.reduce_loop_unroll % 4 == 0);
- assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
-
- assert(jcp.bcast_block % jcp.ur == 0);
- assert(jcp.reduce_dim % jcp.reduce_block == 0);
-
- jcp.ur_tail = jcp.bcast_dim % jcp.ur;
-
- jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
- jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
- jcp.nb_load_blocking = load_blocking / jcp.load_block;
- jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
- jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
- jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block;
-
- jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
- jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
- jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
-
- // miniumum size of load dim chunk for work distribution within threads
- jcp.nb_load_chunk = 1;
- // peformance improvements for googlenet_v3, mb=1;
- // TODO: generalize this condition and rewrite it in appropriate manner
- if (jcp.mb == 1 && jcp.nb_load % 4 == 0 && jcp.ic / jcp.oc >= 4
- && jcp.ic * jcp.oc <= L2_size) {
- jcp.nb_load_chunk = 4;
- jcp.load_grp_count = nstl::max(jcp.nb_load / 4, jcp.load_grp_count);
- }
-
- const auto &oscales = attr.output_scales_;
- jcp.is_oc_scale = oscales.mask_ == 1 << 1;
- assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
-
- jcp.wei_adj_scale =
- (weights_d.extra().flags | memory_extra_flags::scale_adjust)
- ? weights_d.extra().scale_adjust : 1.f;
-
- return status::success;
-}
-
-void jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad(
- memory_tracking::registrar_t &scratchpad,
- const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
- using namespace mkldnn::impl::memory_tracking::names;
-
- if (jcp.signed_input && jcp.ver != ver_vnni) {
- dim_t count = nstl::max<dim_t>(attr.output_scales_.count_, 16);
- scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
- }
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp
deleted file mode 100644
index 22e9732a1f..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp
+++ /dev/null
@@ -1,131 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP
-#define JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-
-#include "jit_generator.hpp"
-#include "jit_primitive_conf.hpp"
-#include "jit_uni_eltwise.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct jit_avx512_core_x8s8s32x_1x1_conv_kernel: public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_1x1_conv_fwd_ker_t)
- jit_avx512_core_x8s8s32x_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp,
- const primitive_attr_t &attr) : jcp(ajcp), attr_(attr),
- eltwise_injector_(nullptr)
- {
- if (jcp.with_eltwise)
- eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>(
- this, jcp.eltwise);
-
- this->generate();
- jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode();
- }
-
- ~jit_avx512_core_x8s8s32x_1x1_conv_kernel() {
- delete eltwise_injector_;
- }
-
- static bool post_ops_ok(jit_1x1_conv_conf_t &jcp,
- const primitive_attr_t &attr);
-
- static status_t init_conf(jit_1x1_conv_conf_t &jcp,
- const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d,
- const memory_desc_wrapper &bias_d,
- const primitive_attr_t &attr,
- int nthreads, bool reduce_src);
-
- static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
- const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr);
-
- bool maybe_eltwise(int position);
-
- jit_1x1_conv_conf_t jcp;
- const primitive_attr_t &attr_;
- void (*jit_ker)(jit_1x1_conv_call_s *);
-
- private:
- jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_;
-
- using reg64_t = const Xbyak::Reg64;
- using zmm_t = const Xbyak::Zmm;
- using mask_t = const Xbyak::Opmask;
-
- reg64_t reg_bcast_data = r8;
- reg64_t reg_ptr_scales = r8;
- reg64_t reg_output_data = r9;
- reg64_t reg_load_data = r10;
- reg64_t reg_ptr_sum_scale = r10;
- reg64_t reg_reduce_loop_work = r11;
- reg64_t reg_bias_data = r12;
- reg64_t reg_comp_data = r12;
- reg64_t reg_scratch = r13;
- reg64_t aux_reg_bcast_data = r14;
- reg64_t aux_reg_load_data = r15;
- reg64_t imm_addr64 = r15;
- reg64_t reg_reduce_pos_flag = rax;
- reg64_t aux1_reg_bcast_data = rbx;
- reg64_t reg_bcast_loop_work = rbx;
- reg64_t bcast_loop_iter = rdx; // Note: Fix me
- reg64_t reg_load_loop_work = rsi;
- reg64_t aux_reg_output_data = abi_not_param1;
- reg64_t reduce_loop_iter = abi_param1;
-
- reg64_t reg_last_load = r8;
- mask_t ktail_mask = k6;
-
- mask_t vmask = k7;
-
- Xbyak::Zmm zmm_tmp = Xbyak::Zmm(28);
- Xbyak::Zmm zmm_one = Xbyak::Zmm(29);
- Xbyak::Zmm zmm_zero = Xbyak::Zmm(30);
- Xbyak::Zmm zmm_bcast = Xbyak::Zmm(31);
- Xbyak::Zmm zmm_shift = Xbyak::Zmm(30);
-
- Xbyak::Zmm zmm_bias_alpha = Xbyak::Zmm(31);
- Xbyak::Xmm xmm_bias_alpha = Xbyak::Xmm(31);
-
- int bcast_loop_work_off = 0;
- int reg_bias_data_off = 8;
- int reg_bcast_data_off = 16;
- int reg_load_data_off = 24;
- int reg_ptr_sum_scale_off = 32;
- int reg_comp_data_off = 40;
- int stack_space_needed = 48;
-
- void bcast_loop(int load_loop_blk);
- void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound);
-
- void generate();
- void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op,
- bool mask_flag);
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp
deleted file mode 100644
index 0bf09fc677..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp
+++ /dev/null
@@ -1,292 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "jit_generator.hpp"
-
-#include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::memory_tracking::names;
-using namespace mkldnn::impl::utils;
-
-namespace {
-template <typename T, typename U>
-void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end,
- T nx, T &nx_start, T &nx_end, T nx_divider)
-{
- const T grp_size = utils::div_up(nthr, nx_divider);
- const T grp_count = utils::div_up(nthr, grp_size);
-
- T grp = ithr / grp_size;
- T grp_ithr = ithr % grp_size;
- T grp_nthr = grp_size;
- T first_grps = nthr % grp_count;
- if (first_grps > 0 && grp >= first_grps) {
- ithr -= first_grps * grp_size;
- grp_nthr--;
- grp = ithr / grp_nthr + first_grps;
- grp_ithr = ithr % grp_nthr;
- }
- balance211(nx, grp_count, grp, nx_start, nx_end);
- balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end);
-}
-}
-
-/* convolution forward */
-template <data_type_t src_type, data_type_t dst_type>
-void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type, dst_type>::
-execute_forward(const exec_ctx_t &ctx) const
-{
- auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
-
- auto scratchpad = this->scratchpad(ctx);
-
- if (pd()->jcp_.signed_input && pd()->jcp_.ver != ver_vnni) {
- auto local_scales = scratchpad.template get<float>(
- key_conv_adjusted_scales);
- auto scales = pd()->attr()->output_scales_.scales_;
- size_t count = pd()->attr()->output_scales_.count_;
- float factor = 1.f / pd()->jcp_.wei_adj_scale;
- if (count == 1) {
- utils::array_set(local_scales, scales[0] * factor, 16);
- } else {
- for (size_t c = 0; c < count; c++)
- local_scales[c] = scales[c] * factor;
- }
- }
-
- parallel(kernel_->jcp.nthr, [&](const int ithr, const int nthr) {
- execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad);
- });
-}
-
-template <data_type_t src_type, data_type_t dst_type>
-void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type, dst_type>
-::execute_forward_thr(const int ithr, const int nthr, const src_data_t *src,
- const wei_data_t *weights, const char *bias, dst_data_t *dst,
- const memory_tracking::grantor_t &scratchpad) const {
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
-
- const size_t bia_dt_size = pd()->with_bias()
- ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
-
- const auto &jcp = kernel_->jcp;
- auto rtus_space = scratchpad.get<src_data_t>(key_conv_rtus_space);
- auto local_scales = scratchpad.get<float>(key_conv_adjusted_scales);
-
- const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
-
- const int stride_h = pd()->desc()->strides[0];
- const int stride_w = pd()->desc()->strides[1];
- const int pad_t = pd()->desc()->padding[0][0];
- const int pad_l = pd()->desc()->padding[0][1];
-
- const auto &oscales = pd()->attr()->output_scales_;
-
- int offset = jcp.ngroups * (jcp.oc / jcp.oc_block) * (jcp.ic / jcp.ic_block)
- * jcp.oc_block * jcp.ic_block;
- wei_data_t *w = const_cast<wei_data_t *>(weights);
- int32_t* compensation = (jcp.signed_input)
- ? reinterpret_cast<int32_t *>(w + offset) : 0;
-
- auto step = [](int default_step, int remaining, int tail_step) {
- assert(default_step <= tail_step);
- return remaining < tail_step ? remaining : default_step;
- };
-
- auto p = jit_1x1_conv_call_s();
-
- auto rp = rtus_driver_t<avx512_common>::call_params_t();
- const int nb_oc = jcp.nb_load;
- const int os_block = jcp.bcast_block;
-
- int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0};
- balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
- jcp.nb_load / jcp.nb_load_chunk, ocb_start, ocb_end,
- jcp.load_grp_count);
- if (jcp.nb_load_chunk > 1) {
- ocb_start *= jcp.nb_load_chunk;
- ocb_end *= jcp.nb_load_chunk;
- }
-
- auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step,
- int &oh, int &ow, int &ih, int &iw)
- {
- int osb{0};
- nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb,
- jcp.nb_bcast);
- bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
- jcp.nb_bcast_blocking_max);
- bcast_step = nstl::min(bcast_step, bcast_end - iwork);
-
- const int os = osb * os_block;
- oh = os / jcp.ow;
- ow = os % jcp.ow;
-
- ih = nstl::max(oh * stride_h - pad_t, 0);
- iw = nstl::max(ow * stride_w - pad_l, 0);
- rp.iw_start = iw;
-
- p.bcast_dim = this_block_size(os, jcp.os,
- bcast_step * os_block);
- rp.os = p.bcast_dim;
- };
-
- auto init_load = [&](int ocb, int &load_step)
- {
- load_step = step(jcp.nb_load_blocking, ocb_end - ocb,
- jcp.nb_load_blocking_max);
- p.load_dim = this_block_size(ocb * jcp.oc_block,
- ocb_end * jcp.oc_block, load_step * jcp.oc_block);
-
- if (ocb + load_step >= nb_oc)
- p.first_last_flag |= FLAG_OC_LAST;
- else
- p.first_last_flag &= ~FLAG_OC_LAST;
-
- };
-
- auto init_reduce = [&]()
- {
- p.reduce_dim = this_block_size(0, jcp.ic, jcp.ic);
- rp.icb = p.reduce_dim / jcp.reduce_block;
- };
-
- auto inner_ker = [&](int ocb, int n, int g, int oh, int ow,
- int ih, int iw)
- {
- const int icb = 0; // Start from the first IC block
- const int _ocb = g * nb_oc + ocb;
- const int _icb = g;
-
- const size_t dst_off = dst_d.blk_off(n, _ocb * jcp.oc_block, oh, ow);
-
- p.output_data = &dst[dst_off];
- p.load_data = &weights[pd()->with_groups()
- ? weights_d.blk_off(g, ocb, icb)
- : weights_d.blk_off(ocb, icb)];
- p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size];
- p.compensation = (jcp.signed_input)
- ? &compensation[_ocb * jcp.oc_block] : 0;
- p.scales = (jcp.signed_input && jcp.ver != ver_vnni)
- ? &local_scales[jcp.is_oc_scale * _ocb * jcp.oc_block]
- : &oscales.scales_[jcp.is_oc_scale * _ocb * jcp.oc_block];
- if (pd()->rtus_.reduce_src_) {
- rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
- + _icb * jcp.is * jcp.ic_block;
- if (ocb == ocb_start) {
- rp.src = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw);
- rtus_driver_->ker_(&rp);
- }
- p.bcast_data = rp.ws;
- } else
- p.bcast_data = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw);
-
- kernel_->jit_ker(&p);
- };
-
- if (jcp.loop_order == loop_rlb) {
- init_reduce();
- int ocb = ocb_start;
- while (ocb < ocb_end) {
- int load_step;
- init_load(ocb, load_step);
- int iwork = bcast_start;
- while (iwork < bcast_end) {
- int n, g, bcast_step, oh, ow, ih, iw;
- init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
- inner_ker(ocb, n, g, oh, ow, ih, iw);
- iwork += bcast_step;
- }
- ocb += load_step;
- }
- } else if (jcp.loop_order == loop_lbr) {
- int ocb = ocb_start;
- while (ocb < ocb_end) {
- int load_step;
- init_load(ocb, load_step);
- int iwork = bcast_start;
- while (iwork < bcast_end) {
- int n, g, bcast_step, oh, ow, ih, iw;
- init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
- init_reduce();
- inner_ker(ocb, n, g, oh, ow, ih, iw);
- iwork += bcast_step;
- }
- ocb += load_step;
- }
- } else if (jcp.loop_order == loop_rbl) {
- init_reduce();
- int iwork = bcast_start;
- while (iwork < bcast_end) {
- int n, g, bcast_step, oh, ow, ih, iw;
- init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
- int ocb = ocb_start;
- while (ocb < ocb_end) {
- int load_step;
- init_load(ocb, load_step);
- inner_ker(ocb, n, g, oh, ow, ih, iw);
- ocb += load_step;
- }
- iwork += bcast_step;
- }
- } else if (jcp.loop_order == loop_blr) {
- int iwork = bcast_start;
- while (iwork < bcast_end) {
- int n, g, bcast_step, oh, ow, ih, iw;
- init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
- int ocb = ocb_start;
- while (ocb < ocb_end) {
- int load_step;
- init_load(ocb, load_step);
- init_reduce();
- inner_ker(ocb, n, g, oh, ow, ih, iw);
- ocb += load_step;
- }
- iwork += bcast_step;
- }
- } else {
- assert(!"unsupported loop order");
- }
-}
-
-using namespace data_type;
-template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, u8>;
-template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, u8>;
-template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s8>;
-template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s8>;
-template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s32>;
-template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s32>;
-template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, f32>;
-template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, f32>;
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp
deleted file mode 100644
index ad9027ac17..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp
+++ /dev/null
@@ -1,159 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_1X1_CONVOLUTION_HPP
-#define CPU_JIT_AVX512_CORE_X8S8S32X_1X1_CONVOLUTION_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "mkldnn_thread.hpp"
-#include "utils.hpp"
-
-#include "cpu_convolution_pd.hpp"
-#include "cpu_primitive.hpp"
-
-#include "jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp"
-#include "jit_uni_1x1_conv_utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template<impl::data_type_t src_type, impl::data_type_t dst_type>
-struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t : public cpu_primitive_t {
- struct pd_t: public cpu_convolution_fwd_pd_t {
- pd_t(engine_t *engine, const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const typename pd_t::base_class *hint_fwd_pd)
- : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_(), rtus_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_int8_1x1:", avx512_core, ""),
- jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<
- src_type, dst_type>);
-
- status_t init() {
- bool ok = true
- && is_fwd()
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(src_type, data_type::s8, data_type::undef,
- dst_type, data_type::s32)
- && IMPLICATION(with_bias(), utils::one_of(
- desc()->bias_desc.data_type, data_type::f32,
- data_type::s32, data_type::s8, data_type::u8))
- && !has_zero_dim_memory()
- && set_default_formats_common(dat_tag(), format_tag::any,
- dat_tag())
- && set_or_check_wei_format();
- if (!ok) return status::unimplemented;
-
- const convolution_desc_t *conv_d = desc();
- const memory_desc_t *src_d = src_md();
- rtus_prepare(this, conv_d, src_d, dst_md());
-
- status_t status = jit_avx512_core_x8s8s32x_1x1_conv_kernel::
- init_conf(jcp_, *conv_d, *src_d, *weights_md(), *dst_md(),
- *weights_md(1), *attr(), mkldnn_get_max_threads(),
- rtus_.reduce_src_);
- if (status != status::success) return status;
-
- auto scratchpad = scratchpad_registry().registrar();
- jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad(
- scratchpad, jcp_, *attr());
-
- rtus_prepare_space_info(this, scratchpad);
-
- return status::success;
- }
-
- jit_1x1_conv_conf_t jcp_;
- reduce_to_unit_stride_t rtus_;
-
- protected:
- format_tag_t dat_tag() const { return format_tag::nhwc; }
-
- bool set_or_check_wei_format() {
- using namespace format_tag;
-
- const bool is_src_s8 = src_md_.data_type == data_type::s8;
- format_tag_t wei_tag = with_groups() ? gOIhw4i16o4i : OIhw4i16o4i;
-
- memory_desc_t want_wei_md = weights_md_;
- memory_desc_init_by_tag(want_wei_md, wei_tag);
- if (is_src_s8) {
- want_wei_md.extra.flags = 0
- | memory_extra_flags::compensation_conv_s8s8
- | memory_extra_flags::scale_adjust;
- want_wei_md.extra.compensation_mask = (1 << 0)
- + (with_groups() ? (1 << 1) : 0);
- want_wei_md.extra.scale_adjust =
- mayiuse(avx512_core_vnni) ? 1.f : 0.5f;
- }
-
- if (weights_md_.format_kind == format_kind::any) {
- weights_md_ = want_wei_md;
- return true;
- }
-
- return weights_md_ == want_wei_md;
- }
- };
-
- template <cpu_isa_t isa, typename conv_t>
- friend void init_rtus_driver(conv_t *self);
-
- jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t(const pd_t *apd)
- : cpu_primitive_t(apd)
- , kernel_(nullptr), rtus_driver_(nullptr)
- {
- kernel_ = new jit_avx512_core_x8s8s32x_1x1_conv_kernel(pd()->jcp_,
- *pd()->attr());
- init_rtus_driver<avx512_common>(this);
- }
-
- ~jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t() {
- delete kernel_;
- delete rtus_driver_;
- }
-
- typedef typename prec_traits<src_type>::type src_data_t;
- typedef typename prec_traits<data_type::s8>::type wei_data_t;
- typedef typename prec_traits<dst_type>::type dst_data_t;
- typedef typename prec_traits<data_type::s32>::type acc_data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
- private:
- void execute_forward(const exec_ctx_t &ctx) const;
- void execute_forward_thr(const int ithr, const int nthr,
- const src_data_t *src, const wei_data_t *weights,
- const char *bias, dst_data_t *dst,
- const memory_tracking::grantor_t &scratchpad) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_avx512_core_x8s8s32x_1x1_conv_kernel *kernel_;
- rtus_driver_t<avx512_common> *rtus_driver_;
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp
deleted file mode 100644
index e89d068302..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp
+++ /dev/null
@@ -1,140 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP
-#define CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "utils.hpp"
-#include "type_helpers.hpp"
-#include "primitive_iterator.hpp"
-
-#include "cpu_convolution_pd.hpp"
-#include "cpu_deconvolution_pd.hpp"
-#include "cpu_primitive.hpp"
-
-#include "jit_uni_1x1_conv_utils.hpp"
-#include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <impl::data_type_t src_type, impl::data_type_t dst_type>
-struct jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t
- : public cpu_primitive_t {
- struct pd_t : public cpu_deconvolution_fwd_pd_t {
- pd_t(engine_t *engine, const deconvolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const deconvolution_fwd_pd_t *hint_fwd_pd)
- : cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- , conv_pd_(nullptr) {}
-
- pd_t(const pd_t &other)
- : cpu_deconvolution_fwd_pd_t(other)
- , conv_pd_(other.conv_pd_->clone())
- {}
-
- ~pd_t() { delete conv_pd_; }
-
- DECLARE_COMMON_PD_T(conv_pd_->name(),
- jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<src_type, dst_type>);
-
- status_t init_convolution() {
- convolution_desc_t cd;
- status_t status;
-
- auto dd = desc();
- status = conv_desc_init(&cd, prop_kind::forward_training,
- alg_kind::convolution_direct, &(dd->src_desc),
- &(dd->weights_desc), &(dd->bias_desc), &(dd->dst_desc),
- dd->strides, dd->dilates, dd->padding[0], dd->padding[1],
- dd->padding_kind);
-
- if (status == status::success) {
- status = mkldnn_primitive_desc::create<conv_pd_t>(
- &conv_pd_, (op_desc_t *)&cd, &attr_, engine_, nullptr);
- }
-
- if (status == status::success)
- status = set_default_params();
-
- return status;
- };
-
- status_t init() {
- bool ok = true
- && is_fwd()
- && desc()->alg_kind == alg_kind::deconvolution_direct
- && !has_zero_dim_memory()
- && desc()->src_desc.data_type == src_type
- && desc()->dst_desc.data_type == dst_type
- && desc()->weights_desc.data_type == data_type::s8
- && IMPLICATION(with_bias(), utils::one_of(
- desc()->bias_desc.data_type, data_type::f32,
- data_type::s32, data_type::s8, data_type::u8))
- && desc()->accum_data_type == data_type::s32;
- if (!ok) return status::unimplemented;
-
- CHECK(init_convolution());
-
- return status::success;
- }
-
- virtual void init_scratchpad_md() override {
- const auto conv_1x1_pd = static_cast<conv_pd_t *>(conv_pd_);
- scratchpad_md_ = *conv_1x1_pd->scratchpad_md();
- }
-
- protected:
- status_t set_default_params() {
- auto conv_1x1_pd_ = static_cast<conv_pd_t *>(conv_pd_);
- src_md_ = *conv_1x1_pd_->src_md();
- dst_md_ = *conv_1x1_pd_->dst_md();
- weights_md_ = *conv_1x1_pd_->weights_md();
- if (with_bias())
- bias_md_ = *conv_1x1_pd_->weights_md(1);
- return status::success;
- }
-
- using conv_pd_t = typename jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t
- <src_type, dst_type>::pd_t;
- friend jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t;
- primitive_desc_t *conv_pd_;
- };
-
- jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t(const pd_t *apd)
- : cpu_primitive_t(apd)
- { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); }
-
- ~jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t()
- { delete conv_p_; }
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- return conv_p_->execute(ctx);
- }
-
-private:
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- primitive_t *conv_p_;
-};
-
-}
-}
-}
-
-#endif /* CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP */
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp
deleted file mode 100644
index 10e98a00c4..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp
+++ /dev/null
@@ -1,1182 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_memory.hpp"
-
-#include "jit_avx512_core_x8s8s32x_conv_kernel.hpp"
-
-#define GET_OFF(field) offsetof(jit_conv_call_s, field)
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::memory_tracking::names;
-using namespace mkldnn::impl::utils;
-using namespace Xbyak;
-
-namespace {
-void pick_loop_order(jit_conv_conf_t &jcp, int nthr)
-{
- jcp.loop_order = loop_cwgn;
- if (jcp.ngroups > 1) {
- jcp.loop_order = loop_ngcw;
- if (jcp.mb < nthr)
- jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg;
- }
-}
-}
-
-template<typename Vmm>
-bool _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::maybe_eltwise(int position)
-{
- using namespace primitive_kind;
- const auto &p = attr_.post_ops_;
-
- if (position == 0) {
- /* eltwise before sum */
- return p.contain(eltwise, 0);
- } else if (position == 1) {
- /* eltwise after sum */
- return p.contain(sum, 0) && p.contain(eltwise, 1);
- }
-
- return false;
-}
-
-template<typename Vmm>
-void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::prepare_output(int ur_w)
-{
- int nb_oc_block
- = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
- for (int k = 0; k < nb_oc_block; k++)
- for (int j = 0; j < ur_w; j++) {
- Vmm vmm = vmm_out(j, k);
- vpxord(vmm, vmm, vmm);
- }
- if (jcp.signed_input) {
- xor_(reg_scratch, reg_scratch);
- if (jcp.is_depthwise && !jcp.is_fast_depthwise) {
- Reg32 _t32 = reg_scratch.cvt32();
- mov(_t32, (uint32_t)128);
- vpbroadcastd(vmm_shift, _t32);
- } else {
- Reg8 _t8 = reg_scratch.cvt8();
- mov(_t8, (int8_t)128);
- vpbroadcastb(vmm_shift, _t8);
- }
- }
-}
-
-template<typename Vmm>
-const Vmm _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::
- vmm_mask(const Vmm vmm_in, bool mask_flag, bool store) {
- return vmm_in;
-}
-
-template<>
-const Zmm _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::
- vmm_mask(const Zmm zmm_in, bool mask_flag, bool store) {
- return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
- : zmm_in;
-}
-
-
-template<typename Vmm>
-void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::cvt2ps(data_type_t type_in,
- const Vmm vmm_in, const Operand &op, bool mask_flag) {
- //const Vmm vmm = mask_flag ? vmm_in | ktail_mask | T_z : vmm_in;
- const Vmm vmm = vmm_mask(vmm_in, mask_flag);
- switch (type_in) {
- case data_type::f32:
- case data_type::s32: vmovups(vmm, op); break;
- case data_type::s8: vpmovsxbd(vmm, op); break;
- case data_type::u8: vpmovzxbd(vmm, op); break;
- default: assert(!"unsupported data type");
- }
- if (type_in != data_type::f32)
- vcvtdq2ps(vmm_in, vmm_in);
-}
-
-template<typename Vmm>
-void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_eltwise(int ur_w) {
- int nb_oc_block
- = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
- if (ur_w == jcp.ur_w)
- eltwise_injector_->compute_vector_range(0, nb_oc_block * jcp.ur_w);
- else
- for (int k = 0; k < nb_oc_block; k++)
- eltwise_injector_->compute_vector_range(k * jcp.ur_w,
- k * jcp.ur_w + ur_w);
-}
-
-template<typename Vmm>
-void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::store_output(
- int ur_w, bool last_oc_block_flag) {
- int nb_oc_block
- = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
- int oc_block = jcp.is_depthwise ? jcp.ch_block : jcp.oc_block;
-
- mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
- mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
- if (jcp.signed_input)
- mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]);
-
- const auto &p = attr_.post_ops_;
- const int sum_idx = p.find(primitive_kind::sum);
- const float *p_sum_scale = nullptr;
- if (sum_idx != -1) {
- const auto &p_entry = p.entry_[sum_idx];
- p_sum_scale = &p_entry.sum.scale;
- }
-
- if (p_sum_scale && *p_sum_scale != 1.f)
- mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
-
- if (jcp.signed_input && jcp.ver != ver_vnni) {
- /* put 'wei_adj_scale = 0.5' for bias calculation */
- mov(reg_bias_alpha, float2int(jcp.wei_adj_scale));
- vmovq(xmm_bias_alpha(), reg_bias_alpha);
- vbroadcastss(vmm_bias_alpha(), xmm_bias_alpha());
- }
-
- for (int k = 0; k < nb_oc_block; k++) {
- const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
- int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * oc_block);
- if (jcp.with_bias) {
- int bias_offset = jcp.typesize_bia * k * oc_block;
- auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
-
- cvt2ps(jcp.bia_dt, vmm_bias, bias_addr, mask_flag);
- if (jcp.signed_input && jcp.ver != ver_vnni)
- /* bias *= 0.5 */
- vmulps(vmm_bias, vmm_bias, vmm_bias_alpha());
- }
- if (jcp.signed_input) {
- int comp_offset = sizeof(int32_t) * k * oc_block;
- auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset);
-
- cvt2ps(data_type::s32, vmm_comp, comp_addr, mask_flag);
- }
- /* add to zmm_accum: compensation, bias and permute */
- for (int j = 0; j < ur_w; j++) {
- Vmm vmm = vmm_out(j, k);
- if (jcp.is_fast_depthwise)
- vpermd(zmm_out(j, k), zmm_permute, zmm_out(j, k));
- vcvtdq2ps(vmm, vmm);
- if (jcp.signed_input)
- vaddps(vmm, vmm, vmm_comp);
- if (jcp.with_bias)
- vaddps(vmm, vmm, vmm_bias);
-
- const Vmm vmm_k = vmm_mask(vmm, mask_flag);
- vmulps(vmm_k, vmm,
- EVEX_compress_addr(reg_ptr_scales, scale_offset));
- }
- }
-
- /* Do post-ops */
- if (maybe_eltwise(0)) compute_eltwise(ur_w);
- if (p_sum_scale) { // post_op: sum
- for (int k = 0; k < nb_oc_block; k++) {
- const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
- for (int j = 0; j < ur_w; j++) {
- int aux_output_offset
- = jcp.typesize_out
- * (k * oc_block
- + j * jcp.oc_without_padding * jcp.ngroups);
- auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
- Vmm vmm = vmm_out(j, k);
- cvt2ps(jcp.dst_dt, vmm_prev_dst, addr, mask_flag);
- if (*p_sum_scale == 1.f)
- vaddps(vmm, vmm_prev_dst);
- else
- vfmadd231ps(vmm, vmm_prev_dst, zword_b[reg_ptr_sum_scale]);
- }
- }
- }
- if (maybe_eltwise(1)) compute_eltwise(ur_w);
-
- /* write out register to output_addr */
- for (int k = 0; k < nb_oc_block; k++) {
- const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
- for (int j = 0; j < ur_w; j++) {
- Vmm vmm = vmm_out(j, k);
- if (jcp.dst_dt == data_type::u8) {
- vpxord(vmm_zero, vmm_zero, vmm_zero);
- vmaxps(vmm, vmm_zero, vmm);
- }
-
- if (jcp.dst_dt != data_type::f32) {
- /* Note: using Zmm for rounding in Xmm/Ymm kernel
- because there is no instruction to do rounding
- from Xmm/Ymm -> Xmm/Ymm.
- Embedded rounding is not supported for Xmm.
- TODO: maybe avoid Zmm if it helps performance.*/
- Zmm zmm = zmm_out(j, k);
- vcvtps2dq(zmm, zmm);
- }
- }
-
- for (int j = 0; j < ur_w; j++) {
- int aux_output_offset = jcp.typesize_out
- * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups);
- auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
-
- Vmm vmm = vmm_out(j, k);
- const Vmm r_vmm = vmm_mask(vmm, mask_flag, true);
-
- switch (jcp.dst_dt) {
- case data_type::f32:
- case data_type::s32: vmovups(addr, r_vmm); break;
- case data_type::s8: vpmovsdb(addr, r_vmm); break;
- case data_type::u8: vpmovusdb(addr, r_vmm); break;
- default: assert(!"unknown dst_dt");
- }
- }
- }
-
-}
-
-template <typename Vmm>
-void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_ker_dw(
- int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
- assert(!"invalid group blocking for depthwise convolution");
-}
-
-template <>
-void _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::compute_ker_dw(
- int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
-
- auto input_spatial_index = [=](int oi, int ki) {
- return (ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l);
- };
-
- auto input_offset2 = [=](int ii, int ci) {
- return jcp.typesize_in * (ii * jcp.ngroups + ci * jcp.ch_block);
- };
-
- auto input_offset3 = [=](int oi, int ci, int ki) {
- return jcp.typesize_in * input_offset2(input_spatial_index(oi, ki), ci);
- };
-
- auto kernel_offset = [=](int ci, int ki) {
- return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block);
- };
-
- auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
- // okay for depthwise since src is zero-extended
- if (jcp.ver == ver_vnni) {
- vpdpbusd(vreg_acc, vreg_src, vreg_wei);
- } else {
- vpmaddwd(zmm_tmp, vreg_src, vreg_wei);
- vpaddd(vreg_acc, vreg_acc, zmm_tmp);
- }
- };
-
- int ii_start = 0;
- int ii_end = -1;
- if (jcp.is_resrc_depthwise && !h_padded) {
- // find bounds of input spatial indices
- bool first = true;
- for (int ki = 0; ki < jcp.kw; ki++) {
- int oi_start = get_ow_start(ki, pad_l);
- int oi_end = get_ow_end(ur_w, ki, pad_r);
- for (int oi = oi_start; oi < oi_end; oi++) {
- int ii = input_spatial_index(oi, ki);
- if (first || ii < ii_start)
- ii_start = ii;
- if (first || ii > ii_end)
- ii_end = ii;
- first = false;
- }
- }
- }
-
- if (jcp.signed_input) {
- vpxord(zmm_shifted_zero, zmm_shifted_zero, zmm_shifted_zero);
- vpaddb(zmm_shifted_zero, zmm_shifted_zero, vmm_shift);
- }
- for (int ci = 0; ci < jcp.nb_ch_blocking; ci++) {
- const bool mask_flag = last_ic_block_flag != no_last_block
- && ci == jcp.nb_ch_blocking - 1;
- if (jcp.is_resrc_depthwise && !h_padded) {
- // now we can load input once and reuse up to jcp.kw times
- for (int ii = ii_start; ii <= ii_end; ii++) {
- int aux_input_offset = input_offset2(ii, ci);
- const Zmm zmm_inp_tmp = zmm_inp(ii, jcp.nb_ch_blocking);
- const Zmm zmm_inp_msk = mask_flag
- ? zmm_inp_tmp | ktail_mask | T_z
- : zmm_inp_tmp;
- if (jcp.is_fast_depthwise) {
- assert(!mask_flag);
- vbroadcasti32x4(zmm_inp_msk,
- EVEX_compress_addr(aux_reg_inp, aux_input_offset));
- } else {
- vpmovzxbd(zmm_inp_msk,
- EVEX_compress_addr(aux_reg_inp, aux_input_offset));
- }
- if (jcp.signed_input)
- vpaddb(zmm_inp_tmp, zmm_inp_tmp, vmm_shift);
- }
- }
- for (int ki = 0; ki < jcp.kw; ki++) {
- int aux_kernel_offset = kernel_offset(ci, ki);
- if (jcp.is_fast_depthwise) {
- vbroadcasti32x4(zmm_wei,
- EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
- vmovdqu8(zmm_wei | kblend_mask | T_z, zmm_wei);
- } else {
- vpmovsxbd(zmm_wei,
- EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
- }
- if (h_padded) {
- assert(jcp.signed_input);
- for (int oi = 0; oi < ur_w; oi++)
- compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero);
- } else {
- const Zmm r_zmm_src = mask_flag ? zmm_src | ktail_mask : zmm_src;
- int oi_start = get_ow_start(ki, pad_l);
- int oi_end = get_ow_end(ur_w, ki, pad_r);
- int start_ = jcp.signed_input ? 0 : oi_start;
- int end_ = jcp.signed_input ? ur_w : oi_end;
- for (int oi = start_; oi < end_; oi++) {
- if (oi >= oi_start && oi < oi_end) {
- if (jcp.is_resrc_depthwise) {
- int ii = input_spatial_index(oi, ki);
- zmm_src = zmm_inp(ii, jcp.nb_ch_blocking);
- } else {
- int aux_input_offset = input_offset3(oi, ci, ki);
- if (jcp.is_fast_depthwise) {
- assert(!mask_flag);
- vbroadcasti32x4(r_zmm_src,
- EVEX_compress_addr(aux_reg_inp,
- aux_input_offset));
- } else {
- vpmovzxbd(r_zmm_src,
- EVEX_compress_addr(aux_reg_inp,
- aux_input_offset));
- }
- if (jcp.signed_input)
- vpaddb(zmm_src, zmm_src, vmm_shift);
- }
- } else if (jcp.signed_input) {
- zmm_src = zmm_shifted_zero;
- }
- compute(zmm_out(oi, ci), zmm_wei, zmm_src);
- }
- }
- }
- }
-}
-
-template<typename Vmm>
-void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_ker(int ur_w, int pad_l,
- int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
- if (jcp.is_depthwise)
- return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag, h_padded);
-
- int kw = jcp.kw;
- int stride_w = jcp.stride_w;
- int ic_block = jcp.ic_block;
- int oc_block = jcp.oc_block;
- int ch_block_all = jcp.ch_block * ic_block * oc_block;
-
- int nb_oc_block = jcp.nb_oc_blocking;
-
- auto input_offset = [=](int oi, int ic, int ki) {
- return jcp.typesize_in
- * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l)
- * jcp.ic_without_padding * jcp.ngroups + 4 * ic);
- };
- auto kernel_offset = [=](int ii, int ic, int ki) {
- return jcp.typesize_in
- * ((ii * jcp.nb_ic * jcp.kh * jcp.kw + ki) * ch_block_all
- + 4 * ic * oc_block);
- };
- auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) {
- if (jcp.ver == ver_vnni) {
- vpdpbusd(vreg_acc, vreg_src, vreg_wei);
- } else {
- vpmaddubsw(vmm_tmp, vreg_src, vreg_wei);
- vpmaddwd(vmm_tmp, vmm_tmp, vmm_one);
- vpaddd(vreg_acc, vreg_acc, vmm_tmp);
- }
- };
-
- for (int ki = 0; ki < kw; ki++) {
- int jj_start = get_ow_start(ki, pad_l);
- int jj_end = get_ow_end(ur_w, ki, pad_r);
- int tail_size = jcp.ic_without_padding % 4;
- int _start = (jcp.signed_input) ? 0 : jj_start;
- int _end = (jcp.signed_input) ? ur_w : jj_end;
- /* Skip the last loads of input if (ic%16)/4 < ic_block/4 */
- int icb = (last_ic_block_flag != no_last_block)
- ? div_up((jcp.ic_without_padding % ic_block), 4)
- : ic_block / 4;
- for (int ic = 0; ic < icb; ic++) {
- if (h_padded == true) {
- /* fill padded area with shifted values */
- Vmm inp = vmm_inp(0,nb_oc_block);
- vpxord(inp, inp, inp);
- vpaddb(inp, inp, vmm_shift);
- } else {
- for (int jj = _start; jj < _end; jj++) {
- int aux_input_offset = input_offset(jj, ic, ki);
- if (jj >= jj_start && jj < jj_end) {
- if (last_ic_block_flag == last_sp_block
- && tail_size != 0 && ic == icb - 1) {
- Xmm xmm_tmp = Xmm(vmm_inp(jj, nb_oc_block).getIdx());
- for (int r = 0; r < tail_size; ++r)
- vpinsrb(xmm_tmp, xmm_tmp,
- ptr[aux_reg_inp + aux_input_offset + r], r);
- vpbroadcastd(vmm_inp(jj, nb_oc_block), xmm_tmp);
- } else {
- vpbroadcastd(vmm_inp(jj, nb_oc_block),
- EVEX_compress_addr(
- aux_reg_inp, aux_input_offset));
- }
- if (jcp.signed_input)
- vpaddb(vmm_inp(jj, nb_oc_block),
- vmm_inp(jj, nb_oc_block), vmm_shift);
- } else {
- /* fill padded area with shifted values */
- if (jcp.signed_input) {
- Vmm inp = vmm_inp(jj, nb_oc_block);
- vpxord(inp, inp, inp);
- vpaddb(inp, inp, vmm_shift);
- }
- }
- }
- }
- for (int ii = 0; ii < nb_oc_block; ii++) {
- int aux_kernel_offset = kernel_offset(ii, ic, ki);
- vmovups(vmm_wei,
- EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
- for (int jj = _start; jj < _end; jj++) {
- Vmm inp = (h_padded == true)
- ? vmm_inp(0,nb_oc_block) : vmm_inp(jj, nb_oc_block);
- compute(vmm_out(jj, ii), vmm_wei, inp);
- }
- }
- }
- }
-}
-
-template<typename Vmm>
-void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::kh_loop(
- int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) {
- Label kh_label, skip_kh_loop;
- Label t_overflow_label, no_t_overflow_label,
- b_overflow_label, no_b_overflow_label;
-
- int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
- int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all;
- int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
- * jcp.ic_without_padding * jcp.ngroups;
-
- mov(aux_reg_inp, reg_inp);
- mov(aux_reg_ker, reg_ker);
-
- if (jcp.signed_input && jcp.ndims > 3) {
- mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
- cmp(reg_overflow, 0);
- je(no_t_overflow_label, T_NEAR);
- L(t_overflow_label); {
- compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
-
- add(aux_reg_ker, shift_kernel_ptr);
- dec(reg_overflow);
- cmp(reg_overflow, 0);
- jg(t_overflow_label, T_NEAR);
- }
- L(no_t_overflow_label);
- }
- mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
- if ((jcp.signed_input) || (!jcp.signed_input &&
- (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) {
- cmp(reg_kj, 0);
- je(skip_kh_loop, T_NEAR);
- }
- L(kh_label); {
- compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false);
-
- add(aux_reg_ker, shift_kernel_ptr);
- add(aux_reg_inp, shift_input_ptr);
- dec(reg_kj);
- cmp(reg_kj, 0);
- jg(kh_label, T_NEAR);
- }
- L(skip_kh_loop);
- if (jcp.signed_input && jcp.ndims > 3) {
- mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
- cmp(reg_overflow, 0);
- je(no_b_overflow_label, T_NEAR);
- L(b_overflow_label); {
- compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
-
- add(aux_reg_ker, shift_kernel_ptr);
- dec(reg_overflow);
- cmp(reg_overflow, 0);
- jg(b_overflow_label, T_NEAR);
- }
- L(no_b_overflow_label);
- }
-}
-
-template<typename Vmm>
-void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::icb_loop(
- int ur_w, int pad_l, int pad_r, bool is_last_sp_block)
-{
- prepare_output(ur_w);
-
- // IC loop
- Label icb_label;
- mov(reg_icb, jcp.nb_ic);
- L(icb_label);
- if (jcp.ngroups % jcp.ch_block != 0 || jcp.ic_without_padding != jcp.ic) {
- Label common_ker, end_ker;
-
- cmp(reg_icb, 1); // The last IC block
- jne(common_ker, T_NEAR);
-
- kh_loop(ur_w, pad_l, pad_r,
- is_last_sp_block ? last_sp_block : last_ic_block);
- jmp(end_ker, T_NEAR);
-
- L(common_ker);
- kh_loop(ur_w, pad_l, pad_r, no_last_block);
-
- L(end_ker);
- } else {
- kh_loop(ur_w, pad_l, pad_r, no_last_block);
- }
- // End of IC Loop
- int inp_step = jcp.ic_block;
- int ker_step = jcp.kh * jcp.kw * jcp.oc_block * jcp.ic_block;
- add(reg_inp, jcp.typesize_in * inp_step);
- add(reg_ker, jcp.typesize_in * ker_step);
-
- dec(reg_icb);
- cmp(reg_icb, 0);
- jg(icb_label, T_NEAR);
-
- sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic);
- sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic);
-
- if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
- Label common_store, end_store;
-
- if (jcp.is_depthwise)
- cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking);
- else
- cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
-
- jne(common_store, T_NEAR);
-
- store_output(ur_w, true); // last oc block
- jmp(end_store, T_NEAR);
-
- L(common_store);
- store_output(ur_w, false);
-
- L(end_store);
- } else {
- store_output(ur_w, false);
- }
-}
-
-template<typename Vmm>
-void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::generate()
-{
- Label permute_index_table;
- int inp_shift_pad = jcp.typesize_in * (jcp.ur_w * jcp.stride_w - jcp.l_pad)
- * jcp.ic_without_padding * jcp.ngroups;
- int inp_shift_pad_second_block = -1 * jcp.typesize_in * jcp.l_pad
- * jcp.ic_without_padding * jcp.ngroups;
- int inp_shift = jcp.typesize_in *
- (jcp.ur_w * jcp.stride_w * jcp.ic_without_padding
- * jcp.ngroups);
- int out_shift = jcp.typesize_out *
- (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups);
- preamble();
-
- if (jcp.is_depthwise) {
- int idx = jcp.max_regs_ur - 1;
- if (!jcp.is_resrc_depthwise)
- zmm_src = Zmm(++idx);
- if (jcp.ver != ver_vnni)
- zmm_tmp = Zmm(++idx);
- if (jcp.is_fast_depthwise)
- zmm_permute = Zmm(++idx);
- if (jcp.signed_input) {
- zmm_shifted_zero = Zmm(++idx);
- ++idx; // due to extra register used for shifts and compensations
- }
- assert(idx == ker_dw_reg_base_idx);
- }
-
- if (!jcp.is_depthwise && jcp.ver != ver_vnni) {
- xor_(reg_scratch, reg_scratch);
- Reg16 _t16 = reg_scratch.cvt16();
- mov(_t16, 0x1);
- vpbroadcastw(vmm_one, _t16);
- }
-
- mov(reg_inp, ptr[param1 + GET_OFF(src)]);
- mov(reg_out, ptr[param1 + GET_OFF(dst)]);
- mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
-
- if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
- int tail_size = jcp.is_depthwise
- ? jcp.ngroups % jcp.ch_block
- : jcp.oc_without_padding % jcp.oc_block;
- int mask = (1 << tail_size) - 1;
- mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
- Reg32 regw_tmp = reg_oi.cvt32();
- mov(regw_tmp, mask);
- kmovw(ktail_mask, regw_tmp);
- }
- if (jcp.is_fast_depthwise) {
- // prepare mask register for blending weights
- mov(reg_scratch, 0x8888444422221111);
- kmovq(kblend_mask, reg_scratch);
- // load permute indices from data section
- mov(reg_scratch, permute_index_table);
- vmovdqu32(zmm_permute, ptr[reg_scratch]);
- }
-
- int r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w
- + (jcp.kw - 1) * (jcp.dilate_w + 1)
- - (jcp.iw + jcp.l_pad - 1));
- int n_oi = jcp.ow / jcp.ur_w;
- int r_pad1 = (jcp.ur_w * n_oi - 1) * jcp.stride_w
- + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1);
-
- if (jcp.nb_ow == 1) {
- if (r_pad1 > 0 || jcp.ur_w_tail == 0)
- n_oi--;
-
- xor_(reg_oi, reg_oi);
- if (jcp.ow == jcp.ur_w) {
- icb_loop(jcp.ur_w, jcp.l_pad, r_pad, true);
- } else {
- if (n_oi == 0) {
- icb_loop(jcp.ur_w, jcp.l_pad, r_pad1, jcp.ur_w_tail == 0);
- add(reg_inp, inp_shift_pad);
- add(reg_out, out_shift);
- if (jcp.ur_w_tail != 0) {
- icb_loop(jcp.ur_w_tail, 0, r_pad, true);
- }
- } else {
- if (jcp.l_pad > 0) {
- icb_loop(jcp.ur_w, jcp.l_pad, 0, false);
- add(reg_inp, inp_shift_pad);
- add(reg_out, out_shift);
-
- inc(reg_oi);
- }
- if ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1))
- {
- Label ow_loop_label;
- L(ow_loop_label); {
- icb_loop(jcp.ur_w, 0, 0, false);
- add(reg_inp, inp_shift);
- add(reg_out, out_shift);
-
- inc(reg_oi);
- cmp(reg_oi, n_oi);
- jl(ow_loop_label, T_NEAR);
- }
- }
- if (r_pad1 > 0 || jcp.ur_w_tail == 0) {
- icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0);
- add(reg_inp, inp_shift);
- add(reg_out, out_shift);
- }
- if (jcp.ur_w_tail != 0) {
- icb_loop(jcp.ur_w_tail, 0, r_pad, true);
- }
- }
- }
- } else {
- // ow block is only processed.
- // Number of block is passed as parameter owb,
- // and padding processing depends on this number.
- Label end_label, last_oi_label, middle_ow_blocks_label, tail_label,
- oi_loop_label, oi_loop_end_label;
-
- assert(jcp.ow_block % jcp.ur_w == 0);
- int n_oi_not_last_ow_block = jcp.ow_block / jcp.ur_w;
- // to simplify code (and general regs usage),
- // size of ow block must be >= 2 * ur_w
- assert(n_oi_not_last_ow_block > 1);
- int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
- int n_oi_first_ow_block = n_oi_not_last_ow_block;
- int n_oi_last_ow_block
- = (jcp.ow - jcp.ow_block * (jcp.nb_ow - 1)) / jcp.ur_w;
- // prepare right padding
- bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
- bool first_ow_block_padded
- = next_last_ow_block_padded && jcp.nb_ow == 2;
- bool last_ow_block_padded
- = (r_pad1 > 0 || jcp.ur_w_tail == 0) && n_oi_last_ow_block > 0;
-
- if (last_ow_block_padded) n_oi_last_ow_block--;
- else if (first_ow_block_padded) n_oi_first_ow_block--;
- else if (next_last_ow_block_padded) n_oi_next_last_ow_block--;
-
- mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
- cmp(reg_owb, 0); // is that the first ow-block ?
- jg(middle_ow_blocks_label, T_NEAR);
-
- // the first ow block, compute left padding
- mov(reg_oi, n_oi_first_ow_block);
- if (jcp.l_pad > 0) {
- icb_loop(jcp.ur_w, jcp.l_pad, 0, false);
- add(reg_inp, inp_shift_pad);
- add(reg_out, out_shift);
-
- dec(reg_oi);
- }
- jmp(oi_loop_label, T_NEAR);
-
- // middle or last ow block entry
- L(middle_ow_blocks_label);
-
- if (jcp.l_pad > 0) {
- // just to consider left padding, not compute
- add(reg_inp, inp_shift_pad_second_block);
- }
-
- // set number of iteration for oi-loop
- if (n_oi_last_ow_block != n_oi_not_last_ow_block) {
- cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
- mov(reg_oi, n_oi_last_ow_block);
- je(oi_loop_label, T_NEAR);
- }
-
- if (n_oi_next_last_ow_block != n_oi_not_last_ow_block) {
- cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
-
- mov(reg_oi, n_oi_next_last_ow_block);
- je(oi_loop_label, T_NEAR);
- }
- mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
-
- // oi loop w/o padding
- L(oi_loop_label); {
- cmp(reg_oi, 0);
- jle(oi_loop_end_label, T_NEAR);
-
- icb_loop(jcp.ur_w, 0, 0, false);
-
- add(reg_inp, inp_shift);
- add(reg_out, out_shift);
- dec(reg_oi);
-
- jmp(oi_loop_label, T_NEAR);
- }
- L(oi_loop_end_label);
-
- mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
- cmp(reg_owb, 0); // first ow-block ?
- if (first_ow_block_padded)
- je(last_oi_label, T_NEAR);
- else
- je(end_label, T_NEAR);
-
- cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
- jl(end_label, T_NEAR);
- if (next_last_ow_block_padded)
- je(last_oi_label, T_NEAR);
- else
- je(end_label, T_NEAR);
-
- // that is last block
- if (!last_ow_block_padded)
- jmp(tail_label, T_NEAR);
-
- // last oi block with right padding
- L(last_oi_label);
- icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0);
- add(reg_inp, inp_shift);
- add(reg_out, out_shift);
-
- mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
- cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
- jl(end_label, T_NEAR);
-
- // ur_w tail
- L(tail_label);
- if (jcp.ur_w_tail != 0) {
- icb_loop(jcp.ur_w_tail, 0, r_pad, true);
- }
- L(end_label);
- }
- postamble();
-
- if (jcp.with_eltwise)
- eltwise_injector_->prepare_table();
-
- if (jcp.is_fast_depthwise) {
- align(64);
- L(permute_index_table);
- const uint32_t _idx[]
- = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 };
- for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
- dd(_idx[i]);
- }
-}
-
-bool jit_avx512_core_x8s8s32x_fwd_kernel::post_ops_ok(
- jit_conv_conf_t &jcp, const primitive_attr_t &attr)
-{
- using namespace primitive_kind;
- const auto &p = attr.post_ops_;
-
- auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
-
- switch (p.len_) {
- case 0: return true;
- case 1: return is_eltwise(0) || p.contain(sum, 0);
- case 2: return (p.contain(sum, 0) && is_eltwise(1)) ||
- (p.contain(sum, 1) && is_eltwise(0));
- default: return false;
- }
-
- return false;
-}
-
-status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd, memory_desc_t &src_md,
- memory_desc_t &weights_md, memory_desc_t &dst_md,
- memory_desc_t &bias_md, const primitive_attr_t &attr,
- int nthreads)
-{
- using namespace prop_kind;
-
- const memory_desc_wrapper src_d(&src_md);
- const memory_desc_wrapper weights_d(&weights_md);
- const memory_desc_wrapper dst_d(&dst_md);
- const memory_desc_wrapper bias_d(&bias_md);
-
- const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
- int ndims = src_d.ndims();
- bool is_1d = ndims == 3;
-
- if (!(mayiuse(avx512_core)
- && one_of(src_d.data_type(), data_type::u8, data_type::s8)
- && weights_d.data_type() == data_type::s8
- && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
- data_type::s8, data_type::u8)))
- return status::unimplemented;
-
- jcp = zero<decltype(jcp)>();
- jcp.ndims = ndims;
- jcp.prop_kind = cd.prop_kind;
- jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
- jcp.mb = src_d.dims()[0];
- jcp.oc = dst_d.dims()[1] / jcp.ngroups;
- jcp.oc_without_padding = jcp.oc;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
- jcp.ic_without_padding = jcp.ic;
- jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
- jcp.iw = src_d.dims()[ndims - 1];
- jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
- jcp.ow = dst_d.dims()[ndims - 1];
- jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
- jcp.kw = weights_d.dims()[with_groups + ndims - 1];
- jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
- jcp.l_pad = cd.padding[0][ndims - 3];
- jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
- jcp.stride_w = cd.strides[ndims - 3];
- jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
-
- jcp.ur_h = 1; /* no code-unrolling by h so far */
-
- jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
- jcp.dilate_w = cd.dilates[ndims - 3];
-
- jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
- jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
-
- if (jcp.is_depthwise) {
- jcp.ch_block = 16;
- jcp.ic_block = 1;
- jcp.oc_block = 1;
- } else {
- jcp.ch_block = 1;
- jcp.ic_block = 16;
- jcp.oc_block = 16;
-
- if (jcp.ngroups == 1) {
- /* For non grouped convolutions, pad channels by 16 if needed */
- jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
- jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
- } else if (!is_1d && jcp.ngroups != 1 && jcp.ic % jcp.ic_block != 0) {
- /* For grouped convolutions, MKL-DNN doesn't support padding.
- Use Ymm when channels per group is multiple of 8,
- Xmm when channels per group is multiple of 4 */
- jcp.ic_block = jcp.ic % 8 == 0 ? 8 : 4;
- jcp.oc_block = jcp.ic_block;
- }
- if (jcp.ic % jcp.ic_block !=0 || jcp.oc % jcp.oc_block != 0)
- return status::unimplemented;
- }
-
- jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
- - (jcp.ih + jcp.t_pad - 1);
-
- if (!post_ops_ok(jcp, attr))
- return status::unimplemented;
-
- const auto &p = attr.post_ops_;
- const int eltwise_ind = p.find(primitive_kind::eltwise);
- jcp.with_eltwise = eltwise_ind != -1;
- if (jcp.with_eltwise)
- jcp.eltwise = p.entry_[eltwise_ind].eltwise;
-
- jcp.ver = mayiuse(avx512_core_vnni) ? ver_vnni : ver_avx512_core;
- jcp.is_fast_depthwise = true && jcp.is_depthwise && jcp.ver == ver_vnni
- && jcp.ngroups % jcp.ch_block == 0; // for groups not multiple of 16
- // would require byte masking
- // for load from src
- jcp.is_resrc_depthwise = jcp.is_depthwise && jcp.stride_w < jcp.kw
- && jcp.kw < 4 && jcp.dilate_w == 0;
- if (jcp.is_depthwise) {
- jcp.max_regs_ur = 31 - jcp.is_fast_depthwise - !jcp.is_resrc_depthwise
- - 2 * jcp.signed_input - (jcp.ver != ver_vnni);
- } else {
- jcp.max_regs_ur = jcp.ver == ver_vnni ? 31 : 28;
- }
-
- auto set_or_check_wei_format = [&]() {
- using namespace format_tag;
- format_tag_t wei_tag;
- if (jcp.ic_block == 16 || jcp.ch_block == 16) {
- if (is_1d) {
- wei_tag = with_groups
- ? jcp.is_depthwise ? Goiw16g : gOIw4i16o4i
- : OIw4i16o4i;
- } else {
- wei_tag = with_groups
- ? jcp.is_depthwise ? Goihw16g : gOIhw4i16o4i
- : OIhw4i16o4i;
- }
- } else if (with_groups && jcp.ic_block == 8) {
- wei_tag = gOIhw2i8o4i;
- } else
- wei_tag = gOIhw4o4i;
-
- memory_desc_t want_wei_md = weights_md;
- memory_desc_init_by_tag(want_wei_md, wei_tag);
- if (jcp.signed_input) {
- want_wei_md.extra.flags = 0
- | memory_extra_flags::compensation_conv_s8s8
- | memory_extra_flags::scale_adjust;
- want_wei_md.extra.compensation_mask = (1 << 0)
- + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0);
- want_wei_md.extra.scale_adjust =
- mayiuse(avx512_core_vnni) ? 1.f : 0.5f;
- }
-
- if (weights_md.format_kind == format_kind::any) {
- weights_md = want_wei_md;
- return true;
- }
-
- return weights_md == want_wei_md;
- };
-
- if (!set_or_check_wei_format())
- return status::unimplemented;
-
- format_tag_t dat_tag = utils::pick(ndims - 3,
- format_tag::nwc, format_tag::nhwc);
-
- if (src_d.format_kind() == format_kind::any) {
- CHECK(memory_desc_init_by_tag(src_md, dat_tag));
- jcp.src_tag = dat_tag;
- } else {
- jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
- }
- if (jcp.src_tag != dat_tag)
- return status::unimplemented;
-
- if (dst_d.format_kind() == format_kind::any) {
- CHECK(memory_desc_init_by_tag(dst_md, dat_tag));
- jcp.dst_tag = dat_tag;
- } else {
- jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
- }
- if (jcp.dst_tag != dat_tag)
- return status::unimplemented;
-
- if (jcp.with_bias) {
- if (bias_d.format_kind() == format_kind::any)
- CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
- }
-
- jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
- jcp.dst_dt = cd.dst_desc.data_type;
-
- jcp.typesize_in = types::data_type_size(src_d.data_type());
- jcp.typesize_out = types::data_type_size(dst_d.data_type());
- jcp.typesize_bia = jcp.with_bias
- ? types::data_type_size(bias_d.data_type())
- : 0;
-
- jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
- jcp.nb_ic = jcp.ic / jcp.ic_block;
- jcp.nb_oc = jcp.oc / jcp.oc_block;
-
- // Try to use 4 channel-groups at a time to avoid false sharing (depthwise)
- int nb_ch_blocking = 4;
- for ( /* init above */ ; nb_ch_blocking > 1; nb_ch_blocking--)
- if (jcp.nb_ch % nb_ch_blocking == 0)
- break;
- jcp.nb_ch_blocking = jcp.is_depthwise ? nb_ch_blocking : 1;
-
- // If OC blocking is incommensurate with the number of OC blocks (general
- // requirement for all convolutions), or if it results in an unrolling
- // factor smaller than the left padding (special requirement for SSD:fc6),
- // then search for a smaller OC blocking that satisfies both constraints.
- auto is_oc_blocking_ok = [&](int block) {
- int ur_w = nstl::min(jcp.ow, jcp.max_regs_ur / (block + 1));
- return jcp.nb_oc % block == 0
- && jcp.l_pad <= ur_w && jcp.ow % ur_w != 1;
- };
-
- // choose nb_oc work chunk size for distribution within threads
- int max_threading_nb_oc_chunk = 4;
- // Performance improvements for googlenet_v3 and resnet_50 with mb = 1;
- // TODO: generalize this condition and rewrite it in appropriate manner
- if (jcp.ver == ver_vnni && jcp.mb == 1 && jcp.kh == 3 && jcp.kw == 3
- && jcp.stride_w == 1 && jcp.ic % 64 == 0)
- max_threading_nb_oc_chunk = 2;
- jcp.nb_oc_blocking_thr_chunk =
- nstl::min(max_threading_nb_oc_chunk, jcp.nb_oc);
- for (; jcp.nb_oc_blocking_thr_chunk > 1; jcp.nb_oc_blocking_thr_chunk--) {
- if (is_oc_blocking_ok(jcp.nb_oc_blocking_thr_chunk))
- break;
- }
-
- // choose oc blocking for computational kernel
- jcp.nb_oc_blocking = jcp.nb_oc_blocking_thr_chunk;
- // Performance improvements for googlenet_v3 with mb = 1;
- // TODO: generalize this condition and rewrite it in appropriate manner
- const int size_treshold_for_nb_oc_blocking_reduction = 17;
- if (jcp.mb == 1 && jcp.ow <= size_treshold_for_nb_oc_blocking_reduction
- && jcp.stride_w == 1
- && !(jcp.kh == 1 && jcp.kw == 3)
- && !(jcp.kh >= 7 && jcp.oc % 64 == 0)) {
- const int max_nb_oc_blocking = 2;
- jcp.nb_oc_blocking = nstl::min(max_nb_oc_blocking, jcp.nb_oc);
- for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--)
- if (jcp.nb_oc_blocking_thr_chunk % jcp.nb_oc_blocking == 0
- && is_oc_blocking_ok(jcp.nb_oc_blocking))
- break;
- }
-
- if (jcp.is_resrc_depthwise)
- jcp.ur_w = (jcp.max_regs_ur - jcp.kw + jcp.stride_w)
- / (jcp.nb_ch_blocking + jcp.stride_w);
- else
- jcp.ur_w
- = jcp.max_regs_ur / (jcp.is_depthwise ? jcp.nb_ch_blocking
- : jcp.nb_oc_blocking + 1);
- if (jcp.ow < jcp.ur_w)
- jcp.ur_w = jcp.ow;
- jcp.ur_w_tail = jcp.ow % jcp.ur_w;
-
- jcp.ow_block = jcp.ow;
- int base_work_amount = jcp.mb * jcp.nb_ch * jcp.oh
- * (jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk);
- float best_thr_eff
- = (float)base_work_amount / rnd_up(base_work_amount, nthreads);
- int max_nb_ow = div_up(jcp.ow, 2 * jcp.ur_w);
- for (int nb_ow = 1; nb_ow <= max_nb_ow; nb_ow++) {
- int ow_block
- = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), jcp.ur_w), jcp.ow);
- if (ow_block < jcp.nb_oc_blocking_thr_chunk * jcp.oc_block
- && best_thr_eff > 0.8f)
- break;
- if (div_up(jcp.ow, ow_block) != nb_ow)
- continue;
- auto work_amount = base_work_amount * nb_ow;
- float thr_eff = (float)work_amount / rnd_up(work_amount, nthreads);
- if (ow_block >= 2 * jcp.ur_w && thr_eff > 1.1f * best_thr_eff) {
- jcp.ow_block = ow_block;
- best_thr_eff = thr_eff;
- }
- if (best_thr_eff > 0.9f)
- break;
- }
- jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
-
- bool args_ok = true
- && jcp.oc % jcp.oc_block == 0
- && jcp.l_pad <= jcp.ur_w
- && IMPLICATION(!jcp.is_1stconv, jcp.ic % jcp.ic_block == 0);
- if (!args_ok)
- return status::unimplemented;
-
- int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
- + (jcp.kw - 1) * (jcp.dilate_w + 1)
- - (jcp.iw + jcp.l_pad - 1));
- if (r_pad_no_tail > jcp.ur_w)
- return status::unimplemented;
-
- pick_loop_order(jcp, nthreads);
-
- jcp.nb_ic_L2 = jcp.nb_ic;
-
- const auto &oscales = attr.output_scales_;
- jcp.is_oc_scale = oscales.mask_ == 1 << 1;
-
- assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
-
- jcp.wei_adj_scale =
- (weights_d.extra().flags | memory_extra_flags::scale_adjust)
- ? weights_d.extra().scale_adjust : 1.f;
-
- return status::success;
-}
-
-void jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad(
- memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
- const primitive_attr_t &attr) {
- if (jcp.signed_input && jcp.ver != ver_vnni) {
- dim_t count = nstl::max(attr.output_scales_.count_, (dim_t)jcp.ic_block);
- scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
- }
-}
-
-template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>;
-template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Ymm>;
-template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Xmm>;
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp
deleted file mode 100644
index d8a05ad53e..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp
+++ /dev/null
@@ -1,239 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP
-#define CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-
-#include "jit_generator.hpp"
-#include "jit_primitive_conf.hpp"
-#include "jit_uni_eltwise.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template<typename Vmm>
-struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_x8s8s32x_conv_fwd_ker_t)
-
- enum { STATE_FIRST_DST_LOAD = 0x1U };
-
- _jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp,
- const primitive_attr_t &attr) : jcp(ajcp), attr_(attr),
- eltwise_injector_(nullptr)
- {
- if (jcp.with_eltwise)
- eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>(
- this, jcp.eltwise);
-
- generate();
- jit_ker_ = (void (*)(jit_conv_call_s *))getCode();
- }
-
- ~_jit_avx512_core_x8s8s32x_fwd_kernel() {
- delete eltwise_injector_;
- }
-
- jit_conv_conf_t jcp;
- const primitive_attr_t &attr_;
- void (*jit_ker_)(jit_conv_call_s *);
-
-private:
- jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_;
-
- enum {
- typesize = sizeof(float),
- ker_reg_base_idx = 28,
- ker_dw_reg_base_idx = 30,
- };
- typedef enum {
- no_last_block,
- last_ic_block,
- last_sp_block,
- } ic_block_t;
-
- /* data regs */
- const Xbyak::Reg64 reg_ptr_scales = rax;
- const Xbyak::Reg64 reg_inp = r8;
- const Xbyak::Reg64 reg_ker = r9;
- const Xbyak::Reg64 reg_out = r10;
- const Xbyak::Reg64 aux_reg_inp = r11;
- const Xbyak::Reg64 reg_ptr_sum_scale = r11;
- const Xbyak::Reg64 aux_reg_ker = r12;
- const Xbyak::Reg64 reg_compensation = r14;
- /* counter regs */
- const Xbyak::Reg64 reg_bias_alpha = abi_not_param1;
- const Xbyak::Reg64 reg_oi = rbx;
- const Xbyak::Reg64 reg_bias = rdx;
- const Xbyak::Reg64 reg_oc_blocks = rsi;
- const Xbyak::Reg64 reg_owb = aux_reg_ker;
- const Xbyak::Reg64 reg_scratch = reg_compensation;
- const Xbyak::Reg64 reg_kj = reg_ptr_scales;
- const Xbyak::Reg64 reg_overflow = reg_ptr_scales;
- const Xbyak::Reg64 reg_icb = reg_bias;
-
- const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2);
- const Xbyak::Opmask kblend_mask = Xbyak::Opmask(3);
-
- const Vmm vmm_wei = Vmm(31);
- /* used during bias section of store_output */
- const Vmm vmm_comp = Vmm(30); // only for signed input
- const Vmm vmm_bias = Vmm(31);
- /* used during post_op sum section of store_output */
- const Vmm vmm_prev_dst = Vmm(31);
- /* used during write-out section of store_output */
- const Vmm vmm_zero = Vmm(31);
-
- /* used in compute_ker (but set during prepare_output) */
- const Vmm vmm_shift = vmm_comp; // only for signed input
- /* used in compute_ker (but only for pre-VNNI machines) */
- const Vmm vmm_tmp = Vmm(28); // not used for depthwise
- const Vmm vmm_one = Vmm(29); // set at start of kernel, not used for depthwise.
-
- /* registers use only for depthwise
- groups are always blocked by 16(padded if needed),
- hence use only Zmm registers */
- const Xbyak::Zmm zmm_wei = Xbyak::Zmm(31);
- Xbyak::Zmm zmm_tmp;
- Xbyak::Zmm zmm_src;
- Xbyak::Zmm zmm_shifted_zero;
- Xbyak::Zmm zmm_permute;
-
- Vmm vmm_out(int i_ur, int i_oc) {
- int idx = i_ur + i_oc * jcp.ur_w;
- assert(idx < (jcp.is_depthwise
- ? ker_dw_reg_base_idx : ker_reg_base_idx));
- return Vmm(idx);
- }
- Xbyak::Zmm zmm_out(int i_ur, int i_oc) {
- int idx = i_ur + i_oc * jcp.ur_w;
- assert(idx < (jcp.is_depthwise
- ? ker_dw_reg_base_idx : ker_reg_base_idx));
- return Xbyak::Zmm(idx);
- }
- Vmm vmm_inp(int i_ic, int nb_x_blocking) {
- int idx = i_ic + nb_x_blocking * jcp.ur_w;
- assert(idx < 31);
- return Vmm(idx);
- }
- Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) {
- int idx = i_ic + nb_x_blocking * jcp.ur_w;
- assert(idx < 31);
- return Xbyak::Zmm(idx);
- }
- Vmm vmm_bias_alpha() {
- int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
- return Vmm(nb_c_block * jcp.ur_w);
- }
- Xbyak::Xmm xmm_bias_alpha() {
- int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
- return Xbyak::Xmm(nb_c_block * jcp.ur_w);
- }
- int get_ow_start(int ki, int pad_l) {
- return nstl::max(0,
- utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
- }
- int get_ow_end(int ur_w, int ki, int pad_r) {
- return ur_w - nstl::max(0, utils::div_up(pad_r
- - (jcp.kw - 1 - ki)
- * (jcp.dilate_w + 1),
- jcp.stride_w));
- }
-
- bool maybe_eltwise(int position);
- void prepare_output(int ur_w);
- void store_output(int ur_w, bool last_oc_block_flag);
- void compute_ker_dw(
- int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded);
- void compute_ker(int ur_w, int pad_l, int pad_r,
- ic_block_t last_ic_block_flag, bool h_padded = false);
- void compute_eltwise(int ur_w);
- void kh_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag);
- void icb_loop(
- int ur_w, int pad_l, int pad_r, bool is_last_spatial_block);
- void generate();
- void cvt2ps(data_type_t type_in, Vmm ymm_in, const Xbyak::Operand &op,
- bool mask_flag);
- const Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store = false);
-};
-
-struct jit_avx512_core_x8s8s32x_fwd_kernel {
-
- jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp,
- const primitive_attr_t &attr) :
- jit_ker(nullptr),
- zmm_kernel_(nullptr),
- ymm_kernel_(nullptr),
- xmm_kernel_(nullptr) {
- int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block;
- switch (ch_block) {
- case 16:
- zmm_kernel_ =
- new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Zmm>(
- ajcp, attr);
- jit_ker = zmm_kernel_->jit_ker_;
- return;
- case 8:
- ymm_kernel_ =
- new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Ymm>(
- ajcp, attr);
- jit_ker = ymm_kernel_->jit_ker_;
- return;
- case 4:
- xmm_kernel_ =
- new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Xmm>(
- ajcp, attr);
- jit_ker = xmm_kernel_->jit_ker_;
- return;
- default:
- assert(!"invalid channel blocking");
- }
- }
-
- ~jit_avx512_core_x8s8s32x_fwd_kernel() {
- delete xmm_kernel_;
- delete ymm_kernel_;
- delete zmm_kernel_;
- }
-
- static bool post_ops_ok(jit_conv_conf_t &jcp,
- const primitive_attr_t &attr);
-
- static status_t init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd,
- memory_desc_t &src_pd,
- memory_desc_t &weights_pd,
- memory_desc_t &dst_pd,
- memory_desc_t &bias_pd,
- const primitive_attr_t &attr,
- int nthreads);
- static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
- const jit_conv_conf_t &jcp, const primitive_attr_t &attr);
-
- void (*jit_ker)(jit_conv_call_s *);
- _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Zmm> *zmm_kernel_;
- _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Ymm> *ymm_kernel_;
- _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Xmm> *xmm_kernel_;
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp
deleted file mode 100644
index cdbf333d5e..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp
+++ /dev/null
@@ -1,423 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "jit_avx512_core_x8s8s32x_convolution.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::memory_tracking::names;
-using namespace mkldnn::impl::utils;
-
-using namespace nstl;
-
-using jit_conv_ker_t = void (*)(jit_conv_call_s *);
-
-#define wht_blk_off(d, g, ...) \
- (pd()->with_groups() \
- ? (d).blk_off((g), __VA_ARGS__) \
- : (d).blk_off(__VA_ARGS__))
-
-template <data_type_t src_type, data_type_t dst_type>
-void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type,
- dst_type>::execute_forward_1d(const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
- const memory_desc_wrapper bias_d(pd()->weights_md(1));
-
- const size_t bia_dt_size = pd()->with_bias()
- ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
-
- const auto &jcp = pd()->jcp_;
- assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
- assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
-
- const float *oscales = pd()->attr()->output_scales_.scales_;
- if (jcp.signed_input && jcp.ver != ver_vnni) {
- auto local_scales = scratchpad(ctx).template get<float>(
- key_conv_adjusted_scales);
- size_t count = pd()->attr()->output_scales_.count_;
- float factor = 1.f / pd()->jcp_.wei_adj_scale;
- if (count == 1) {
- utils::array_set(local_scales, oscales[0] * factor, 16);
- } else {
- for (size_t c = 0; c < count; c++)
- local_scales[c] = oscales[c] * factor;
- }
- oscales = local_scales;
- }
-
- size_t offset = weights_d.size() - weights_d.additional_buffer_size();
- auto w = const_cast<wei_data_t *>(weights);
- int32_t* compensation = (jcp.signed_input)
- ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
- int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
- int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking;
- int group_block = jcp.ch_block;
- int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow;
-
- parallel(0, [&](const int ithr, const int nthr) {
-
- int start{ 0 }, end{ 0 };
- balance211(work_amount, nthr, ithr, start, end);
-
- auto p = jit_conv_call_s();
-
- int n{ 0 }, gg{ 0 }, occ{ 0 }, owb{ 0 };
- switch (jcp.loop_order) {
- case loop_cwgn:
- nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg,
- nb_groups, n, jcp.mb);
- break;
- case loop_gncw:
- nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks,
- owb, jcp.nb_ow);
- break;
- case loop_ngcw:
- nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ, oc_chunks,
- owb, jcp.nb_ow);
- break;
- case loop_nwcg:
- nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks,
- gg, nb_groups);
- break;
- default: assert(!"unsupported loop order");
- }
- while (start < end) {
- int ocb = occ * jcp.nb_oc_blocking;
- int gb = gg * jcp.nb_ch_blocking;
- int g = gb * group_block;
- int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block;
- int g_ic = g * jcp.nb_ic * jcp.ic_block;
- int ow_s = owb * jcp.ow_block;
- int iw_s = ow_s * jcp.stride_w;
-
- p.bias = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) : 0;
- p.compensation = (jcp.signed_input) ? compensation + g_oc : 0;
- p.dst = dst + dst_d.blk_off(n, g_oc, ow_s);
- p.src = src + src_d.blk_off(n, g_ic, iw_s);
- p.filt = weights + wht_blk_off(weights_d, gb, ocb, 0);
- p.scales = &oscales[jcp.is_oc_scale * g_oc];
- p.oc_blocks = jcp.is_depthwise ? gb : ocb;
- p.kh_padding = jcp.kh;
- p.t_overflow = 0;
- p.b_overflow = 0;
- p.owb = owb;
-
- kernel_->jit_ker(&p);
-
- ++start;
- switch (jcp.loop_order) {
- case loop_cwgn:
- nd_iterator_step(occ, oc_chunks, owb, jcp.nb_ow, gg, nb_groups,
- n, jcp.mb);
- break;
- case loop_gncw:
- nd_iterator_step(gg, nb_groups, n, jcp.mb, occ, oc_chunks, owb,
- jcp.nb_ow);
- break;
- case loop_ngcw:
- nd_iterator_step(n, jcp.mb, gg, nb_groups, occ, oc_chunks, owb,
- jcp.nb_ow);
- break;
- case loop_nwcg:
- nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, gg,
- nb_groups);
- break;
- default: assert(!"unsupported loop order");
- }
- }
- });
-}
-
-template <data_type_t src_type, data_type_t dst_type>
-void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type,
- dst_type>::execute_forward_2d(const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
- const memory_desc_wrapper bias_d(pd()->weights_md(1));
-
- const size_t bia_dt_size = pd()->with_bias()
- ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
-
- const auto &jcp = pd()->jcp_;
- assert(jcp.ch_block == 1);
- assert(jcp.nb_ch_blocking == 1);
- assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
- assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
-
- const float *oscales = pd()->attr()->output_scales_.scales_;
- if (jcp.signed_input && jcp.ver != ver_vnni) {
- auto local_scales = scratchpad(ctx).template get<float>(
- key_conv_adjusted_scales);
- size_t count = pd()->attr()->output_scales_.count_;
- float factor = 1.f / pd()->jcp_.wei_adj_scale;
- if (count == 1) {
- utils::array_set(local_scales, oscales[0] * factor, 16);
- } else {
- for (size_t c = 0; c < count; c++)
- local_scales[c] = oscales[c] * factor;
- }
- oscales = local_scales;
- }
-
- size_t offset = weights_d.size() - weights_d.additional_buffer_size();
- auto w = const_cast<wei_data_t *>(weights);
- int32_t* compensation = (jcp.signed_input)
- ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
- int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk;
- int nb_groups = jcp.nb_ch;
- int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow;
-
- parallel(0, [&](const int ithr, const int nthr) {
-
- int start{0}, end{0};
- balance211(work_amount, nthr, ithr, start, end);
-
- auto p = jit_conv_call_s();
-
- size_t src_h_stride = src_d.blk_off(0, 0, 1);
- size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
- size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
-
- int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 }, owb{ 0 };
- switch (jcp.loop_order) {
- case loop_cwgn:
- nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g,
- nb_groups, n, jcp.mb, oh_s, jcp.oh);
- break;
- case loop_ngcw:
- nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
- owb, jcp.nb_ow, oh_s, jcp.oh);
- break;
- case loop_nhwcg:
- nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow,
- occ, oc_chunks, g, nb_groups);
- break;
- default: assert(!"unsupported loop order");
- }
- while (start < end) {
- for (int occ1 = 0; occ1 < jcp.nb_oc_blocking_thr_chunk;
- occ1 += jcp.nb_oc_blocking) {
- int ocb = occ * jcp.nb_oc_blocking_thr_chunk + occ1;
- int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block;
-
- int g_ic = g * jcp.nb_ic * jcp.ic_block;
-
- int work_rem = end - start;
- int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
- int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
- if (jcp.loop_order == loop_nhwcg)
- oh_e = oh_s + 1; // step instead
- int ow_s = owb * jcp.ow_block;
- int iw_s = ow_s * jcp.stride_w;
-
- auto bias_w = bias
- ? bias + (bias_d.blk_off(g_oc) * bia_dt_size)
- : 0;
- int32_t *compensation_w = (jcp.signed_input)
- ? compensation + g_oc : 0;
-
- auto dst_w = dst + dst_d.blk_off(n, g_oc, oh_s, ow_s);
- auto src_w = src + src_d.blk_off(n, g_ic, ih_s, iw_s);
- auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
-
- auto scales = &oscales[jcp.is_oc_scale * g_oc];
-
- for (int oj = oh_s, ij = ih_s; oj < oh_e;
- ++oj, ij += jcp.stride_h) {
- int dilate_h = jcp.dilate_h + 1;
- int i_t_overflow = nstl::min(jcp.kh,
- div_up(max(0, -ij), dilate_h));
- int i_b_overflow = nstl::min(jcp.kh, div_up(
- max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + 1),
- dilate_h));
- int kh_padding = nstl::max(0,
- jcp.kh - i_t_overflow - i_b_overflow);
-
- size_t wei_stride = (!jcp.signed_input)
- ? i_t_overflow * wht_h_stride : 0;
- p.src = src_w + i_t_overflow * dilate_h * src_h_stride;
- p.dst = dst_w;
- p.filt = wht_w + wei_stride;
- p.bias = bias_w;
- p.compensation = compensation_w;
- p.oc_blocks = ocb;
- p.kh_padding = kh_padding;
- p.scales = scales;
- p.t_overflow = i_t_overflow;
- p.b_overflow = i_b_overflow;
- p.owb = owb;
-
- kernel_->jit_ker(&p);
- src_w += src_h_stride * jcp.stride_h;
- dst_w += dst_h_stride;
- }
- }
- switch (jcp.loop_order) {
- case loop_cwgn:
- nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, g,
- nb_groups, n, jcp.mb, oh_s, jcp.oh);
- break;
- case loop_ngcw:
- nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
- oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
- break;
- case loop_nhwcg:
- ++start;
- nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, occ,
- oc_chunks, g, nb_groups);
- break;
- default: assert(!"unsupported loop order");
- }
- }
- });
-}
-
-template <data_type_t src_type, data_type_t dst_type>
-void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type,
- dst_type>::execute_forward_2d_dw(const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
- const memory_desc_wrapper bias_d(pd()->weights_md(1));
-
- const size_t bia_dt_size = pd()->with_bias()
- ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
-
- const auto &jcp = pd()->jcp_;
- assert(jcp.ic_block == 1);
- assert(jcp.oc_block == 1);
- assert(jcp.nb_ic == 1);
- assert(jcp.nb_oc == 1);
- assert(jcp.nb_oc_blocking == 1);
- assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
-
- const float *oscales = pd()->attr()->output_scales_.scales_;
- if (jcp.signed_input && jcp.ver != ver_vnni) {
- auto local_scales = scratchpad(ctx).template get<float>(
- key_conv_adjusted_scales);
- size_t count = pd()->attr()->output_scales_.count_;
- float factor = 1.f / pd()->jcp_.wei_adj_scale;
- if (count == 1) {
- utils::array_set(local_scales, oscales[0] * factor, 16);
- } else {
- for (size_t c = 0; c < count; c++)
- local_scales[c] = oscales[c] * factor;
- }
- oscales = local_scales;
- }
-
- size_t offset = weights_d.size() - weights_d.additional_buffer_size();
- auto w = const_cast<wei_data_t *>(weights);
- int32_t* compensation = (jcp.signed_input)
- ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
- int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking;
- int group_block = jcp.ch_block;
-
- parallel_nd(jcp.mb, jcp.oh, jcp.nb_ow, nb_groups,
- [&](int n, int oh_s, int owb, int gg) {
-
- auto p = jit_conv_call_s();
-
- size_t src_h_stride = src_d.blk_off(0, 0, 1);
- size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
-
- int gb = gg * jcp.nb_ch_blocking;
- int g = gb * group_block;
-
- int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
- int ow_s = owb * jcp.ow_block;
- int iw_s = ow_s * jcp.stride_w;
-
- auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size) : 0;
- int32_t *compensation_w = jcp.signed_input ? compensation + g : 0;
-
- auto dst_w = dst + dst_d.blk_off(n, g, oh_s, ow_s);
- auto src_w = src + src_d.blk_off(n, g, ih_s, iw_s);
- auto wht_w = weights + wht_blk_off(weights_d, gb, 0);
-
- auto scales = &oscales[jcp.is_oc_scale * g];
-
- int dilate_h = jcp.dilate_h + 1;
- int i_t_overflow = nstl::min(jcp.kh, div_up(max(0, -ih_s), dilate_h));
- int i_b_overflow = nstl::min(jcp.kh,
- div_up(max(0, ih_s - jcp.ih + (jcp.kh - 1) * dilate_h + 1),
- dilate_h));
- int kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow);
-
- size_t wei_stride = jcp.signed_input ? 0 : i_t_overflow * wht_h_stride;
- p.src = src_w + i_t_overflow * dilate_h * src_h_stride;
- p.dst = dst_w;
- p.filt = wht_w + wei_stride;
- p.bias = bias_w;
- p.compensation = compensation_w;
- p.oc_blocks = gb;
- p.kh_padding = kh_padding;
- p.scales = scales;
- p.t_overflow = i_t_overflow;
- p.b_overflow = i_b_overflow;
- p.owb = owb;
-
- kernel_->jit_ker(&p);
- });
-}
-
-template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
- data_type::s8, data_type::u8>;
-template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
- data_type::u8, data_type::u8>;
-template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
- data_type::s8, data_type::s8>;
-template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
- data_type::u8, data_type::s8>;
-template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
- data_type::s8, data_type::s32>;
-template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
- data_type::u8, data_type::s32>;
-template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
- data_type::s8, data_type::f32>;
-template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
- data_type::u8, data_type::f32>;
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp
deleted file mode 100644
index 203ebdf942..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp
+++ /dev/null
@@ -1,115 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_CONVOLUTION_HPP
-#define CPU_JIT_AVX512_CORE_X8S8S32X_CONVOLUTION_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "mkldnn_thread.hpp"
-#include "utils.hpp"
-
-#include "cpu_convolution_pd.hpp"
-#include "cpu_primitive.hpp"
-
-#include "jit_avx512_core_x8s8s32x_conv_kernel.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <impl::data_type_t src_type, impl::data_type_t dst_type>
-struct jit_avx512_core_x8s8s32x_convolution_fwd_t : public cpu_primitive_t {
- struct pd_t : public cpu_convolution_fwd_pd_t {
- pd_t(engine_t *engine, const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const typename pd_t::base_class *hint_fwd_pd)
- : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_()
- {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_int8:", avx512_core, ""),
- jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type, dst_type>);
-
- status_t init() {
- bool ok = true
- && is_fwd()
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(src_type, data_type::s8, data_type::undef,
- dst_type, data_type::s32)
- && IMPLICATION(with_bias(), utils::one_of(bias_md_.data_type,
- data_type::f32, data_type::s32, data_type::s8,
- data_type::u8))
- && !has_zero_dim_memory();
- if (!ok) return status::unimplemented;
-
- status_t status = jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(
- jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_,
- *attr(), mkldnn_get_max_threads());
- if (status != status::success) return status;
-
- auto scratchpad = scratchpad_registry().registrar();
- jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad(scratchpad,
- jcp_, *attr());
-
- return status;
- }
-
- jit_conv_conf_t jcp_;
- };
-
- jit_avx512_core_x8s8s32x_convolution_fwd_t(const pd_t *apd)
- : cpu_primitive_t(apd)
- {
- kernel_ = new jit_avx512_core_x8s8s32x_fwd_kernel(pd()->jcp_,
- *pd()->attr());
- }
-
- ~jit_avx512_core_x8s8s32x_convolution_fwd_t() { delete kernel_; }
-
- typedef typename prec_traits<src_type>::type src_data_t;
- typedef typename prec_traits<data_type::s8>::type wei_data_t;
- typedef typename prec_traits<dst_type>::type dst_data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override
- {
- const auto &_pd = pd();
- if (_pd->ndims() == 3)
- execute_forward_1d(ctx);
- else if (_pd->jcp_.is_depthwise)
- execute_forward_2d_dw(ctx);
- else
- execute_forward_2d(ctx);
- return status::success;
- }
-
-private:
- void execute_forward_1d(const exec_ctx_t &ctx) const;
- void execute_forward_2d(const exec_ctx_t &ctx) const;
- void execute_forward_2d_dw(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_avx512_core_x8s8s32x_fwd_kernel *kernel_;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp
deleted file mode 100644
index 142af1f541..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp
+++ /dev/null
@@ -1,1034 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "jit_avx512_core_x8s8s32x_deconvolution.hpp"
-
-#define GET_OFF(field) offsetof(jit_deconv_call_s, field)
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::memory_tracking::names;
-using namespace mkldnn::impl::utils;
-using namespace Xbyak;
-
-using namespace nstl;
-
-#define wht_blk_off(d, g, ...) \
- (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) : \
- (d).blk_off(__VA_ARGS__))
-
-status_t jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf(
- jit_conv_conf_t &jcp, const deconvolution_desc_t &cd,
- memory_desc_t &src_md, memory_desc_t &weights_md,
- memory_desc_t &dst_md, const bool with_bias,
- memory_desc_t &bias_md, const primitive_attr_t &attr) {
- const memory_desc_wrapper src_d(&src_md);
- const memory_desc_wrapper dst_d(&dst_md);
- const memory_desc_wrapper weights_d(&weights_md);
- const memory_desc_wrapper bias_d(&bias_md);
-
- if (!(mayiuse(avx512_core)
- && one_of(src_d.data_type(), data_type::u8, data_type::s8)
- && weights_d.data_type() == data_type::s8
- && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
- data_type::s8, data_type::u8)))
- return status::unimplemented;
-
- jcp = zero<decltype(jcp)>();
-
- const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
- jcp.signed_input = src_d.data_type() == data_type::s8;
- const int ndims = jcp.ndims = dst_d.ndims();
- const bool is_1d = ndims == 3;
-
- jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
- jcp.oc = dst_d.dims()[1] / jcp.ngroups;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
- jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
- jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
- jcp.is_depthwise = true && with_groups
- && utils::everyone_is(1, jcp.ic_without_padding,
- jcp.oc_without_padding);
-
- /* TODO: future work, on hold until depthwise specialized kernel is
- * implemented. */
- if (jcp.is_depthwise && jcp.signed_input)
- return status::unimplemented;
-
- format_tag_t dat_tag = utils::pick(ndims - 3,
- format_tag::nwc, format_tag::nhwc);
-
- if (src_d.format_kind() == format_kind::any) {
- CHECK(memory_desc_init_by_tag(src_md, dat_tag));
- jcp.src_tag = dat_tag;
- } else {
- jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
- }
- if (jcp.src_tag != dat_tag)
- return status::unimplemented;
-
- if (dst_d.format_kind() == format_kind::any) {
- CHECK(memory_desc_init_by_tag(dst_md, dat_tag));
- jcp.dst_tag = dat_tag;
- } else {
- jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
- }
- if (jcp.dst_tag != dat_tag)
- return status::unimplemented;
-
- auto set_or_check_wei_format = [&]() {
- using namespace format_tag;
-
- format_tag_t wei_tag = is_1d
- ? (jcp.is_depthwise
- ? Goiw16g : (with_groups ? gOIw4i16o4i : OIw4i16o4i))
- : (jcp.is_depthwise
- ? Goihw16g : (with_groups ? gOIhw4i16o4i : OIhw4i16o4i));
-
- memory_desc_t want_wei_md = weights_md;
- memory_desc_init_by_tag(want_wei_md, wei_tag);
- if (jcp.signed_input && !jcp.is_depthwise) {
- want_wei_md.extra.flags = 0
- | memory_extra_flags::compensation_conv_s8s8
- | memory_extra_flags::scale_adjust;
- want_wei_md.extra.compensation_mask = (1 << 0)
- + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0);
- want_wei_md.extra.scale_adjust =
- mayiuse(avx512_core_vnni) ? 1.f : 0.5f;
- }
-
- if (weights_md.format_kind == format_kind::any) {
- weights_md = want_wei_md;
- return true;
- }
-
- return weights_md == want_wei_md;
- };
-
- if (!set_or_check_wei_format())
- return status::unimplemented;
-
- jcp.with_bias = with_bias;
- if (jcp.with_bias) {
- if (bias_d.format_kind() == format_kind::any)
- CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
- }
-
- jcp.prop_kind = cd.prop_kind;
- jcp.mb = src_d.dims()[0];
- jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
- jcp.iw = src_d.dims()[ndims - 1];
- jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
- jcp.ow = dst_d.dims()[ndims - 1];
- jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
- jcp.kw = weights_d.dims()[with_groups + ndims - 1];
- jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
- jcp.l_pad = cd.padding[0][ndims - 3];
- jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
- jcp.stride_w = cd.strides[ndims - 3];
-
- if (jcp.is_depthwise) {
- jcp.ch_block = 16;
- jcp.oc_block = 1;
- jcp.ic_block = 1;
- } else {
- jcp.ch_block = 1;
- jcp.oc_block = 16;
- jcp.ic_block = 16;
-
- if (jcp.ngroups == 1) {
- jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block);
- jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block);
- }
- if (jcp.ic % jcp.ic_block != 0)
- return status::unimplemented;
- }
-
- jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
- jcp.dilate_w = cd.dilates[ndims - 3];
-
- if (!IMPLICATION(jcp.dilate_h, jcp.stride_h == 1)
- || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1))
- return status::unimplemented;
-
- /* padding: bottom and right */
- jcp.b_pad = (jcp.ih - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
- - (jcp.oh + jcp.t_pad - 1);
- jcp.r_pad = (jcp.iw - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
- - (jcp.ow + jcp.l_pad - 1);
-
- if (!post_ops_ok(jcp, attr))
- return status::unimplemented;
-
- const auto &p = attr.post_ops_;
- const int eltwise_ind = p.find(primitive_kind::eltwise);
- jcp.with_eltwise = eltwise_ind != -1;
- if (jcp.with_eltwise)
- jcp.eltwise = p.entry_[eltwise_ind].eltwise;
-
- jcp.ver = ver_avx512_core;
- if (mayiuse(avx512_core_vnni))
- jcp.ver = ver_vnni;
- const auto &oscales = attr.output_scales_;
- jcp.is_oc_scale = oscales.mask_ == 1 << 1;
-
- assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
-
- jcp.dst_dt = dst_d.data_type();
- jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef;
- jcp.typesize_bia
- = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
- jcp.typesize_in = types::data_type_size(src_d.data_type());
- jcp.typesize_out = types::data_type_size(dst_d.data_type());
-
- jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
- jcp.nb_oc = jcp.oc / jcp.oc_block;
- jcp.nb_ic = jcp.ic / jcp.ic_block;
-
- /* kernel blocking params */
- const int regs = jcp.ver == ver_vnni ? 30 : 28;
- jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
- for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--)
- if (jcp.nb_oc % jcp.nb_oc_blocking == 0
- && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1))
- break;
-
- jcp.ur_w = regs / (jcp.nb_oc_blocking + 1);
- int l_overflow = max(
- 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
-
- if (jcp.ow < jcp.ur_w) {
- jcp.ur_w = jcp.ow;
- jcp.ur_w_tail = 0;
- } else {
- for (; jcp.ur_w >= 1; jcp.ur_w--) {
- /* ur_w should be multiple of stride_w in order
- to simplify logic for get_ow_start and get_ow_end */
- bool is_multiple_of_stride = jcp.ur_w % jcp.stride_w == 0;
-
- /* boundary conditions:
- These conditions ensure all elements close to boundary
- are computed in a single call of compute loop */
- bool left_boundary_covered = jcp.ur_w >= l_overflow * jcp.stride_w;
- jcp.ur_w_tail = jcp.ow % jcp.ur_w;
- int r_overflow_no_tail
- = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
- - max(0, jcp.r_pad) - jcp.ur_w_tail)
- / jcp.stride_w);
- bool right_boundary_covered
- = jcp.ur_w >= r_overflow_no_tail * jcp.stride_w;
-
- if (is_multiple_of_stride && left_boundary_covered
- && right_boundary_covered)
- break;
- else if (jcp.ur_w == 1)
- /* The boundary conditions above are also important
- to maintain simplicity of calls to icb_loop,
- if those conditions are not satisfied,
- then special cases will need to be added
- to use correct l_overflow/r_overflow values
- when different iterations of compute loop
- work on the locations close to boundary.
- So to keep code simple, return unimplemented
- for extreme case when a good ur_w cannot be found.
- */
- return status::unimplemented;
- }
- }
-
- jcp.wei_adj_scale =
- (weights_d.extra().flags | memory_extra_flags::scale_adjust)
- ? weights_d.extra().scale_adjust : 1.f;
-
- jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn;
- return status::success;
-}
-
-bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::maybe_eltwise(int position) {
- using namespace primitive_kind;
- const auto &p = attr_.post_ops_;
-
- if (position == 0) {
- /* eltwise before sum */
- return p.contain(eltwise, 0);
- } else if (position == 1) {
- /* eltwise after sum */
- return p.contain(sum, 0) && p.contain(eltwise, 1);
- }
- return false;
-}
-
-void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_eltwise(int ur_w) {
- int nb_oc_block
- = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
- eltwise_injector_->compute_vector_range(0, nb_oc_block * ur_w);
-}
-
-bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::post_ops_ok(
- jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
- using namespace primitive_kind;
- const auto &p = attr.post_ops_;
-
- auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
-
- switch (p.len_) {
- case 0: return true;
- case 1: return is_eltwise(0) || p.contain(sum, 0);
- case 2:
- return (p.contain(sum, 0) && is_eltwise(1))
- || (p.contain(sum, 1) && is_eltwise(0));
- default: return false;
- }
-
- return false;
-}
-
-void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad(
- memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
- const primitive_attr_t &attr) {
- if (jcp.signed_input && jcp.ver != ver_vnni) {
- dim_t count = nstl::max<dim_t>(attr.output_scales_.count_, 16);
- scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
- }
-}
-
-void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_ker(int ur_w,
- int l_overflow, int r_overflow, ker_block_t last_ic_block_flag,
- bool h_padded) {
-
- const int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
- const int ur_w_stride = jcp.signed_input ? 1 : jcp.stride_w;
-
- auto src_offset = [=](int oj, int icb, int ki) {
- return jcp.typesize_in
- * (((oj + jcp.l_pad - ki * (jcp.dilate_w + 1)) / jcp.stride_w)
- * jcp.ngroups * jcp.ic_without_padding
- + icb * 4);
- };
-
- auto kernel_offset = [=](int ocb, int icb, int ki) {
- return jcp.typesize_in
- * (ocb * jcp.nb_ic * jcp.kh * jcp.kw * ch_block_all
- + icb * jcp.oc_block * jcp.ic_block / 4
- + ki * ch_block_all);
- };
-
- auto compute = [=](zmm_t vreg_acc, zmm_t vreg_wei, zmm_t vreg_src) {
- if (jcp.ver == ver_vnni) {
- vpdpbusd(vreg_acc, vreg_src, vreg_wei);
- } else if (jcp.is_depthwise) {
- vpmulld(zmm_tmp, vreg_src, vreg_wei);
- vpaddd(vreg_acc, vreg_acc, zmm_tmp);
- } else {
- vpmaddubsw(zmm_tmp, vreg_src, vreg_wei);
- vpmaddwd(zmm_tmp, zmm_tmp, zmm_one);
- vpaddd(vreg_acc, vreg_acc, zmm_tmp);
- }
- };
-
- for (int ki = 0; ki < jcp.kw; ki++) {
-
- int jj_start = get_ow_start(ki, l_overflow);
- int jj_end = get_ow_end(ur_w, ki, r_overflow);
-
- int _start = (jcp.signed_input) ? 0 : jj_start;
- int _end = (jcp.signed_input) ? ur_w : jj_end;
-
- int tail_size = jcp.ic_without_padding % 4;
- int n_ic_blocks = jcp.is_depthwise ?
- 1 :
- (last_ic_block_flag & ~no_last_block ?
- div_up(jcp.ic_without_padding % jcp.ic_block,
- 4) :
- jcp.ic_block / 4);
-
- for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) {
- if (h_padded == true) {
- /* fill padded area with shifted values */
- Zmm inp = zmm_inp(0, jcp.nb_oc_blocking);
- vpxord(inp, inp, inp);
- vpsubb(inp, inp, zmm_shift);
- } else {
-
- for (int jj = _start; jj < _end; jj += ur_w_stride) {
-
- int aux_src_off = src_offset(jj, icb1, ki);
-
- if (jj >= jj_start && jj < jj_end
- && ((jj + jcp.l_pad - ki) % jcp.stride_w == 0)) {
- if (jcp.is_depthwise) {
- vpmovzxbd(zmm_inp(jj, jcp.nb_oc_blocking),
- EVEX_compress_addr(
- aux_reg_src, aux_src_off));
- } else if ((last_ic_block_flag & last_sp_block)
- && tail_size != 0 && icb1 == n_ic_blocks - 1) {
- xmm_t xmm_tmp = xmm_t(
- zmm_inp(jj, jcp.nb_oc_blocking).getIdx());
- for (int r = 0; r < tail_size; ++r)
- vpinsrb(xmm_tmp, xmm_tmp,
- ptr[aux_reg_src + aux_src_off + r], r);
- vpbroadcastd(
- zmm_inp(jj, jcp.nb_oc_blocking), xmm_tmp);
- } else {
- vpbroadcastd(zmm_inp(jj, jcp.nb_oc_blocking),
- EVEX_compress_addr(
- aux_reg_src, aux_src_off));
- }
- if (jcp.signed_input)
- vpsubb(zmm_inp(jj, jcp.nb_oc_blocking),
- zmm_inp(jj, jcp.nb_oc_blocking), zmm_shift);
- } else {
- /* fill padded area with shifted values */
- if (jcp.signed_input) {
- Zmm inp = zmm_inp(jj, jcp.nb_oc_blocking);
- vpxord(inp, inp, inp);
- vpsubb(inp, inp, zmm_shift);
- }
- }
- }
- }
- for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
- int aux_filt_off = kernel_offset(ocb, icb1, ki);
-
- if (_end - _start > 0) {
- if (jcp.is_depthwise)
- vpmovsxbd(zmm_wei,
- EVEX_compress_addr(aux_reg_filt, aux_filt_off));
- else
- vmovups(zmm_wei,
- EVEX_compress_addr(aux_reg_filt, aux_filt_off));
- }
- for (int jj = _start; jj < _end; jj += ur_w_stride) {
- Zmm inp = (h_padded == true) ?
- zmm_inp(0, jcp.nb_oc_blocking) :
- zmm_inp(jj, jcp.nb_oc_blocking);
- compute(zmm_out(jj, ocb), zmm_wei, inp);
- }
- }
- }
- }
-}
-
-void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::kh_loop(int ur_w,
- int l_overflow, int r_overflow, ker_block_t last_ic_block_flag) {
-
- int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
- int shift_src_ih = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
- * jcp.ngroups * jcp.ic_without_padding;
- const int stride_h = jcp.signed_input ? 1 : jcp.stride_h;
- int shift_filt_kh = jcp.typesize_in * jcp.kw * ch_block_all * stride_h;
-
- Label kh_loop_label, skip_kh_loop;
- Label t_overflow_label, no_t_overflow_label, b_overflow_label,
- no_b_overflow_label;
-
- mov(aux_reg_src, reg_src);
- mov(aux_reg_filt, reg_filt);
-
- if (jcp.signed_input && jcp.ndims > 3) {
- /* Weights are transposed, so first compute 'bottom' padding. */
- mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
- cmp(reg_overflow, 0);
- je(no_b_overflow_label, T_NEAR);
- L(b_overflow_label); {
- compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
-
- add(aux_reg_filt, shift_filt_kh);
- dec(reg_overflow);
- cmp(reg_overflow, 0);
- jg(b_overflow_label, T_NEAR);
- }
- L(no_b_overflow_label);
- }
-
- mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
-
- if (jcp.signed_input || ((!jcp.signed_input)
- && ((min(jcp.t_pad, jcp.b_pad) < 0)
- || ((jcp.kh - 1) * (jcp.dilate_h + 1)
- < nstl::max(jcp.t_pad, jcp.b_pad))))) {
- cmp(reg_kh, 0);
- je(skip_kh_loop, T_NEAR);
- }
-
- L(kh_loop_label); {
- compute_ker(ur_w, l_overflow, r_overflow, last_ic_block_flag, false);
- sub(aux_reg_src, shift_src_ih);
- add(aux_reg_filt, shift_filt_kh);
- dec(reg_kh);
-
- /* Insert weight compensation in stride 'holes' */
- if (jcp.signed_input && jcp.stride_h > 1) {
- Label kh_comp_loop;
-
- cmp(reg_kh, 0);
- je(skip_kh_loop, T_NEAR);
- mov(reg_comp_strides, jcp.stride_h - 1);
- L(kh_comp_loop);
- {
- compute_ker(
- ur_w, 0, 0, last_ic_block_flag, true);
- add(aux_reg_filt, shift_filt_kh);
- dec(reg_comp_strides);
- cmp(reg_comp_strides, 0);
- jg(kh_comp_loop, T_NEAR);
- }
- }
- cmp(reg_kh, 0);
- jg(kh_loop_label, T_NEAR);
- }
- L(skip_kh_loop);
- if (jcp.signed_input && jcp.ndims > 3) {
- mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
- cmp(reg_overflow, 0);
- je(no_t_overflow_label, T_NEAR);
- L(t_overflow_label); {
- compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
-
- add(aux_reg_filt, shift_filt_kh);
- dec(reg_overflow);
- cmp(reg_overflow, 0);
- jg(t_overflow_label, T_NEAR);
- }
- L(no_t_overflow_label);
- }
-}
-
-void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::prepare_output(int ur_w) {
- for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
- for (int ur = 0; ur < ur_w; ur++) {
- zmm_t zmm = zmm_out(ur, ocb);
- vpxord(zmm, zmm, zmm);
- }
- }
- if (jcp.signed_input) {
- xor_(reg_scratch, reg_scratch);
- Reg8 _t8 = reg_scratch.cvt8();
- mov(_t8, (int8_t)-128);
- vpbroadcastb(zmm_shift, _t8);
- }
-}
-
-void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::cvt2ps(
- data_type_t type_in, zmm_t zmm_in, const Operand &op, bool mask_flag) {
- zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in;
- switch (type_in) {
- case data_type::f32:
- case data_type::s32: vmovups(zmm, op); break;
- case data_type::s8: vpmovsxbd(zmm, op); break;
- case data_type::u8: vpmovzxbd(zmm, op); break;
- default: assert(!"unsupported data type");
- }
- if (type_in != data_type::f32)
- vcvtdq2ps(zmm_in, zmm_in);
-}
-
-void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::store_output(
- int ur_w, bool last_oc_block) {
- mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
- mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
-
- if (jcp.signed_input)
- mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]);
-
- const auto &p = attr_.post_ops_;
- const int sum_idx = p.find(primitive_kind::sum);
- const float *p_sum_scale
- = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr;
- if (p_sum_scale && *p_sum_scale != 1.f)
- mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
-
- if (jcp.with_bias && jcp.signed_input && jcp.ver != ver_vnni) {
- mov(reg_bias_alpha, float2int(jcp.wei_adj_scale));
- vmovq(xmm_bias_alpha(), reg_bias_alpha);
- vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha());
- }
-
- for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
- const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
- int scale_offset
- = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block);
-
- auto zmm_bias = zmm_tmp;
- if (jcp.with_bias) {
- int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block;
- auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
- cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
- if (jcp.signed_input && jcp.ver != ver_vnni)
- vmulps(zmm_bias, zmm_bias, zmm_bias_alpha());
- }
- if (jcp.signed_input) {
- int comp_offset = sizeof(int32_t) * ocb * jcp.oc_block;
- auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset);
- cvt2ps(data_type::s32, zmm_comp, comp_addr, mask_flag);
- }
-
- for (int ur = 0; ur < ur_w; ur++) {
- zmm_t zmm = zmm_out(ur, ocb);
- vcvtdq2ps(zmm, zmm);
- if (jcp.signed_input)
- vaddps(zmm, zmm, zmm_comp);
- if (jcp.with_bias)
- vaddps(zmm, zmm, zmm_bias);
- zmm_t mask_zmm = mask_flag ? zmm | ktail_mask | T_z : zmm;
- vmulps(mask_zmm, zmm,
- EVEX_compress_addr(reg_ptr_scales, scale_offset));
- }
- }
- if (maybe_eltwise(0))
- compute_eltwise(ur_w);
- if (p_sum_scale) { // post_op: sum
- for (int k = 0; k < jcp.nb_oc_blocking; k++) {
- const bool mask_flag
- = last_oc_block == 1 && k == jcp.nb_oc_blocking - 1;
- for (int j = 0; j < ur_w; j++) {
- int aux_output_offset
- = jcp.typesize_out
- * (k * jcp.oc_block
- + j * jcp.oc_without_padding * jcp.ngroups);
- auto addr = EVEX_compress_addr(reg_dst, aux_output_offset);
- Zmm zmm = zmm_out(j, k);
- cvt2ps(jcp.dst_dt, zmm_prev_dst, addr, mask_flag);
- if (*p_sum_scale == 1.f)
- vaddps(zmm, zmm_prev_dst);
- else
- vfmadd231ps(zmm, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
- }
- }
- }
- if (maybe_eltwise(1))
- compute_eltwise(ur_w);
-
- for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
- const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
- for (int ur = 0; ur < ur_w; ur++) {
- zmm_t zmm = zmm_out(ur, ocb);
- if (jcp.dst_dt == data_type::u8) {
- vpxord(zmm_zero, zmm_zero, zmm_zero);
- vmaxps(zmm, zmm_zero, zmm);
- }
- if (jcp.dst_dt != data_type::f32)
- vcvtps2dq(zmm, zmm);
- }
- for (int ur = 0; ur < ur_w; ur++) {
- int aux_dst_off = jcp.typesize_out
- * (ur * jcp.ngroups * jcp.oc_without_padding
- + ocb * jcp.oc_block);
- auto addr = EVEX_compress_addr(reg_dst, aux_dst_off);
-
- zmm_t zmm = zmm_out(ur, ocb);
- zmm_t r_zmm = mask_flag ? zmm | ktail_mask : zmm;
- switch (jcp.dst_dt) {
- case data_type::f32:
- case data_type::s32: vmovups(addr, r_zmm); break;
- case data_type::s8: vpmovsdb(addr, r_zmm); break;
- case data_type::u8: vpmovusdb(addr, r_zmm); break;
- default: assert(!"unknown dst_dt");
- }
- }
- }
-}
-
-void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::icb_loop(
- int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) {
-
- int shift_src_icb = jcp.typesize_in * jcp.ic_block;
- int shift_filt_icb
- = jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block;
-
- prepare_output(ur_w);
-
- Label skip_icb_loop, icb_loop_label;
-
- mov(reg_icb, jcp.nb_ic);
- L(icb_loop_label); {
-
- if (jcp.ic_without_padding != jcp.ic) {
- Label common_ker, end_ker;
- cmp(reg_icb, 1);
- jg(common_ker, T_NEAR);
-
- kh_loop(ur_w, l_overflow, r_overflow,
- is_last_sp_block ? last_sp_block : last_ic_block);
- jmp(end_ker, T_NEAR);
-
- L(common_ker);
- kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
-
- L(end_ker);
- } else {
- kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
- }
-
- add(reg_src, shift_src_icb);
- add(reg_filt, shift_filt_icb);
- dec(reg_icb);
- cmp(reg_icb, 0);
- jg(icb_loop_label, T_NEAR);
- }
-
- /* come-back pointers */
- sub(reg_src, jcp.nb_ic * shift_src_icb);
- sub(reg_filt, jcp.nb_ic * shift_filt_icb);
- L(skip_icb_loop);
-
- if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
- Label common_store, end_store;
- mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
- if (jcp.is_depthwise)
- cmp(reg_oc_blocks, jcp.nb_ch - 1);
- else
- cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
- jne(common_store, T_NEAR);
-
- store_output(ur_w, true);
- jmp(end_store, T_NEAR);
-
- L(common_store);
- store_output(ur_w, false);
-
- L(end_store);
-
- } else {
- store_output(ur_w, false);
- }
-}
-
-void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::generate() {
- preamble();
-
- xor_(reg_scratch, reg_scratch);
- Reg16 _t = reg_scratch.cvt16();
- mov(_t, 0x1);
- vpbroadcastw(zmm_one, _t);
-
- if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
- int tail_size = jcp.is_depthwise ?
- jcp.ngroups % jcp.ch_block :
- jcp.oc_without_padding % jcp.oc_block;
- int mask = (1 << tail_size) - 1;
- Reg32 regw_tmp = reg_nur_w.cvt32();
- mov(regw_tmp, mask);
- kmovw(ktail_mask, regw_tmp);
- }
-
- mov(reg_src, ptr[param1 + GET_OFF(src)]);
- mov(reg_filt, ptr[param1 + GET_OFF(filt)]);
- mov(reg_dst, ptr[param1 + GET_OFF(dst)]);
-
- int dst_shift = jcp.typesize_out * jcp.ur_w * jcp.ngroups
- * jcp.oc_without_padding;
- int src_shift = jcp.typesize_in * (jcp.ur_w / jcp.stride_w) * jcp.ngroups
- * jcp.ic_without_padding;
-
- int l_overflow = max(
- 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
- int r_overflow
- = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad))
- / jcp.stride_w);
-
- int r_overflow1
- = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
- - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail)
- / jcp.stride_w);
- int nur_w = jcp.ow / jcp.ur_w;
- if (r_overflow1 > 0)
- nur_w--;
-
- if (jcp.ur_w == jcp.ow) {
- icb_loop(jcp.ur_w, l_overflow, r_overflow, true);
- } else if (nur_w == 0) {
- icb_loop(jcp.ur_w, l_overflow, r_overflow1, jcp.ur_w_tail == 0);
- add(reg_src, src_shift);
- add(reg_dst, dst_shift);
- if (jcp.ur_w_tail != 0)
- icb_loop(jcp.ur_w_tail, 0, r_overflow, true);
- } else {
- xor_(reg_nur_w, reg_nur_w);
- if (l_overflow > 0) {
- icb_loop(jcp.ur_w, l_overflow, 0, false);
- add(reg_src, src_shift);
- add(reg_dst, dst_shift);
- inc(reg_nur_w);
- }
- if ((l_overflow <= 0 && nur_w > 0) || (l_overflow > 0 && nur_w > 1)) {
- Label ow_loop_label;
- L(ow_loop_label);
- {
- icb_loop(jcp.ur_w, 0, 0, false);
- add(reg_src, src_shift);
- add(reg_dst, dst_shift);
- inc(reg_nur_w);
- cmp(reg_nur_w, nur_w);
- jl(ow_loop_label, T_NEAR);
- }
- }
- if (r_overflow1 > 0) {
- icb_loop(jcp.ur_w, 0, r_overflow1, jcp.ur_w_tail == 0);
- add(reg_src, src_shift);
- add(reg_dst, dst_shift);
- }
- if (jcp.ur_w_tail != 0) {
- icb_loop(jcp.ur_w_tail, 0, r_overflow, true);
- }
- }
- postamble();
-
- if (jcp.with_eltwise)
- eltwise_injector_->prepare_table();
-}
-
-template <data_type_t src_type, data_type_t dst_type>
-void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<src_type,
- dst_type>::execute_forward_1d(const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
- const memory_desc_wrapper bias_d(pd()->weights_md(1));
-
- auto &jcp = kernel_->jcp;
-
- int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
- int nb_groups = jcp.nb_ch;
-
- const float *oscales = pd()->attr()->output_scales_.scales_;
- if (jcp.signed_input && jcp.ver != ver_vnni) {
- auto local_scales
- = scratchpad(ctx).template get<float>(key_conv_adjusted_scales);
- size_t count = pd()->attr()->output_scales_.count_;
- float factor = 1.f / pd()->jcp_.wei_adj_scale;
- if (count == 1) {
- utils::array_set(local_scales, oscales[0] * factor, 16);
- } else {
- for (size_t c = 0; c < count; c++)
- local_scales[c] = oscales[c] * factor;
- }
- oscales = local_scales;
- }
- size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
- auto w = const_cast<wei_data_t *>(weights);
- int32_t *compensation
- = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
-
- parallel(0, [&](const int ithr, const int nthr) {
- int start{ 0 }, end{ 0 };
- int work_amount = jcp.mb * nb_groups * oc_chunks;
- balance211(work_amount, nthr, ithr, start, end);
-
- auto p = jit_deconv_call_s();
-
- int n{ 0 }, g{ 0 }, occ{ 0 };
- if (jcp.loop_order == loop_ngc)
- nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks);
- else if (jcp.loop_order == loop_cgn)
- nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb);
- else
- assert(!"unsupported loop order");
- while (start < end) {
-
- int ocb = occ * jcp.nb_oc_blocking;
- int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
- int g_ic = g * jcp.ch_block * jcp.ic;
-
- p.dst = dst + dst_d.blk_off(n, g_oc);
- p.src = src + src_d.blk_off(n, g_ic);
- p.filt = weights + wht_blk_off(weights_d, g, ocb, 0);
- p.bias = jcp.with_bias ?
- bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) :
- 0;
- p.compensation = (jcp.signed_input) ? compensation + g_oc : 0;
- p.scales = &oscales[jcp.is_oc_scale * g_oc];
- p.t_overflow = 0;
- p.b_overflow = 0;
- p.kh_padding = jcp.kh;
- p.oc_blocks = jcp.is_depthwise ? g : ocb;
-
- kernel_->jit_ker(&p);
-
- ++start;
- if (jcp.loop_order == loop_ngc)
- nd_iterator_step(n, jcp.mb, g, nb_groups, occ, oc_chunks);
- else if (jcp.loop_order == loop_cgn)
- nd_iterator_step(occ, oc_chunks, g, nb_groups, n, jcp.mb);
- else
- assert(!"unsupported loop order");
- }
- });
-}
-
-template <data_type_t src_type, data_type_t dst_type>
-void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<src_type,
- dst_type>::execute_forward_2d(const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
- const memory_desc_wrapper bias_d(pd()->weights_md(1));
-
- auto &jcp = kernel_->jcp;
-
- int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
- int nb_groups = jcp.nb_ch;
-
- size_t src_h_stride = src_d.blk_off(0, 0, 1);
- size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
- size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
-
- const float *oscales = pd()->attr()->output_scales_.scales_;
- if (jcp.signed_input && jcp.ver != ver_vnni) {
- auto local_scales
- = scratchpad(ctx).template get<float>(key_conv_adjusted_scales);
- size_t count = pd()->attr()->output_scales_.count_;
- float factor = 1.f / pd()->jcp_.wei_adj_scale;
- if (count == 1) {
- utils::array_set(local_scales, oscales[0] * factor, 16);
- } else {
- for (size_t c = 0; c < count; c++)
- local_scales[c] = oscales[c] * factor;
- }
- oscales = local_scales;
- }
- size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
- auto w = const_cast<wei_data_t *>(weights);
- int32_t *compensation
- = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
-
- parallel(0, [&](const int ithr, const int nthr) {
- int start{ 0 }, end{ 0 };
- int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh;
- balance211(work_amount, nthr, ithr, start, end);
-
- auto p = jit_deconv_call_s();
-
- /*loop order = cgn*/
- int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 };
- if (jcp.loop_order == loop_ngc)
- nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
- oh_s, jcp.oh);
- else if (jcp.loop_order == loop_cgn)
- nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb,
- oh_s, jcp.oh);
- else
- assert(!"unsupported loop order");
- while (start < end) {
-
- int ocb = occ * jcp.nb_oc_blocking;
- int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
- int g_ic = g * jcp.ch_block * jcp.ic;
- int work_rem = end - start;
- int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
-
- auto dst_w = dst + dst_d.blk_off(n, g_oc);
- auto src_w = src + src_d.blk_off(n, g_ic);
- auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
- auto bias_w = jcp.with_bias ?
- bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) :
- 0;
- int32_t *compensation_w
- = (jcp.signed_input) ? compensation + g_oc : 0;
-
- auto scales = &oscales[jcp.is_oc_scale * g_oc];
- for (int oj = oh_s; oj < oh_e; oj++) {
- int ih_max = 0, kh_lo = 0, kh_len = 0;
- if (jcp.dilate_h != 0 && jcp.stride_h == 1) {
- /* dilation */
- int dilate_h = jcp.dilate_h + 1;
- // Note: use div_up to account for "holes" in filter
- int o_t_overflow = div_up(
- max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad),
- dilate_h);
- int o_b_overflow
- = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 - jcp.oh
- + oj - jcp.b_pad),
- dilate_h);
- kh_len = jcp.kh - o_t_overflow - o_b_overflow;
- kh_lo = o_b_overflow;
- ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h;
- } else {
- int o_t_overflow = max(
- 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h);
- int o_b_overflow
- = max(0, ((oj + jcp.kh) - (jcp.oh + jcp.b_pad))
- / jcp.stride_h);
- int overflow_kh_hi = jcp.kh - 1
- - abs(jcp.oh + jcp.b_pad - (oj + 1)) % jcp.stride_h;
- int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h;
-
- kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h
- + 1 - o_t_overflow - o_b_overflow;
- kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h;
- ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h;
- }
-
- int wei_stride
- = (!jcp.signed_input) ? kh_lo * wht_kh_stride : 0;
- p.src = src_w + ih_max * src_h_stride;
- p.dst = dst_w + oj * dst_h_stride;
- p.filt = wht_w + wei_stride;
- p.bias = bias_w;
- p.compensation = compensation_w;
- p.t_overflow = max(
- 0, jcp.kh - (kh_lo + max(0, kh_len - 1) * jcp.stride_h
- + 1));
- p.b_overflow = kh_lo;
- p.kh_padding = kh_len;
- p.scales = scales;
- p.oc_blocks = jcp.is_depthwise ? g : ocb;
- kernel_->jit_ker(&p);
- }
- if (jcp.loop_order == loop_ngc)
- nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
- oc_chunks, oh_s, jcp.oh);
- else if (jcp.loop_order == loop_cgn)
- nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n,
- jcp.mb, oh_s, jcp.oh);
- else
- assert(!"unsupported loop order");
- }
- });
-}
-
-template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
- data_type::u8>;
-template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
- data_type::s8>;
-template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
- data_type::f32>;
-template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
- data_type::s32>;
-template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
- data_type::u8>;
-template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
- data_type::s8>;
-template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
- data_type::f32>;
-template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
- data_type::s32>;
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp
deleted file mode 100644
index 901038fa48..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp
+++ /dev/null
@@ -1,237 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_AVX512_CORE_U8S8S32X_DECONVOLUTION_HPP
-#define CPU_JIT_AVX512_CORE_U8S8S32X_DECONVOLUTION_HPP
-
-#include "c_types_map.hpp"
-#include "cpu_primitive.hpp"
-#include "cpu_memory.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-#include "nstl.hpp"
-
-#include "cpu_deconvolution_pd.hpp"
-#include "jit_generator.hpp"
-#include "jit_primitive_conf.hpp"
-#include "jit_uni_eltwise.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-typedef enum {
- no_last_block = 0x1U,
- last_ic_block = 0x2U,
- last_sp_block = 0x4U,
-} ker_block_t;
-
-struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel : public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_deconv_fwd_ker_t);
-
- jit_avx512_core_x8s8s32x_deconv_fwd_kernel(
- const jit_conv_conf_t &ajcp, const primitive_attr_t &attr)
- : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) {
- if (jcp.with_eltwise)
- eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>(
- this, jcp.eltwise);
- generate();
- jit_ker = (void (*)(jit_deconv_call_s *))getCode();
- }
-
- ~jit_avx512_core_x8s8s32x_deconv_fwd_kernel() {
- delete eltwise_injector_;
- }
-
- static bool post_ops_ok(jit_conv_conf_t &jcp,
- const primitive_attr_t &attr);
-
- static status_t init_conf(jit_conv_conf_t &jcp,
- const deconvolution_desc_t &cd,
- memory_desc_t &src_md,
- memory_desc_t &weights_md,
- memory_desc_t &dst_md,
- const bool with_bias,
- memory_desc_t &bias_md,
- const primitive_attr_t &attr);
-
- static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
- const jit_conv_conf_t &jcp, const primitive_attr_t &attr);
-
- const jit_conv_conf_t &jcp;
- const primitive_attr_t &attr_;
- void (*jit_ker)(jit_deconv_call_s *);
-private:
- jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_;
- using reg64_t = const Xbyak::Reg64;
- using zmm_t = const Xbyak::Zmm;
- using xmm_t = const Xbyak::Xmm;
-
- reg64_t reg_src = r8;
- reg64_t reg_filt = r9;
- reg64_t reg_dst = r10;
- reg64_t param1 = abi_param1;
- reg64_t reg_kh = abi_not_param1;
- reg64_t reg_nur_w = rbx;
- reg64_t reg_bias = rdx;
- reg64_t reg_icb = reg_bias;
- reg64_t reg_ptr_scales = rax;
- reg64_t reg_oc_blocks = rsi;
-
- reg64_t aux_reg_src = r11;
- reg64_t aux_reg_filt = r12;
-
- reg64_t reg_compensation = r14;
- reg64_t reg_scratch = r14;
- reg64_t reg_ptr_sum_scale = r11;
- reg64_t reg_bias_alpha = abi_not_param1;
- reg64_t reg_overflow = rax;
- reg64_t reg_comp_strides = reg_overflow;
-
- Xbyak::Opmask ktail_mask = Xbyak::Opmask(2);
- zmm_t zmm_tmp = zmm_t(28);
- zmm_t zmm_one = zmm_t(29);
- /* used during write-out section of store_output */
- zmm_t zmm_zero = zmm_t(31);
- zmm_t zmm_wei = zmm_t(31);
-
- /* signed input */
- zmm_t zmm_shift = zmm_t(30);
- zmm_t zmm_comp = zmm_t(30);
- zmm_t zmm_bias = zmm_t(31);
- zmm_t zmm_prev_dst = zmm_t(31);
-
- zmm_t zmm_out(int i_ur, int i_oc) {
- int idx = i_ur * jcp.nb_oc_blocking + i_oc;
- assert(idx < 31);
- return zmm_t(idx);
- }
- zmm_t zmm_inp(int i_ic, int nb_x_blocking) {
- int idx = i_ic + nb_x_blocking * jcp.ur_w;
- assert(idx < 31);
- return zmm_t(idx);
- }
- zmm_t zmm_bias_alpha() {
- return zmm_t(jcp.nb_oc_blocking * jcp.ur_w);
- }
- xmm_t xmm_bias_alpha() {
- return xmm_t(jcp.nb_oc_blocking * jcp.ur_w);
- }
-
- int get_ow_start(int ki, int l_overflow) {
- int res = (jcp.ow - 1 + jcp.r_pad) % jcp.stride_w
- + l_overflow * jcp.stride_w
- - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1);
- while (res < 0)
- res += jcp.stride_w;
- return res;
- }
-
- int get_ow_end(int ur_w, int ki, int r_overflow) {
- if (utils::one_of(ur_w, jcp.ow, jcp.ur_w_tail))
- ur_w += nstl::min(0, jcp.r_pad); // remove negative padding
- int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w
- + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1);
- while (res < 0)
- res += jcp.stride_w;
- return ur_w - res;
- }
- bool maybe_eltwise(int position);
- void compute_eltwise(int ur_w);
- void prepare_output(int ur_w);
- void store_output(int ur_w, bool last_oc_block);
- void compute_ker(int ur_w, int l_overflow, int r_overflow,
- ker_block_t last_ic_block_flag, bool h_padded = false);
- void kh_loop(int ur_w, int pad_l, int pad_r, ker_block_t last_ker_block);
- void icb_loop(int ur_w, int pad_l, int pad_r, bool last_block);
- void generate();
- void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op,
- bool mask_flag);
-};
-
-template <impl::data_type_t src_type, impl::data_type_t dst_type>
-struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t : public cpu_primitive_t {
- struct pd_t : public cpu_deconvolution_fwd_pd_t {
- using cpu_deconvolution_fwd_pd_t::cpu_deconvolution_fwd_pd_t;
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_deconvolution:", avx512_core, ""),
- _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<src_type, dst_type>);
-
- status_t init() {
- bool ok = true
- && is_fwd()
- && (desc()->alg_kind & alg_kind::deconvolution_direct)
- && desc()->src_desc.data_type == src_type
- && desc()->dst_desc.data_type == dst_type
- && IMPLICATION(with_bias(), utils::one_of(
- desc()->bias_desc.data_type, data_type::f32,
- data_type::s32, data_type::s8, data_type::u8))
- && desc()->accum_data_type == data_type::s32;
- if (!ok) return status::unimplemented;
-
- status_t status = jit_avx512_core_x8s8s32x_deconv_fwd_kernel::
- init_conf(jcp_, *desc(), src_md_, weights_md_, dst_md_,
- with_bias(), bias_md_, *attr());
-
- if (status != status::success) return status;
-
- auto scratchpad = scratchpad_registry().registrar();
- jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad(scratchpad,
- jcp_, *attr());
-
- return status::success;
- }
-
- jit_conv_conf_t jcp_;
- };
-
- _jit_avx512_core_x8s8s32x_deconvolution_fwd_t(const pd_t *apd)
- : cpu_primitive_t(apd)
- {
- kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel(pd()->jcp_,
- *pd()->attr());
- }
-
- ~_jit_avx512_core_x8s8s32x_deconvolution_fwd_t() { delete kernel_; }
-
- typedef typename prec_traits<src_type>::type src_data_t;
- typedef typename prec_traits<data_type::s8>::type wei_data_t;
- typedef typename prec_traits<dst_type>::type dst_data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- if(pd()->ndims() == 3)
- execute_forward_1d(ctx);
- else
- execute_forward_2d(ctx);
- return status::success;
- }
-
-private:
- void execute_forward_1d(const exec_ctx_t &ctx) const;
- void execute_forward_2d(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- jit_avx512_core_x8s8s32x_deconv_fwd_kernel *kernel_;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp
deleted file mode 100644
index c09592d5c9..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp
+++ /dev/null
@@ -1,773 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_AVX2_GENERATOR_HPP
-#define CPU_JIT_AVX2_GENERATOR_HPP
-
-#include <limits.h>
-
-#include "mkldnn_thread.hpp"
-#include "utils.hpp"
-
-#include "cpu_isa_traits.hpp"
-#include "jit_utils/jit_utils.hpp"
-
-#if defined(_WIN32) && !defined(__GNUC__)
-# define STRUCT_ALIGN(al, ...) __declspec(align(al)) __VA_ARGS__
-#else
-# define STRUCT_ALIGN(al, ...) __VA_ARGS__ __attribute__((__aligned__(al)))
-#endif
-
-#if defined(_WIN32)
-# define OFFSET_SHADOWSPACE 0x28
-#endif
-
-#define DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_name) \
- const char *name() const override { return STRINGIFY(jit_name); } \
- const char *source_file() const override { return __FILE__; }
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-// TODO: move this to jit_generator class?
-namespace {
-
-typedef enum {
- PAGE_4K = 4096,
- PAGE_2M = 2097152,
-} cpu_page_size_t;
-
-// TODO: move this somewhere else? Although this is only used by jit kernels
-// (Roma)
-static inline int float2int(float x) {
- union {
- float vfloat;
- int vint;
- } cvt;
- cvt.vfloat = x;
- return cvt.vint;
-}
-
-// TODO: A GPR class that hides ABI details from the JIT kernels and allows
-// numbering registers from 0 to 14 (x86_64) / 6 (x32) (gpr0, gpr1, ...) and
-// stack register (sr).
-//
-// This will allow using syntax like this:
-//
-// param = gpr0;
-// reg_input = gpr0;
-// reg_output = gpr1;
-// ...
-//
-// #ifndef XBYAK64
-// mov(param, ptr[sr])
-// #endif
-//
-// (Roma)
-
-#ifdef XBYAK64
-constexpr Xbyak::Operand::Code abi_save_gpr_regs[] = {
- Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12,
- Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15,
-#ifdef _WIN32
- Xbyak::Operand::RDI, Xbyak::Operand::RSI,
-#endif
-};
-
-#ifdef _WIN32
-static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RCX),
- abi_param2(Xbyak::Operand::RDX),
- abi_param3(Xbyak::Operand::R8),
- abi_param4(Xbyak::Operand::R9),
- abi_not_param1(Xbyak::Operand::RDI);
-#else
-static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RDI),
- abi_param2(Xbyak::Operand::RSI),
- abi_param3(Xbyak::Operand::RDX),
- abi_param4(Xbyak::Operand::RCX),
- abi_param5(Xbyak::Operand::R8),
- abi_param6(Xbyak::Operand::R9),
- abi_not_param1(Xbyak::Operand::RCX);
-#endif
-#endif
-
-inline unsigned int get_cache_size(int level, bool per_core = true){
- unsigned int l = level - 1;
- // Currently, if XByak is not able to fetch the cache topology
- // we default to 32KB of L1, 512KB of L2 and 1MB of L3 per core.
- if (cpu.getDataCacheLevels() == 0){
- const int L1_cache_per_core = 32000;
- const int L2_cache_per_core = 512000;
- const int L3_cache_per_core = 1024000;
- int num_cores = per_core ? 1 : mkldnn_get_max_threads();
- switch(l){
- case(0): return L1_cache_per_core * num_cores;
- case(1): return L2_cache_per_core * num_cores;
- case(2): return L3_cache_per_core * num_cores;
- default: return 0;
- }
- }
- if (l < cpu.getDataCacheLevels()) {
- return cpu.getDataCacheSize(l)
- / (per_core ? cpu.getCoresSharingDataCache(l) : 1);
- } else
- return 0;
-}
-
-}
-
-class jit_generator : public Xbyak::CodeGenerator
-{
-private:
- const size_t xmm_len = 16;
-#ifdef _WIN32
- const size_t xmm_to_preserve_start = 6;
- const size_t xmm_to_preserve = 10;
-#else
- const size_t xmm_to_preserve_start = 0;
- const size_t xmm_to_preserve = 0;
-#endif
-
- const size_t num_abi_save_gpr_regs
- = sizeof(abi_save_gpr_regs) / sizeof(abi_save_gpr_regs[0]);
-
- const size_t size_of_abi_save_regs
- = num_abi_save_gpr_regs * rax.getBit() / 8
- + xmm_to_preserve * xmm_len;
-
-public:
- enum {
- _cmp_eq_oq = 0u,
- _cmp_lt_os = 1u,
- _cmp_le_os = 2u,
- _cmp_neq_uq = 4u,
- _cmp_nlt_us = 5u,
- _cmp_nle_us = 6u,
-
- _op_floor = 1u,
- _op_mxcsr = 4u,
- };
-
- Xbyak::Reg64 param1 = abi_param1;
- const int EVEX_max_8b_offt = 0x200;
- const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp;
-
- inline size_t get_size_of_abi_save_regs() {
- return size_of_abi_save_regs;
- }
-
- void preamble() {
- if (xmm_to_preserve) {
- sub(rsp, xmm_to_preserve * xmm_len);
- for (size_t i = 0; i < xmm_to_preserve; ++i)
- movdqu(ptr[rsp + i * xmm_len], Xbyak::Xmm(xmm_to_preserve_start + i));
- }
- for (size_t i = 0; i < num_abi_save_gpr_regs; ++i)
- push(Xbyak::Reg64(abi_save_gpr_regs[i]));
- if (mayiuse(avx512_common)) {
- mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
- }
- }
-
- void mic_prefetcht0(Xbyak::Address a) {
- if (mayiuse(avx512_mic))
- prefetcht0(a);
- }
-
- void mic_prefetcht1(Xbyak::Address a) {
- if (mayiuse(avx512_mic))
- prefetcht1(a);
- }
-
- void mic_prefetcht2(Xbyak::Address a) {
- if (mayiuse(avx512_mic))
- prefetcht2(a);
- }
-
- void uni_vzeroupper() {
- if (mayiuse(avx) && !mayiuse(avx512_mic))
- vzeroupper();
- }
-
- void postamble() {
- for (size_t i = 0; i < num_abi_save_gpr_regs; ++i)
- pop(Xbyak::Reg64(abi_save_gpr_regs[num_abi_save_gpr_regs - 1 - i]));
- if (xmm_to_preserve) {
- for (size_t i = 0; i < xmm_to_preserve; ++i)
- movdqu(Xbyak::Xmm(xmm_to_preserve_start + i), ptr[rsp + i * xmm_len]);
- add(rsp, xmm_to_preserve * xmm_len);
- }
- uni_vzeroupper();
- ret();
- }
-
- template<typename T>
- Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base,
- T raw_offt, bool bcast = false)
- {
- using Xbyak::Zmm;
- using Xbyak::Reg64;
- using Xbyak::Address;
- using Xbyak::RegExp;
-
- assert(raw_offt <= INT_MAX);
- auto offt = static_cast<int>(raw_offt);
-
- int scale = 0;
-
- if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) {
- offt = offt - 2 * EVEX_max_8b_offt;
- scale = 1;
- } else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) {
- offt = offt - 4 * EVEX_max_8b_offt;
- scale = 2;
- }
-
- auto re = RegExp() + base + offt;
- if (scale)
- re = re + reg_EVEX_max_8b_offt * scale;
-
- if (bcast)
- return zword_b [re];
- else
- return zword [re];
- }
-
- Xbyak::Address make_safe_addr(const Xbyak::Reg64 &reg_out, size_t offt,
- const Xbyak::Reg64 &tmp_reg, bool bcast = false) {
- if (offt > INT_MAX) {
- mov(tmp_reg, offt);
- return bcast ? ptr_b[reg_out + tmp_reg] : ptr[reg_out + tmp_reg];
- } else {
- return bcast ? ptr_b[reg_out + offt] : ptr[reg_out + offt];
- }
- }
-
- Xbyak::Address EVEX_compress_addr_safe(const Xbyak::Reg64 &base,
- size_t raw_offt, const Xbyak::Reg64 &reg_offt, bool bcast = false) {
- if (raw_offt > INT_MAX) {
- return make_safe_addr(base, raw_offt, reg_offt, bcast);
- } else {
- return EVEX_compress_addr(base, raw_offt, bcast);
- }
- }
-
- void safe_add(const Xbyak::Reg64 &base, size_t raw_offt,
- const Xbyak::Reg64 &reg_offt) {
- if (raw_offt > INT_MAX) {
- mov(reg_offt, raw_offt);
- add(base, reg_offt);
- } else {
- add(base, raw_offt);
- }
- }
-
- void safe_sub(const Xbyak::Reg64 &base, size_t raw_offt,
- const Xbyak::Reg64 &reg_offt) {
- if (raw_offt > INT_MAX) {
- mov(reg_offt, raw_offt);
- sub(base, reg_offt);
- } else {
- sub(base, raw_offt);
- }
- }
-
- // Disallow char-based labels completely
- void L(const char *label) = delete;
- void L(Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); }
-
- void uni_vpxor(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
- const Xbyak::Operand &op) {
- assert(x1.getIdx() == x2.getIdx());
- pxor(x2, op);
- }
- void uni_vpxor(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
- const Xbyak::Operand &op) {
- if (mayiuse(avx2)) {
- vpxor(x1, x2, op);
- } else {
- vxorps(x1, x2, op);
- }
- }
- void uni_vpxor(const Xbyak::Zmm &x1, const Xbyak::Zmm &x2,
- const Xbyak::Operand &op) {
- vpxord(x1, x2, op);
- }
-
- void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Xmm &x) {
- movss(addr, x);
- }
- void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Ymm &x) {
- vmovss(addr, x);
- }
- void uni_vmovss(const Xbyak::Xmm &x, const Xbyak::Address& addr) {
- movss(x, addr);
- }
- void uni_vmovss(const Xbyak::Ymm &x, const Xbyak::Address& addr) {
- vmovss(x, addr);
- }
-
- void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Xmm &x) {
- movsd(addr, x);
- }
- void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Ymm &x) {
- vmovsd(addr, x);
- }
- void uni_vmovsd(const Xbyak::Xmm &x, const Xbyak::Address& addr) {
- movsd(x, addr);
- }
- void uni_vmovsd(const Xbyak::Ymm &x, const Xbyak::Address& addr) {
- vmovsd(x, addr);
- }
-
- void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
- movdqu(addr, x);
- }
- void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
- vmovdqu(addr, x);
- }
- void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Zmm &x) {
- vmovdqu32(addr, x);
- }
-
- void uni_vmovdqu(const Xbyak::Xmm &x, const Xbyak::Address &addr) {
- movdqu(x, addr);
- }
- void uni_vmovdqu(const Xbyak::Ymm &x, const Xbyak::Address &addr) {
- vmovdqu(x, addr);
- }
- void uni_vmovdqu(const Xbyak::Zmm &x, const Xbyak::Address &addr) {
- vmovdqu32(x, addr);
- }
-
- void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
- movups(addr, x);
- }
- void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
- vmovups(addr, x);
- }
-
- void uni_vmovups(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
- movups(x, op);
- }
- void uni_vmovups(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
- vmovups(x, op);
- }
-
- void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
- movntps(addr, x);
- }
- void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
- vmovntps(addr, x);
- }
-
- void uni_vbroadcastss(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
- movss(x, op);
- shufps(x, x, 0x0);
- }
- void uni_vbroadcastss(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
- if (op.isMEM() || mayiuse(avx2)) {
- vbroadcastss(x, op);
- } else {
- Xbyak::Xmm t(x.getIdx());
- if (t.getIdx() != op.getIdx()) movss(t, op);
- vinsertf128(x, x, t, 1);
- vshufps(x, x, x, 0);
- }
- }
-
- void uni_vpbroadcastd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
- movsd(x, op);
- pshufd(x, x, 0x0);
- }
- void uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
- if (mayiuse(avx2)) {
- vpbroadcastd(x, op);
- } else {
- Xbyak::Xmm t(x.getIdx());
- if (t.getIdx() != op.getIdx()) movsd(t, op);
- vinsertf128(x, x, t, 1);
- vshufps(x, x, x, 0);
- }
- }
-
- void uni_vrcpss(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
- rcpss(x, op);
- }
- void uni_vrcpss(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2) {
- Xbyak::Xmm x1_(x1.getIdx());
- Xbyak::Xmm x2_(x2.getIdx());
- vrcpss(x1_, x1_, x2_);
- }
- void uni_vrcpss(const Xbyak::Ymm &x, const Xbyak::Address &op) {
- Xbyak::Xmm x_(x.getIdx());
- vrcpss(x_, x_, op);
- }
-
- void uni_vrcpps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
- rcpps(x, op);
- }
- void uni_vrcpps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
- vrcpps(x, op);
- }
- void uni_vrcpps(const Xbyak::Zmm &x, const Xbyak::Operand &op) {
- vrcp14ps(x, op);
- }
-
- void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
- const Xbyak::Operand &op2 = Xbyak::Operand()) {
- assert(x.getIdx() == op1.getIdx());
- divps(x, op2);
- }
- void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
- const Xbyak::Operand &op2 = Xbyak::Operand()) {
- vdivps(x, op1, op2);
- }
-
- void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
- const Xbyak::Operand &op2, const Xbyak::Xmm &buf) {
- movups(buf, op1);
- divps(buf, op2);
- if (x.getIdx() != buf.getIdx()) {
- movups(x, buf);
- }
- }
-
- void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
- const Xbyak::Operand &op2, const Xbyak::Ymm &buf) {
- vdivps(x, op1, op2);
- }
-
- void uni_vaddps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
- const Xbyak::Operand &op2 = Xbyak::Operand()) {
- assert(x.getIdx() == op1.getIdx());
- addps(x, op2);
- }
- void uni_vaddps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
- const Xbyak::Operand &op2 = Xbyak::Operand()) {
- vaddps(x, op1, op2);
- }
-
- void uni_vpsignd(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2,
- const Xbyak::Operand& op) {
- assert(x1.getIdx() == x2.getIdx());
- psignd(x1, op);
- }
- void uni_vpsignd(const Xbyak::Ymm& x1, const Xbyak::Ymm& x2,
- const Xbyak::Operand& op) {
- vpsignd(x1, x2, op);
- }
-
- void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
- const Xbyak::Operand &op2 = Xbyak::Operand()) {
- assert(x.getIdx() == op1.getIdx());
- subps(x, op2);
- }
- void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
- const Xbyak::Operand &op2 = Xbyak::Operand()) {
- vsubps(x, op1, op2);
- }
-
- void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
- const Xbyak::Operand &op2, const Xbyak::Xmm &buf) {
- movups(buf, op1);
- subps(buf, op2);
- if (x.getIdx() != buf.getIdx()) {
- movups(x, buf);
- }
- }
-
- void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
- const Xbyak::Operand &op2, const Xbyak::Ymm &buf) {
- vsubps(x, op1, op2);
- }
-
- void uni_vmulps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
- const Xbyak::Operand &op2 = Xbyak::Operand()) {
- assert(x.getIdx() == op1.getIdx());
- mulps(x, op2);
- }
- void uni_vmulps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
- const Xbyak::Operand &op2 = Xbyak::Operand()) {
- vmulps(x, op1, op2);
- }
-
- void uni_vfmadd213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
- const Xbyak::Operand &op) {
- mulps(x1, x2);
- addps(x1, op);
- }
- void uni_vfmadd213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
- const Xbyak::Operand &op) {
- vfmadd213ps(x1, x2, op);
- }
-
- void uni_vfmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
- const Xbyak::Operand &op) {
- mulps(x2, op);
- addps(x1, x2);
- }
- void uni_vfmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
- const Xbyak::Operand &op) {
- vfmadd231ps(x1, x2, op);
- }
-
- void uni_vfnmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
- const Xbyak::Operand &op) {
- mulps(x2, op);
- subps(x1, x2);
- }
-
- void uni_vfnmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
- const Xbyak::Operand &op) {
- vfnmadd231ps(x1, x2, op);
- }
-
- void uni_vsqrtps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
- sqrtps(x, op);
- }
- void uni_vsqrtps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
- vsqrtps(x, op);
- }
-
- void uni_vpaddd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
- const Xbyak::Operand &op) {
- assert(x1.getIdx() == x2.getIdx());
- paddd(x2, op);
- }
- void uni_vpaddd(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2,
- const Xbyak::Operand &op) {
- vpaddd(x1, x2, op);
- }
-
- void uni_vandps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
- const Xbyak::Operand &op = Xbyak::Operand()) {
- assert(x1.getIdx() == x2.getIdx());
- andps(x1, op);
- }
- void uni_vandps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
- const Xbyak::Operand &op = Xbyak::Operand()) {
- if (!mayiuse(avx512_common) || x1.getBit() < 512)
- vandps(x1, x2, op);
- else
- vpandd(x1, x2, op);
- }
-
- void uni_vorps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
- const Xbyak::Operand &op = Xbyak::Operand()) {
- assert(x1.getIdx() == x2.getIdx());
- orps(x1, op);
- }
- void uni_vorps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
- const Xbyak::Operand &op = Xbyak::Operand()) {
- if (!mayiuse(avx512_common) || x1.getBit() < 512)
- vorps(x1, x2, op);
- else
- vpord(x1, x2, op);
- }
-
- void uni_vpslld(const Xbyak::Xmm &x, const Xbyak::Operand &op,
- const int imm) {
- assert(x.getIdx() == op.getIdx());
- pslld(x, imm);
- }
- void uni_vpslld(const Xbyak::Ymm &x, const Xbyak::Operand &op,
- const int imm) {
- vpslld(x, op, imm);
- }
-
- void uni_vpsrld(const Xbyak::Xmm &x, const Xbyak::Operand &op,
- const int imm) {
- assert(x.getIdx() == op.getIdx());
- psrld(x, imm);
- }
- void uni_vpsrld(const Xbyak::Ymm &x, const Xbyak::Operand &op,
- const int imm) {
- vpsrld(x, op, imm);
- }
-
- void uni_vmaxps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
- const Xbyak::Operand &op2 = Xbyak::Operand()) {
- assert(x.getIdx() == op1.getIdx());
- maxps(x, op2);
- }
- void uni_vmaxps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
- const Xbyak::Operand &op2 = Xbyak::Operand()) {
- vmaxps(x, op1, op2);
- }
-
- void uni_vminps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
- const Xbyak::Operand &op2 = Xbyak::Operand()) {
- assert(x.getIdx() == op1.getIdx());
- minps(x, op2);
- }
- void uni_vminps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
- const Xbyak::Operand &op2 = Xbyak::Operand()) {
- vminps(x, op1, op2);
- }
-
- void uni_vcmpgtps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
- const Xbyak::Operand &op) {
- assert(x1.getIdx() == x2.getIdx());
- cmpps(x1, op, _cmp_nle_us);
- }
-
- void uni_vcmpgtps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
- const Xbyak::Operand &op) {
- vcmpgtps(x1, x2, op);
- }
-
- void uni_vcmpgeps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
- const Xbyak::Operand &op) {
- assert(x1.getIdx() == x2.getIdx());
- cmpps(x1, op, _cmp_nlt_us);
- }
-
- void uni_vcmpgeps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
- const Xbyak::Operand &op) {
- vcmpps(x1, x2, op, _cmp_nlt_us);
- }
-
- void uni_vtestps(const Xbyak::Xmm &x1, const Xbyak::Operand &op) {
- ptest(x1, op);
- }
-
- void uni_vtestps(const Xbyak::Ymm &x1, const Xbyak::Operand &op) {
- assert(!(x1.isZMM() || op.isZMM()));
- vtestps(x1, op);
- }
-
- void uni_vblendvps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
- const Xbyak::Operand &op, const Xbyak::Xmm &msk) {
- assert(x1.getIdx() == x2.getIdx());
- assert(msk.getIdx() == 0);
- blendvps(x1, op);
- }
- void uni_vblendvps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
- const Xbyak::Operand &op, const Xbyak::Ymm &msk) {
- vblendvps(x1, x2, op, msk);
- }
-
- void uni_vroundps(const Xbyak::Xmm &x, const Xbyak::Operand &op,
- const int imm) {
- roundps(x, op, imm);
- }
- void uni_vroundps(const Xbyak::Ymm &x, const Xbyak::Operand &op,
- const int imm) {
- vroundps(x, op, imm);
- }
-
- void uni_vcvtps2dq(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
- cvtps2dq(x, op);
- }
- void uni_vcvtps2dq(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
- vcvtps2dq(x, op);
- }
-
- void uni_vcvtdq2ps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
- cvtdq2ps(x, op);
- }
- void uni_vcvtdq2ps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
- vcvtdq2ps(x, op);
- }
-
- void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Xmm &x2) {
- movmskps(x1.cvt64(), x2);
- }
- void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Ymm &x2) {
- vmovmskps(x1, x2);
- }
-
- void uni_vpackssdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){
- assert(x1.getIdx() == x1.getIdx());
- packssdw(x1, op);
- }
- void uni_vpackssdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){
- vpackssdw(x1, x2, op);
- }
-
- void uni_vpackuswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){
- assert(x1.getIdx() == x1.getIdx());
- packuswb(x1, op);
- }
- void uni_vpackuswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){
- vpackuswb(x1, x2, op);
- }
-
-
- void mul_by_const(const Xbyak::Reg &out,
- const Xbyak::Reg64 &tmp, int value) {
- // Generates a shift + add sequence for multiplicating contents of the
- // out register by a known JIT-time value. Clobbers the tmp register.
- //
- // Pros compared to mul/imul:
- // - does not require using known registers
- // - not microcoded on Intel(R) Xeon Phi(TM) processors
- // Still, there are probably a lot of cases when mul/imul is faster on
- // Intel(R) Core(TM) processors. Not intended for critical path.
-
- // TODO: detect when overflow is emminent (Roma)
- // TODO: detect when using mul/imul is a better option (Roma)
-
- int p = 0; // the current power of 2
- int old_p = 0; // the last seen power of 2 such that value[old_p] != 0
-
- xor_(tmp, tmp);
- while (value) {
- if (value & 1) {
- int shift = p - old_p;
- if (shift) {
- shl(out, shift);
- old_p = p;
- }
- add(tmp, out);
- }
- value >>= 1;
- p++;
- }
- mov(out, tmp);
- }
-
-public:
- jit_generator(
- void *code_ptr = nullptr,
- size_t code_size = 256 * 1024
- ) : Xbyak::CodeGenerator(code_size, code_ptr)
- {
- }
- virtual ~jit_generator() {}
-
- virtual const char *name() const = 0;
- virtual const char *source_file() const = 0;
-
- const Xbyak::uint8 *getCode() {
- const Xbyak::uint8 *code = CodeGenerator::getCode();
- size_t code_size = getSize();
- jit_utils::register_jit_code(code, code_size, name(), source_file());
- return code;
- }
-
- template<typename F> const F getCode() {
- return (const F)getCode();
- }
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp
deleted file mode 100644
index 56d7f592e2..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp
+++ /dev/null
@@ -1,481 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef JIT_PRIMITIVE_CONF_HPP
-#define JIT_PRIMITIVE_CONF_HPP
-
-#include <stdint.h>
-
-#include "common/primitive_attr.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-/* convolution */
-enum conv_version_t {ver_unused, ver_fma, ver_avx512_core, ver_4fma, ver_vnni};
-enum conv_loop_order_t {loop_cgn, loop_gnc, loop_ngc, loop_gncw, loop_cwgn,
- loop_ngcw, loop_nhwcg, loop_nwcg};
-enum conv_1x1_loop_order_t {loop_rbl, loop_rlb, loop_lbr, loop_lrb, loop_blr,
- loop_brl};
-enum conv_kernel_kind_t {embd_bcast, expl_bcast};
-
-enum {
- FLAG_MB_FIRST = 1 << 0, FLAG_MB_LAST = 1 << 1,
- FLAG_OC_FIRST = 1 << 2, FLAG_OC_LAST = 1 << 3,
- FLAG_IC_FIRST = 1 << 4, FLAG_IC_LAST = 1 << 5,
- FLAG_SP_FIRST = 1 << 6, FLAG_SP_LAST = 1 << 7,
- FLAG_REDUCE_FIRST = 1<<8, FLAG_REDUCE_LAST = 1<<9,
- FLAG_ZERO_FILTER = 1 << 0, /* Controls whether the inner kernel skips
- loading weights-data from memory; this
- needs to happen on the first Group/16
- iteration. */
- FLAG_ZERO_BIAS = 1 << 1, /* Controls whether the inner kernel skip
- loading bias data from memory */
- FLAG_COMPUTE_BIAS = 1 << 2, /* Controls bias computation during execution
- pass */
-};
-
-struct jit_conv_conf_t {
- prop_kind_t prop_kind;
- conv_version_t ver;
- conv_loop_order_t loop_order;
-
- int simd_w;
- int ndims;
- int mb;
- int ngroups, ic, oc, oc_without_padding, ic_without_padding;
- int id, ih, iw, od, oh, ow;
- int f_pad, l_pad, t_pad;
- int back_pad, r_pad, b_pad;
- int kd, kh, kw;
- int stride_d, stride_h, stride_w;
- int dilate_d, dilate_h, dilate_w;
- format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround
- bool with_bias;
- bool with_sum;
- bool with_eltwise;
-
- post_ops_t::entry_t::eltwise_t eltwise;
-
- int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
-
- int idp, ihp, iwp, ohp, owp;
- int nb_ic, ic_block;
- int nb_oc, oc_block;
- int nb_ow, ow_block;
- int nb_oc_blocking; /* used in jit kernels for nb_oc work bloking taking
- into account vector registers distribution */
- int nb_oc_blocking_thr_chunk; /* used for distibution of nb_oc work
- within threads */
- int nb_ic_blocking, nb_ic_blocking_max; // blocking of nb_ic work
- int nb_ic_L2;
- int h_blocking;
- int nb_oc_L2;
- int ur_h, ur_w;
- int ur_w_tail;
- bool is_1stconv;
- int nonblk_group_off;
- /* fma avx512_core */
- conv_kernel_kind_t kernel_kind;
- /* 4fma */
- int tr_iw;
- int tr_src_num_guard_elems;
- /* 1st conv: 4fma */
- int tr_ld;
- int kh_step;
- /* 4vnni */
- int typesize_in;
- int typesize_out;
- int typesize_bia;
- int typesize_acc;
- /* avx512_u8s8u8 */
- int ic_nb1, ic_nb2;
- int oc_nb1;
- int ur_ow_max, ur_ow, ur_ow_tail;
- int ur_ow_nsteps;
- data_type_t bia_dt;
- data_type_t dst_dt;
- /* avx512: max possible value is nregs(32) - aux_regs(4) */
- int src_offsets[28];
- int src_count;
- bool expl_bcast;
- bool large_spatial;
- int is_oc_scale;
- int max_regs_ur; // maximum accumulation registers
- // dw conv
- int nb_ch, ch_block, nb_ch_blocking;
- bool is_depthwise, is_fast_depthwise, is_resrc_depthwise;
- int aligned_threads;
- // large spatial
- int oh_blk_size;
- // s8s8 convolution
- bool signed_input;
- float wei_adj_scale;
-};
-
-struct jit_conv_conf_2x3_wino_t {
- conv_version_t ver;
-
- int m;
- int r;
- int alpha;
- int tile_h, tile_w;
-
- int mb;
- int ngroups, ic, oc, oc_without_padding;
- int ih, iw, oh, ow;
- int l_pad, t_pad;
- int r_pad, b_pad;
- int kh, kw;
- int stride_h, stride_w;
- int dilate_h, dilate_w;
-
- int nb_ic, ic_block;
- int nb_oc, oc_block;
-
- int w_block_size, h_block_size;
-
- data_type_t bia_dt;
- data_type_t dst_dt;
-
- int is_oc_scale;
- int typesize_in;
- int typesize_out;
- int typesize_bia;
- int typesize_acc;
-
- format_tag_t src_tag, dst_tag; // temporary workaround
- bool with_bias;
- bool small_mb;
-
- int xb, yb;
- int inp_stride;
- int out_stride;
- int wei_stride;
- int bia_stride;
-
- int M, N, K;
- int m_block, n_block, k_block;
- int n2_block, n_chunks;
- int k2_block, k_chunks;
-
- int mb_block, nb_mb;
-
- size_t size_wino_src, size_wino_wei, size_wino_dst;
-
- int nthr;
-};
-
-/*
- Winograd sched policy:
-
- Computation Unit:
- W: weights transform
- S: src transform
- D: dst transform
- G: gemm
-
- Thread grouping by:
- i: nb_ic
- o: nb_oc
- t: tile_block
- e: element in tile
-
- Note: 'i' and 'o' are omited if
- i. not comblined with t or
- ii. with discrete transforms
-
- Current policies supported:
-*/
-enum winograd_sched_t {
- WSCHED_INVALID = 0,
-
- /* Forward & backward-data */
- /* W_S_G_D implements discrete transforms */
- WSCHED_DATA_W_S_G_D,
- /* W_SGD implements tiled transforms s.t. GEMM could reuse data in L2*/
- WSCHED_DATA_W_SGD,
-
- /* Backward-weights */
- WSCHED_WEI_S_D_G_W,
- WSCHED_WEI_SDGtWo,
- WSCHED_WEI_S_D_Giot_W,
- WSCHED_WEI_SDGt_W,
-};
-
-struct jit_conv_winograd_conf_t : public jit_conv_conf_t {
- int itiles;
- int jtiles;
- int ntiles;
- int ic_simd_block=16;
- int tile_4fma_padding;
- int tile_4fma;
- int oc_simd_block=16;
- int oc_reg_block;
- int ic_reg_block;
- int tile_block;
- int tile_block_ur;
- int nb_tile_block_ur;
-
- bool double_buffering;
- bool with_relu_postsum;
- int zmm_start;
- int nb_reg;
-
- int dimK;
- int dimK_4fma;
- int dimK_reg_block;
- int dimK_block;
- int dimK_nb_block;
-
- int dimM;
- int dimM_reg_block;
- int dimM_simd_block;
- int dimM_block;
- int dimM_nb_block;
-
- int dimN;
- int dimN_reg_block;
- int dimN_bcast_ur;
- int dimN_block;
- int dimN_nb_block;
-
- winograd_sched_t sched_policy;
-};
-
-struct jit_conv_call_s {
- const void *src; /* hack, non-const for backward_data */
- const void *dst; /* hack, non-const for forward */
- const void *filt; /* hack, non-const for backward_weights */
- const void *bias; /* hack, non-const for backward_bias */
- const void *src_prf;
- const void *dst_prf;
- const void *filt_prf;
- const void *bias_prf;
- const void *scales;
- const void *acc_s32;
- const void *compensation;
- size_t kd_offset;
- size_t kd_offset_prf;
- size_t d_index;
- size_t d_index_prf;
- size_t d_worksize;
- size_t d_worksize_prf;
- size_t kd_padding;
- size_t kd_padding_prf;
- size_t kh_padding;
- size_t kh_padding_prf;
- size_t owb;
- size_t owb_prf;
- size_t kw_padding;
- size_t channel;
- size_t channel_prf;
- size_t oc_blocks;
- size_t ur_w;
- size_t ur_str_w;
- size_t ch_blocks;
- size_t t_overflow;
- size_t b_overflow;
- int flags;
-};
-
-struct jit_deconv_call_s {
- const void *src; /* hack, non-const for backward_data */
- const void *dst; /* hack, non-const for forward */
- const void *filt; /* hack, non-const for backward_weights */
- const void *bias; /* hack, non-const for backward_bias */
- const void *scales;
- const void *compensation;
- size_t t_overflow;
- size_t b_overflow;
- size_t kh_padding;
- size_t oc_blocks;
-};
-
-struct jit_dw_conv_call_s {
- const void *input;
- const void *output;
- const void *filter;
- const void *bias;
- size_t kh_count;
- size_t oh_count;
- size_t oh_index;
- size_t filter_pad_off;
- unsigned char
- exec_flags; /* Flags passed by driver execution to inner kernel */
-};
-
-struct jit_wino_transform_call_s {
- size_t tile_block;
- size_t tile_block_ur;
- size_t nb_tile_block_ur;
- size_t tile_count;
- size_t tj;
- size_t ti;
- void *src;
- void *dst;
- void *Mw;
- void *M;
- void *T;
- void *G;
- void *bias;
-};
-
-struct jit_1x1_conv_conf_t {
- prop_kind_t prop_kind;
- conv_version_t ver;
-
- int mb;
- int ngroups, ic, oc, oc_without_padding, ic_without_padding;
- int iw, ih, ow, oh;
- int l_pad, t_pad;
- int kh, kw;
- int stride_h, stride_w;
- format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround
- bool with_bias;
- bool with_sum;
- bool with_eltwise;
-
- post_ops_t::entry_t::eltwise_t eltwise;
-
- int is, os;
- int ic_block, oc_block;
-
- int ur, ur_tail;
-
- int reduce_dim, reduce_block, nb_reduce,
- nb_reduce_blocking, nb_reduce_blocking_max;
- int load_dim, load_block, nb_load,
- nb_load_blocking, nb_load_blocking_max, nb_load_chunk;
- int bcast_dim, bcast_block, nb_bcast,
- nb_bcast_blocking, nb_bcast_blocking_max;
-
- int reduce_loop_unroll, reduce_loop_bcast_step, reduce_loop_load_step;
- int load_loop_load_step, load_loop_iter_step;
- int bcast_loop_output_step, bcast_loop_output_substep;
- int bcast_loop_bcast_step, bcast_loop_bcast_substep;
- int fma_step;
- int load_grp_count;
- conv_1x1_loop_order_t loop_order;
- bool use_vmovntps;
- /* avx512 core */
- bool expl_bcast;
- /* 4vnni */
- int typesize_in;
- int typesize_out;
- int typesize_bia;
- int typesize_acc;
- /* 4fma */
- bool transpose_src;
- int tr_is;
- int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
- int is_oc_scale;
- data_type_t bia_dt;
- data_type_t dst_dt;
- bool signed_input;
- float wei_adj_scale;
-};
-
-struct jit_gemm_conv_conf_t {
- prop_kind_t prop_kind;
-
- int mb;
- int ngroups, ic, oc;
- int iw, ih, id, ow, oh, od;
- int l_pad, t_pad, f_pad;
- int kh, kw, kd;
- int stride_h, stride_w, stride_d;
- int dilate_h, dilate_w, dilate_d;
- bool with_bias;
-
- int is, os, ks;
- int ic_block, oc_block;
-
- int nthr;
- ptrdiff_t im2col_sz;
- bool need_wei_reduction;
- bool signed_input;
- int oh_block;
- int ow_block;
- bool outer_threading;
-};
-
-struct jit_1x1_conv_call_s {
- const void *bcast_data;
- const void *load_data;
- const void *output_data;
- const void *bias_data; // used in forward and backward_weights only
- const void *acc_s32;
- const void *scales;
- const void *compensation;
-
- size_t load_dim;
- size_t bcast_dim;
- size_t reduce_dim;
-
- size_t output_stride; // used in backward_weights only
-
- size_t first_last_flag;
-};
-
-/* pooling */
-struct jit_pool_conf_t {
- int ndims;
- int mb, c;
- int id, ih, iw, od, oh, ow;
- int stride_d, stride_h, stride_w;
- int kd, kh, kw;
- int f_pad, t_pad, l_pad;
- alg_kind_t alg;
- bool is_training;
- bool pad_w_is_null;
- bool is_backward;
- bool simple_alg;
- data_type_t ind_dt;
-
- int c_block, c_tail, nb_c;
- int ur_c, ur_c_tail;
- int ur_w;
- int ur_w_tail;
- size_t tail[4];
- data_type_t src_dt;
- data_type_t dst_dt;
-};
-
-struct jit_pool_call_s {
- const float *src;
- const float *dst;
- const void *indices;
- const float *src_prf;
- const float *dst_prf;
- const void *indices_prf;
- size_t oh;
- size_t kd_padding;
- size_t kh_padding;
- size_t kh_padding_shift;
- size_t kd_padding_shift;
- size_t kw_padding;
- const float* init_value;
- float ker_area_h;
-};
-
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp
deleted file mode 100644
index 94d2101d6e..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp
+++ /dev/null
@@ -1,677 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-#include "cpu_memory.hpp"
-
-#include "jit_sse42_1x1_conv_kernel_f32.hpp"
-
-#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::format_tag;
-using namespace mkldnn::impl::prop_kind;
-using namespace mkldnn::impl::utils;
-
-using namespace Xbyak;
-
-void jit_sse42_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk)
-{
- mov(aux1_reg_bcast_data, reg_bcast_data);
- mov(aux_reg_output_data, reg_output_data);
- mov(bcast_loop_iter, reg_bcast_loop_work);
-
- Label bcast_loop;
- Label bcast_loop_tail;
-
- cmp(bcast_loop_iter, jcp.ur);
- jl(bcast_loop_tail, T_NEAR);
-
- L(bcast_loop); {
- assert(jcp.bcast_block % jcp.ur == 0);
- int num_substeps = jcp.bcast_block / jcp.ur;
- assert(num_substeps > 0 && num_substeps < 10);
- for (int i = 0; i < num_substeps; i++) {
- generate_reduce_loop(load_loop_blk, jcp.ur);
- if (i < num_substeps - 1) {
- add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
- add(aux_reg_output_data, jcp.bcast_loop_output_substep);
- } else {
- add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
- - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
- add(aux_reg_output_data, jcp.bcast_loop_output_step
- - (num_substeps - 1) * jcp.bcast_loop_output_substep);
- }
- }
- sub(bcast_loop_iter, jcp.bcast_block);
- cmp(bcast_loop_iter, jcp.bcast_block);
- jge(bcast_loop, T_NEAR);
- }
-
- L(bcast_loop_tail);
- if (jcp.ur_tail) {
- Label bcast_loop_tail_out;
- cmp(bcast_loop_iter, 0);
- jz(bcast_loop_tail_out, T_NEAR);
- generate_reduce_loop(load_loop_blk, jcp.ur_tail);
- L(bcast_loop_tail_out);
- }
-}
-
-void jit_sse42_1x1_conv_kernel_f32::generate_reduce_loop(
- int load_loop_blk, int ur)
-{
- auto reg_load = [=](int i, int n) {
- return Xmm(2*ur * load_loop_blk + 2*i + n + 1);
- };
-
- auto reg_accum = [=](int i, int j, int n) {
- return Xmm(2*j * load_loop_blk + 2*i + n + 1);
- };
-
- auto bias_ptr = [=](int i, int n) {
- return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i + n*4*sizeof(float)];
- };
-
- auto bcast_ptr = [=](int u, int j) {
- assert(j < jcp.ur);
- assert(u <= jcp.reduce_loop_unroll);
- size_t offt;
- if (one_of(jcp.prop_kind,
- forward_training, forward_inference, backward_data)) {
- assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data)
- ? jcp.oc_block : jcp.ic_block);
- auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is;
- offt = (u == jcp.reduce_loop_unroll)
- ? (height + j) * jcp.reduce_loop_unroll
- : j * jcp.reduce_loop_unroll + u;
- } else
- offt = u * jcp.ic_block + j;
- return ptr[aux_reg_bcast_data + sizeof(float) * offt];
- };
-
- auto load_ptr = [=](int u, int i, int n) {
- size_t offt;
- size_t u0 = u % jcp.reduce_loop_unroll;
- size_t u1 = u / jcp.reduce_loop_unroll;
- switch (jcp.prop_kind) {
- case backward_data:
- offt = (i * jcp.oc_block + u0) * jcp.ic_block;
- break;
- case backward_weights:
- offt = (i * jcp.os + u0) * jcp.oc_block;
- break;
- default:
- offt = (i * jcp.ic + u0) * jcp.oc_block;
- }
- return ptr[aux_reg_load_data
- + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt + n * 4 * sizeof(float)];
- };
-
- auto output_ptr = [=](int i, int j, int n) {
- switch (jcp.prop_kind) {
- case backward_data:
- return ptr[aux_reg_output_data +
- (i * jcp.is + j) * jcp.ic_block * sizeof(float) + n * 4 * sizeof(float)];
- case backward_weights:
- return ptr[aux_reg_output_data
- + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale
- + sizeof(float) * jcp.oc_block * j + n * 4 * sizeof(float)];
- default:
- return ptr[aux_reg_output_data +
- (i * jcp.os + j) * jcp.oc_block * sizeof(float) + n*4*sizeof(float)];
- }
- };
-
- auto init = [=]() {
- Label init_done;
- Label init_zero;
-
- if (jcp.with_bias && one_of(jcp.prop_kind, forward_training,
- forward_inference)) {
- test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
- jz(init_zero);
-
- for (int i = 0; i < load_loop_blk; i++)
- for (int j = 0; j < ur; ++j) {
- movups(reg_accum(i, j, 0), bias_ptr(i, 0));
- movups(reg_accum(i, j, 1), bias_ptr(i, 1));
- }
- jmp(init_done);
- }
-
- L(init_zero);
- for (int i = 0; i < load_loop_blk; ++i)
- for (int j = 0; j < ur; ++j) {
- auto r0 = reg_accum(i, j, 0);
- auto r1 = reg_accum(i, j, 1);
- xorps(r0, r0);
- xorps(r1, r1);
- }
-
- L(init_done);
-
- // load weights
- for (int i = 0; i < load_loop_blk; ++i) {
- movups(reg_load(i, 0), load_ptr(0, i, 0));
- movups(reg_load(i, 1), load_ptr(0, i, 1));
- }
-
- movss(reg_bcast, bcast_ptr(0, 0));
- shufps(reg_bcast, reg_bcast, 0);
- }; // init()
-
- auto store = [=]() {
- Label store_noadd;
-
- if (!jcp.with_sum) {
- test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
- jnz(store_noadd, T_NEAR);
- }
-
- for (int j = 0; j < ur; ++j)
- for (int i = 0; i < load_loop_blk; ++i) {
- auto r0 = reg_accum(i, j, 0);
- auto r1 = reg_accum(i, j, 1);
- addps(r0, output_ptr(i, j, 0));
- addps(r1, output_ptr(i, j, 1));
- }
-
- L(store_noadd);
-
- if (jcp.with_eltwise) {
- assert(ur * load_loop_blk < 14);
-
- Label store_norelu;
- test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
- jz(store_norelu, T_NEAR);
-
- eltwise_injector_->compute_vector_range(1,
- 2 * ur * load_loop_blk + 1);
-
- L(store_norelu);
- }
-
- for (int j = 0; j < ur; ++j)
- for (int i = 0; i < load_loop_blk; ++i) {
- movups(output_ptr(i, j, 0), reg_accum(i, j, 0));
- movups(output_ptr(i, j, 1), reg_accum(i, j, 1));
- }
- };
-
- auto fma_block = [=](bool last_block) {
- for (int u = 0; u < jcp.reduce_loop_unroll; ++u) {
- for (int j = 0; j < ur; ++j) {
- for (int i = 0; i < load_loop_blk; ++i) {
- mulps(reg_load(i, 0), reg_bcast);
- mulps(reg_load(i, 1), reg_bcast);
- addps(reg_accum(i, j, 0), reg_load(i, 0));
- addps(reg_accum(i, j, 1), reg_load(i, 1));
-
- if (j == ur - 1 && !(last_block && u == jcp.reduce_loop_unroll - 1)) {
- movups(reg_load(i, 0), load_ptr(u + 1, i, 0));
- movups(reg_load(i, 1), load_ptr(u + 1, i, 1));
- }
- }
- if (j < ur - 1) {
- movss(reg_bcast, bcast_ptr(u, j + 1));
- shufps(reg_bcast, reg_bcast, 0);
- }
- } // for ur
- if (!last_block || u < jcp.reduce_loop_unroll - 1) {
- movss(reg_bcast, bcast_ptr(u + 1, 0));
- shufps(reg_bcast, reg_bcast, 0);
- }
- } // for reduce_loop_unroll
- };
-
- Label reduce_loop;
- Label reduce_loop_tail;
-
- mov(aux_reg_load_data, reg_load_data);
- mov(aux_reg_bcast_data, aux1_reg_bcast_data);
-
- init();
-
- mov(reduce_loop_iter, reg_reduce_loop_work);
- sub(reduce_loop_iter, jcp.reduce_loop_unroll);
- jle(reduce_loop_tail, T_NEAR);
-
- L(reduce_loop); {
- fma_block(false);
- add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
- add(aux_reg_load_data, jcp.reduce_loop_load_step);
- sub(reduce_loop_iter, jcp.reduce_loop_unroll);
- jg(reduce_loop, T_NEAR);
- }
-
- L(reduce_loop_tail);
- fma_block(true);
-
- store();
-} // reduce_loop()
-
-void jit_sse42_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk)
-{
- if (!jcp.with_bias || jcp.prop_kind != backward_weights)
- return;
-
- Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out;
- Label diff_bias_load;
-
- auto diff_bias_ptr = [=](int i, int n) {
- return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)+ 4*n*sizeof(float)];
- };
-
- auto load_ptr = [=](int u, int i, int n) {
- return ptr[aux_reg_load_data
- + (i * jcp.os + u) * jcp.oc_block * sizeof(float) + 4*n*sizeof(float)];
- };
-
- auto diff_bias_reg = [=](int i, int n) { return Xmm(2*i + n + 1); };
-
- mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]);
- cmp(reg_diff_bias_data, 0);
- je(diff_bias_loop_out, T_NEAR);
-
- test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
- jz(diff_bias_load, T_NEAR);
-
- for (int i = 0; i < load_loop_blk; ++i) {
- auto r0 = diff_bias_reg(i, 0);
- auto r1 = diff_bias_reg(i, 1);
- xorps(r0, r0);
- xorps(r1, r1);
- }
- jmp(diff_bias_init_out, T_NEAR);
-
- L(diff_bias_load);
- for (int i = 0; i < load_loop_blk; ++i) {
- movups(diff_bias_reg(i, 0), diff_bias_ptr(i, 0));
- movups(diff_bias_reg(i, 1), diff_bias_ptr(i, 1));
- }
-
- L(diff_bias_init_out);
- mov(aux_reg_load_data, reg_load_data);
- mov(reduce_loop_iter, reg_reduce_loop_work);
- L(diff_bias_loop); {
- for(int u = 0; u < jcp.reduce_loop_unroll; ++u)
- for (int i = 0; i < load_loop_blk; ++i) {
- addps(diff_bias_reg(i, 0), load_ptr(u, i, 0));
- addps(diff_bias_reg(i, 1), load_ptr(u, i, 1));
- }
- assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
- add(aux_reg_load_data, jcp.reduce_loop_load_step);
- sub(reduce_loop_iter, jcp.reduce_loop_unroll);
- jnz(diff_bias_loop, T_NEAR);
- }
-
- for (int i = 0; i < load_loop_blk; i++) {
- movups(diff_bias_ptr(i, 0), diff_bias_reg(i, 0));
- movups(diff_bias_ptr(i, 1), diff_bias_reg(i, 1));
- }
-
- add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
- mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
-
- L(diff_bias_loop_out);
-}
-
-void jit_sse42_1x1_conv_kernel_f32::generate()
-{
- preamble();
-
- mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
- mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
- mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
- if (jcp.with_bias) {
- if (jcp.prop_kind == backward_weights) {
- sub(rsp, stack_space_needed);
- mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]);
- mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
- } else
- mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
- }
-
- mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
- mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
- mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
- mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
- if (jcp.prop_kind == backward_weights)
- mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
-
- auto generate_load_loop_body = [=] (int load_loop_blk) {
- generate_bcast_loop(load_loop_blk);
- add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
- switch (jcp.prop_kind) {
- case forward_training:
- case forward_inference:
- add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
- add(reg_output_data,
- load_loop_blk * jcp.os * jcp.oc_block * sizeof(float));
- break;
- case backward_data:
- add(reg_output_data,
- load_loop_blk * jcp.is * jcp.ic_block * sizeof(float));
- break;
- case backward_weights:
- for (int i = 0; i < load_loop_blk; i++)
- add(reg_output_data, reg_output_stride);
- break;
- default:
- assert(!"invalid prop_kind");
- }
- sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
- };
-
- Label load_loop_blk_8;
- Label load_loop_blk_16;
- Label load_loop_blk_24;
- Label load_loop_blk_end;
-
- cmp(reg_load_loop_work, 8);
- jle(load_loop_blk_8, T_NEAR);
-
- cmp(reg_load_loop_work, 32);
- je(load_loop_blk_16, T_NEAR);
-
- cmp(reg_load_loop_work, 16);
- jle(load_loop_blk_16, T_NEAR);
-
- L(load_loop_blk_24); {
- generate_diff_bias_loop(3);
- generate_load_loop_body(3);
- cmp(reg_load_loop_work, 32);
- je(load_loop_blk_16);
- cmp(reg_load_loop_work, 24);
- jge(load_loop_blk_24);
- }
-
- cmp(reg_load_loop_work, 8);
- jle(load_loop_blk_8, T_NEAR);
-
- L(load_loop_blk_16); {
- generate_diff_bias_loop(2);
- generate_load_loop_body(2);
- cmp(reg_load_loop_work, 16);
- jge(load_loop_blk_16);
- }
-
- L(load_loop_blk_8); {
- cmp(reg_load_loop_work, 0);
- je(load_loop_blk_end, T_NEAR);
- generate_diff_bias_loop(1);
- generate_load_loop_body(1);
- }
-
- L(load_loop_blk_end);
-
- if (jcp.with_bias && jcp.prop_kind == backward_weights)
- add(rsp, stack_space_needed);
-
- postamble();
-
- if (jcp.with_eltwise)
- eltwise_injector_->prepare_table();
-}
-
-bool jit_sse42_1x1_conv_kernel_f32::post_ops_ok(
- jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
- const auto &p = attr.post_ops_;
-
- auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
- auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
-
- switch (p.len_) {
- case 0: return true; // no post_ops
- case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
- case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
- default: return false;
- }
-
- return false;
-}
-
-status_t jit_sse42_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr)
-{
- if (!mayiuse(sse42))
- return status::unimplemented;
-
- // TODO (Roma): this code is duplicated from the generic kernel; maybe the
- // configuration struct could do some stuff below
- const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
- const int ndims = src_d.ndims();
-
- jcp.prop_kind = cd.prop_kind;
-
- jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
- jcp.mb = src_d.dims()[0];
-
- jcp.oc = dst_d.dims()[1] / jcp.ngroups;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
-
- jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
- jcp.iw = src_d.dims()[ndims - 1];
- jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
- jcp.ow = dst_d.dims()[ndims - 1];
-
- jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
- jcp.kw = weights_d.dims()[with_groups + ndims - 1];
-
- jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
- jcp.l_pad = cd.padding[0][ndims - 3];
-
- jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
- jcp.stride_w = cd.strides[ndims - 3];
-
- jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
-
- jcp.os = jcp.oh * jcp.ow;
- jcp.is = jcp.ih * jcp.iw;
-
- if (!post_ops_ok(jcp, attr))
- return status::unimplemented;
-
- const auto &p = attr.post_ops_;
- jcp.with_sum = p.find(primitive_kind::sum) != -1;
- const int eltwise_ind = p.find(primitive_kind::eltwise);
- jcp.with_eltwise = eltwise_ind != -1;
- if (jcp.with_eltwise)
- jcp.eltwise = p.entry_[eltwise_ind].eltwise;
-
- const int is_bwd_d = jcp.prop_kind == backward_data;
-
- format_tag_t dat_tag = ndims == 3 ? nCw8c : nChw8c;
- format_tag_t wei_tag = with_groups
- ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o,
- gOIhw8o8i)
- : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o,
- OIhw8o8i);
-
- jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
- jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
- jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
-
- bool args_ok = true
- && jcp.ngroups == 1
- && jcp.src_tag == dat_tag
- && jcp.wei_tag == wei_tag
- && jcp.dst_tag == dat_tag;
- if (!args_ok) return status::unimplemented;
-
- const int simd_w = 4;
- jcp.ic_block = jcp.oc_block = simd_w*2;
-
- args_ok = true
- && jcp.oc % jcp.oc_block == 0
- && jcp.ic % jcp.ic_block == 0
- && jcp.t_pad == 0 && jcp.l_pad == 0
- && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
- && jcp.kh == 1 && jcp.kw == 1;
- if (!args_ok) return status::unimplemented;
-
- jcp.ur = 1;
-
- int load_blocking{ 0 };
- int load_blocking_max{ 0 };
- int bcast_blocking{ 0 };
- int bcast_blocking_max{ 0 };
- int reduce_blocking{ 0 };
-
- if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
- jcp.reduce_dim = jcp.ic;
- jcp.reduce_block = jcp.ic_block;
-
- jcp.load_dim = jcp.oc;
- jcp.load_block = jcp.oc_block;
-
- jcp.bcast_dim = jcp.is;
- jcp.bcast_block = jcp.ur;
-
- jcp.reduce_loop_unroll = jcp.reduce_block;
- jcp.reduce_loop_bcast_step
- = jcp.reduce_loop_unroll * jcp.is * sizeof(float);
- jcp.reduce_loop_load_step
- = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
-
- jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float);
- jcp.bcast_loop_output_substep = -1; // unused
- jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float);
- jcp.bcast_loop_bcast_substep = -1; // unused
-
- jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float);
- jcp.load_loop_iter_step = jcp.oc_block;
-
- load_blocking = 120; // assumes the kernel is jcp.ur x 3
- load_blocking_max = 144;
- bcast_blocking = 128; // affects load balancing across threads
- bcast_blocking_max = 192;
- reduce_blocking = 128; // affects L1$ utilization
- } else if (jcp.prop_kind == backward_data) {
- jcp.reduce_dim = jcp.oc;
- jcp.reduce_block = jcp.oc_block;
-
- jcp.load_dim = jcp.ic;
- jcp.load_block = jcp.oc_block;
-
- jcp.bcast_dim = jcp.os;
- jcp.bcast_block = jcp.ur;
-
- jcp.reduce_loop_unroll = jcp.reduce_block;
- jcp.reduce_loop_bcast_step
- = jcp.reduce_loop_unroll * jcp.os * sizeof(float);
- jcp.reduce_loop_load_step
- = jcp.reduce_loop_unroll * jcp.ic * sizeof(float);
-
- jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float);
- jcp.bcast_loop_output_substep = -1; // unused
- jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float);
- jcp.bcast_loop_bcast_substep = -1; // unused
-
- jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float);
- jcp.load_loop_iter_step = jcp.ic_block;
-
- load_blocking = 96; // assumes the kernel is jcp.ur x 3
- load_blocking_max = 144;
- bcast_blocking = 128; // affects load balancing across threads
- bcast_blocking_max = 196;
- reduce_blocking = 64; // affects L1$ utilization
- } else if (jcp.prop_kind == backward_weights) {
- jcp.reduce_dim = jcp.os;
- jcp.reduce_block = 1;
-
- jcp.load_dim = jcp.oc;
- jcp.load_block = jcp.oc_block;
-
- jcp.bcast_dim = jcp.ic;
- jcp.bcast_block = jcp.ic_block;
-
- jcp.reduce_loop_unroll = jcp.reduce_block;
- jcp.reduce_loop_bcast_step
- = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float);
- jcp.reduce_loop_load_step
- = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
-
- jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float);
- jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float);
- jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float);
- jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float);
-
- jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float);
- jcp.load_loop_iter_step = jcp.oc_block;
-
- /* --- */
-
- load_blocking = div_up(jcp.load_dim, jcp.load_block);
- while (true) {
- if (load_blocking <= 32) break;
- else if (load_blocking % 2 == 0) load_blocking /= 2;
- else if (load_blocking % 3 == 0) load_blocking /= 3;
- else break;
- }
- load_blocking *= jcp.load_block;
- load_blocking_max = load_blocking;
- assert(jcp.load_dim % load_blocking == 0);
-
- bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
- while (true) {
- if (bcast_blocking <= 9) break;
- else if (bcast_blocking % 2 == 0) bcast_blocking /= 2;
- else if (bcast_blocking % 3 == 0) bcast_blocking /= 3;
- else break;
- }
- bcast_blocking *= jcp.bcast_block;
- bcast_blocking_max = bcast_blocking;
- assert(jcp.bcast_dim % bcast_blocking == 0);
-
- reduce_blocking = 128; // affects L1$ utilization
- } else
- return status::unimplemented;
-
- assert(load_blocking);
- assert(load_blocking_max);
- assert(bcast_blocking);
- assert(bcast_blocking_max);
- assert(reduce_blocking);
-
- assert(jcp.bcast_block % jcp.ur == 0);
- jcp.ur_tail = jcp.bcast_dim % jcp.ur;
-
- jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
- jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
- jcp.nb_load_blocking = load_blocking / jcp.load_block;
- jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
- jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
-
- jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
- jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
- jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
-
- return status::success;
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp
deleted file mode 100644
index b314a5098c..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp
+++ /dev/null
@@ -1,104 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef JIT_SSE42_1x1_CONV_KERNEL_F32_HPP
-#define JIT_SSE42_1x1_CONV_KERNEL_F32_HPP
-
-#include "c_types_map.hpp"
-#include "cpu_memory.hpp"
-#include "jit_generator.hpp"
-#include "jit_primitive_conf.hpp"
-#include "jit_uni_eltwise.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct jit_sse42_1x1_conv_kernel_f32: public jit_generator {
- jit_sse42_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp,
- const primitive_attr_t &attr)
- : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr)
- {
- if (jcp.with_eltwise)
- eltwise_injector_ = new jit_uni_eltwise_injector_f32<sse42>(this,
- jcp.eltwise);
-
- this->generate();
- jit_ker = (void (*)(jit_1x1_conv_call_s *))this->getCode();
- }
-
- ~jit_sse42_1x1_conv_kernel_f32() {
- delete eltwise_injector_;
- }
-
- static bool post_ops_ok(jit_1x1_conv_conf_t &jcp,
- const primitive_attr_t &attr);
-
- static status_t init_conf(jit_1x1_conv_conf_t &jcp,
- const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr);
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_1x1_conv_kernel_f32)
-
- jit_1x1_conv_conf_t jcp;
- const primitive_attr_t &attr_;
- void (*jit_ker)(jit_1x1_conv_call_s *);
-
-private:
- using reg64_t = const Xbyak::Reg64;
- using xmm_t = const Xbyak::Xmm;
-
- reg64_t reg_bcast_data = rax;
- reg64_t reg_load_data = rsi;
- reg64_t reg_output_data = rbx;
- reg64_t aux_reg_bcast_data = rdx;
- reg64_t aux1_reg_bcast_data = abi_not_param1;
- reg64_t aux_reg_load_data = abi_param1;
- reg64_t aux_reg_output_data = rbp;
- reg64_t reg_load_loop_work = r9;
- reg64_t reg_bcast_loop_work = r10;
- reg64_t reg_reduce_loop_work = r11;
- reg64_t load_loop_iter = r13;
- reg64_t imm_addr64 = load_loop_iter;
- reg64_t bcast_loop_iter = r14;
- reg64_t reduce_loop_iter = r15;
- reg64_t reg_reduce_pos_flag = r8;
- reg64_t reg_output_stride = r12;
- reg64_t reg_bias_data = r12;
- reg64_t reg_diff_bias_data = bcast_loop_iter;
-
- int reg_diff_bias_data_stack_offt = 0;
- int stack_space_needed = 8;
-
- xmm_t reg_bcast = xmm_t(15);
-
- jit_uni_eltwise_injector_f32<sse42> *eltwise_injector_;
-
- void generate_bcast_loop(int load_loop_blk);
- void generate_reduce_loop(int load_loop_blk, int ur);
- void generate_diff_bias_loop(int load_loop_blk);
-
- void generate();
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp
deleted file mode 100644
index 30c137641e..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp
+++ /dev/null
@@ -1,134 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "mkldnn_types.h"
-
-#include "c_types_map.hpp"
-#include "jit_sse42_1x1_convolution.hpp"
-#include "utils.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-#define data_blk_off(f, n, c, h, w) \
- ((ndims == 3) \
- ? (f).blk_off(n, c, w) \
- : (f).blk_off(n, c, h, w))
-
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::utils;
-
-void jit_sse42_1x1_convolution_fwd_t::execute_forward(
- const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
-
- const auto &jcp = kernel_->jcp;
- const int ndims = src_d.ndims();
-
- const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
-
- parallel(0, [&](const int ithr, const int nthr) {
- // TODO (Roma): remove this restriction
- assert(jcp.stride_w == 1 && jcp.stride_h == 1);
-
- auto par_conv = jit_1x1_conv_call_s();
-
- const int nb_oc = jcp.nb_load;
- const int nb_ic = jcp.nb_reduce;
- const int nb_ic_blocking = jcp.nb_reduce_blocking;
- const int os_block = jcp.bcast_block;
-
- int start{0}, end{0};
- balance211(work_amount, nthr, ithr, start, end);
-
- int iwork = start;
- while (iwork < end) {
- int n{0}, g{0}, osb{0};
- nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb,
- jcp.nb_bcast);
-
- const int bcast_step_rem = jcp.nb_bcast - osb;
- int bcast_step = bcast_step_rem <= jcp.nb_bcast_blocking_max
- ? bcast_step_rem : jcp.nb_bcast_blocking;
- bcast_step = nstl::min<int>(bcast_step, end - iwork);
-
- const int os = osb * os_block;
- const int ow = os % jcp.ow;
- const int oh = os / jcp.ow;
- const int iw = nstl::max<int>(ow * jcp.stride_w - jcp.l_pad, 0);
- const int ih = nstl::max<int>(oh * jcp.stride_h - jcp.t_pad, 0);
-
- par_conv.bcast_dim = this_block_size(os, jcp.os,
- bcast_step * os_block);
-
- int ocb = 0;
- while (ocb < jcp.nb_load) {
- const int load_step_rem = jcp.nb_load - ocb;
- const int load_step = load_step_rem < jcp.nb_load_blocking_max
- ? load_step_rem : jcp.nb_load_blocking;
-
- const size_t _ocb = g * nb_oc + ocb;
- par_conv.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc,
- load_step * jcp.oc_block);
-
- const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow);
- par_conv.output_data = &dst[dst_off];
-
- par_conv.bias_data = &bias[_ocb * jcp.oc_block];
-
- for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
- par_conv.first_last_flag = 0
- | (icb == 0) * FLAG_REDUCE_FIRST
- | (icb + nb_ic_blocking >= nb_ic) * FLAG_REDUCE_LAST;
-
- par_conv.reduce_dim = this_block_size(icb * jcp.ic_block,
- jcp.ic, nb_ic_blocking * jcp.ic_block);
-
- const size_t _icb = g * nb_ic + icb;
- const size_t src_off = data_blk_off(src_d, n, _icb, ih, iw);
- par_conv.bcast_data = &src[src_off];
-
- par_conv.load_data = &weights[pd()->with_groups()
- ? weights_d.blk_off(g, ocb, icb)
- : weights_d.blk_off(ocb, icb)];
-
- kernel_->jit_ker(&par_conv);
- }
-
- ocb += load_step;
- }
-
- iwork += bcast_step;
- }
- });
-
- if (pd()->wants_zero_pad_dst())
- ctx.memory(MKLDNN_ARG_DST)->zero_pad();
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp
deleted file mode 100644
index b32b1e4784..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp
+++ /dev/null
@@ -1,96 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_SSE42_1x1_CONVOLUTION_HPP
-#define CPU_JIT_SSE42_1x1_CONVOLUTION_HPP
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "utils.hpp"
-
-#include "cpu_convolution_pd.hpp"
-#include "cpu_primitive.hpp"
-#include "jit_sse42_1x1_conv_kernel_f32.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct jit_sse42_1x1_convolution_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_fwd_pd_t {
- pd_t(engine_t *engine,
- const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const typename pd_t::base_class *hint_fwd_pd)
- : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_1x1:", sse42, ""),
- jit_sse42_1x1_convolution_fwd_t);
-
- status_t init() {
- bool ok = true
- && is_fwd()
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::f32, data_type::f32, data_type::f32)
- && !has_zero_dim_memory()
- && set_default_formats();
- if (!ok) return status::unimplemented;
-
- return jit_sse42_1x1_conv_kernel_f32::init_conf(jcp_, *desc(),
- *src_md(), *weights_md(), *dst_md(), *attr());
- }
-
- jit_1x1_conv_conf_t jcp_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
-
- auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
- auto wei_tag = with_groups()
- ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o)
- : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o);
-
- return set_default_formats_common(dat_tag, wei_tag, dat_tag);
- }
- };
-
- jit_sse42_1x1_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {
- kernel_ = new jit_sse42_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr());
- }
- ~jit_sse42_1x1_convolution_fwd_t() { delete kernel_; };
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- jit_sse42_1x1_conv_kernel_f32 *kernel_;
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp
deleted file mode 100644
index 17cabc1186..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp
+++ /dev/null
@@ -1,497 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-#include "cpu_memory.hpp"
-
-#include "jit_sse42_conv_kernel_f32.hpp"
-
-#define GET_OFF(field) offsetof(jit_conv_call_s, field)
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::format_tag;
-using namespace mkldnn::impl::prop_kind;
-using namespace mkldnn::impl::utils;
-
-using namespace Xbyak;
-
-void jit_sse42_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w,
- int pad_l, int pad_r, int oc_blocks)
-{
- int iw = jcp.iw;
- int ih = jcp.ih;
- int kw = jcp.kw;
- int kh = jcp.kh;
- int nb_ic = jcp.nb_ic;
- int stride_w = jcp.stride_w;
- int dilate_w = jcp.dilate_w + 1;
- int ic_blk = jcp.ic_block;
- int oc_blk = jcp.oc_block;
-
- for (int ki = 0; ki < kw; ki++) {
- int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
- int jj_end = ur_w
- - nstl::max(0, div_up(ki*dilate_w + pad_r - (kw-1)*dilate_w, stride_w));
- for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
- for (int jj = jj_start; jj < jj_end; jj++) {
- int inp_off;
- if (one_of(jcp.src_tag, ncw, nchw))
- inp_off = ifm2*ih*iw + (ki*dilate_w + jj*stride_w - pad_l);
- else
- inp_off = (ki*dilate_w + jj*stride_w - pad_l)*ic_blk + ifm2;
-
- movss(Xmm(oc_blocks * ur_w + jj + 1),
- ptr[aux_reg_input + sizeof(float) * inp_off]);
- shufps(Xmm(oc_blocks * ur_w + jj + 1),
- Xmm(oc_blocks * ur_w + jj + 1), 0x0);
- }
-
- for (int ii = 0; ii < oc_blocks; ii++) {
- int ker_off = ii * nb_ic * kh * kw * ic_blk * oc_blk
- + ki * ic_blk * oc_blk + ifm2 * oc_blk;
-
- for (int jj = jj_start; jj < jj_end; jj++)
- {
- movups(xmm0,
- ptr[aux_reg_kernel + sizeof(float) * ker_off]);
- mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1));
- addps(Xmm(ur_w * ii + jj + 1), xmm0);
- }
- }
- }
- }
-}
-
-void jit_sse42_conv_fwd_kernel_f32::oh_step_nopad(int ur_w,
- int pad_l, int pad_r, int oc_blocks)
-{
- Label kw_loop;
-
- int iw = jcp.iw;
- int ih = jcp.ih;
- int kw = jcp.kw;
- int kh = jcp.kh;
- int nb_ic = jcp.nb_ic;
- int stride_w = jcp.stride_w;
- int dilate_w = jcp.dilate_w + 1;
- int ic_blk = jcp.ic_block;
- int oc_blk = jcp.oc_block;
-
- xor_(ki_iter, ki_iter);
- L(kw_loop);
- {
- int jj_start = 0;
- int jj_end = ur_w;
- for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
- for (int jj = jj_start; jj < jj_end; jj++) {
- int inp_off;
- if (one_of(jcp.src_tag, ncw, nchw))
- inp_off = ifm2 * ih * iw + (jj * stride_w - pad_l);
- else
- inp_off = (jj * stride_w - pad_l) * ic_blk + ifm2;
-
- movss(Xmm(oc_blocks * ur_w + jj + 1),
- ptr[aux_reg_input + sizeof(float) * inp_off]);
- shufps(Xmm(oc_blocks * ur_w + jj + 1),
- Xmm(oc_blocks * ur_w + jj + 1), 0x0);
- }
- for (int ii = 0; ii < oc_blocks; ii++) {
- int aux_kernel_offset = ii * nb_ic * kh * kw * ic_blk * oc_blk
- + ifm2 * oc_blk;
- for (int jj = jj_start; jj < jj_end; jj++) {
- movups(xmm0,
- ptr[aux_reg_kernel + sizeof(float) * aux_kernel_offset]);
- mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1));
- addps(Xmm(ur_w * ii + jj + 1), xmm0);
- }
- }
- }
- add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk);
- add(aux_reg_input, sizeof(float) * (one_of(jcp.src_tag, ncw, nchw) ?
- dilate_w : ic_blk * dilate_w));
-
- inc(ki_iter);
- cmp(ki_iter, kw);
- jl(kw_loop, T_NEAR);
- }
-}
-
-void jit_sse42_conv_fwd_kernel_f32::width_blk_step(int ur_w,
- int pad_l, int pad_r, int oc_blocks)
-{
- int iw = jcp.iw;
- int kw = jcp.kw;
- int ow = jcp.ow;
- int oh = jcp.oh;
- int dilate_h = jcp.dilate_h + 1;
- int dilate_w = jcp.dilate_w + 1;
- int ic_blk = jcp.ic_block;
- int oc_blk = jcp.oc_block;
- const int inp_mult = one_of(jcp.src_tag, ncw, nchw)
- ? dilate_h : ic_blk * dilate_h;
- const int inp_off = one_of(jcp.src_tag, ncw, nchw)
- ? dilate_w : ic_blk * dilate_w;
-
- xor_(simd_iter, simd_iter);
-
- mov(aux_reg_input, reg_input);
- mov(aux_reg_kernel, reg_kernel);
-
- Label init_simd_iter_loop;
- Label init_done;
- Label init_first;
-
- L(init_simd_iter_loop);
-
- if (!jcp.with_sum) {
- test(reg_ci_flag, FLAG_IC_FIRST);
- jne(init_first, T_NEAR);
- }
-
- for (int ii = 0; ii < oc_blocks; ii++)
- for (int jj = 0; jj < ur_w; jj++)
- movups(Xmm(ur_w * ii + jj + 1), xword[reg_output
- + sizeof(float) * (ii * oh * ow + jj) * oc_blk]);
-
- if (jcp.with_sum && jcp.with_bias) {
- test(reg_ci_flag, FLAG_IC_FIRST);
- je(init_done, T_NEAR);
-
- for (int ii = 0; ii < oc_blocks; ii++)
- for (int jj = 0; jj < ur_w; jj++)
- addps(Xmm(ur_w * ii + jj + 1),
- xword[reg_bias + sizeof(float) * ii * oc_blk]);
- }
-
- jmp(init_done);
-
- L(init_first);
- if (this->jcp.with_bias) {
- for (int ii = 0; ii < oc_blocks; ii++)
- for (int jj = 0; jj < ur_w; jj++)
- movups(Xmm(ur_w * ii + jj + 1),
- xword[reg_bias + sizeof(float) * ii * oc_blk]);
- } else {
- for (int ii = 0; ii < oc_blocks; ii++)
- for (int jj = 0; jj < ur_w; jj++)
- pxor(Xmm(ur_w * ii + jj + 1), Xmm(ur_w * ii + jj + 1));
- }
-
- L(init_done);
-
- Label skip_kh_loop;
- mov(kj, reg_kh);
- if ((jcp.dilate_h >= jcp.ih)
- || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
- cmp(kj, 0);
- je(skip_kh_loop, T_NEAR);
- }
- Label kh_loop;
- L(kh_loop);
- {
- if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) {
- oh_step_nopad(ur_w, pad_l, pad_r, oc_blocks);
- sub(aux_reg_input, sizeof(float) * kw * inp_off);
- add(aux_reg_input, sizeof(float) * iw * inp_mult);
- } else {
- oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks);
- add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk);
- add(aux_reg_input, sizeof(float) * iw * inp_mult);
- }
-
- dec(kj);
- cmp(kj, 0);
- jg(kh_loop, T_NEAR);
- }
-
- L(skip_kh_loop);
-
- if (jcp.with_eltwise) {
- Label regular_store;
- test(reg_ci_flag, FLAG_IC_LAST);
- je(regular_store, T_NEAR);
-
- eltwise_injector_->compute_vector_range(1, oc_blocks * ur_w + 1);
-
- L(regular_store);
- }
-
- for (int ii = 0; ii < oc_blocks; ii++) {
- for (int jj = 0; jj < ur_w; jj++) {
- const size_t o_off = (ii * oh * ow + jj) * oc_blk;
-
- Xmm reg_out = Xmm(ur_w * ii + jj + 1);
- movups(xword[reg_output + sizeof(float) * o_off], reg_out);
- }
- }
-
- mov(aux_reg_kernel, reg_kernel);
- mov(aux_reg_input, reg_input);
- add(aux_reg_kernel, sizeof(float) * 4);
- add(reg_output, sizeof(float) * 4);
- add(reg_bias, sizeof(float) * 4);
-
- inc(simd_iter);
- cmp(simd_iter, 2);
- jl(init_simd_iter_loop, T_NEAR);
-
- sub(reg_output, sizeof(float) * 8);
- sub(reg_bias, sizeof(float) * 8);
-}
-
-inline void jit_sse42_conv_fwd_kernel_f32::solve_common(int oc_blocks)
-{
- int ur_w = jcp.ur_w;
- int ur_w_tail = jcp.ur_w_tail;
- int n_oi = jcp.ow / ur_w;
- int iw = jcp.iw;
- int kw = jcp.kw;
- int ic_blk = jcp.ic_block;
- int oc_blk = jcp.oc_block;
- int dilate_w = jcp.dilate_w + 1;
- int str_w = jcp.stride_w;
- const int inp_mult = one_of(jcp.src_tag, ncw, nchw) ? 1 : ic_blk;
-
- int l_pad = jcp.l_pad;
- int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w
- - (iw + l_pad - 1));
- int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w
- - (iw + l_pad - 1);
- if (r_pad1 > 0) n_oi--;
-
- if (l_pad > 0) {
- n_oi--;
- if (n_oi < 0 && r_pad1 > 0)
- width_blk_step(ur_w, l_pad, r_pad1, oc_blocks); // "lrpad"
- else
- width_blk_step(ur_w, l_pad, 0, oc_blocks); // "lpad"
- add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult);
- add(reg_output, sizeof(float) * ur_w * oc_blk);
- }
-
- Label ow_loop;
- xor_(oi_iter, oi_iter);
-
- if (n_oi > 0) {
- L(ow_loop);
-
- width_blk_step(ur_w, 0, 0, oc_blocks); // "middle"
- add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
- add(reg_output, sizeof(float) * ur_w * oc_blk);
-
- inc(oi_iter);
- cmp(oi_iter, n_oi);
- jl(ow_loop, T_NEAR);
- }
-
- if (r_pad1 > 0 && n_oi >=0) {
- width_blk_step(ur_w, 0, r_pad1, oc_blocks); // "rpad"
- add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
- add(reg_output, sizeof(float) * ur_w * oc_blk);
- }
-
- if (ur_w_tail != 0)
- width_blk_step(ur_w_tail, 0, r_pad, oc_blocks); // "tail"
-}
-
-void jit_sse42_conv_fwd_kernel_f32::generate()
-{
- this->preamble();
-
- mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
- mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
- mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
- if (jcp.with_bias)
- mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
- mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
- mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
- mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]);
-
- int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking;
- Label tail, exit;
-
- cmp(reg_oc_blocks, jcp.nb_oc_blocking);
- jne(nb_oc_tail ? tail : exit, T_NEAR);
-
- solve_common(jcp.nb_oc_blocking);
- jmp(exit, T_NEAR);
-
- if (nb_oc_tail) {
- L(tail);
- cmp(reg_oc_blocks, nb_oc_tail);
- jne(exit, T_NEAR);
- solve_common(nb_oc_tail);
- }
-
- L(exit);
-
- this->postamble();
-
- if (jcp.with_eltwise)
- eltwise_injector_->prepare_table();
-}
-
-bool jit_sse42_conv_fwd_kernel_f32::post_ops_ok(
- jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
- const auto &p = attr.post_ops_;
-
- auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
- auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
-
- switch (p.len_) {
- case 0: return true; // no post_ops
- case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
- case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
- default: return false;
- }
-
- return false;
-}
-
-status_t jit_sse42_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr)
-{
- if (!mayiuse(sse42)) return status::unimplemented;
-
- jcp.prop_kind = cd.prop_kind;
-
- const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
- const int ndims = src_d.ndims();
- jcp.ndims = ndims;
-
- jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
- jcp.mb = src_d.dims()[0];
-
- jcp.oc = dst_d.dims()[1] / jcp.ngroups;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
-
- jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
- jcp.iw = src_d.dims()[ndims - 1];
- jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
- jcp.ow = dst_d.dims()[ndims - 1];
-
- jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
- jcp.kw = weights_d.dims()[with_groups + ndims - 1];
-
- jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
- jcp.l_pad = cd.padding[0][ndims - 3];
-
- jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
- jcp.stride_w = cd.strides[ndims - 3];
-
- jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[0];
- jcp.dilate_w = cd.dilates[ndims - 3];
- jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
- - (jcp.ih + jcp.t_pad - 1);
-
- if (ndims == 3) {
- jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c);
- jcp.wei_tag = weights_d.matches_one_of_tag(
- Owi8o, gOwi8o, OIw8i8o, gOIw8i8o);
- jcp.dst_tag = dst_d.matches_one_of_tag(nCw8c);
- } else if (ndims == 4) {
- jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c);
- jcp.wei_tag = weights_d.matches_one_of_tag(
- Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o);
- jcp.dst_tag = dst_d.matches_one_of_tag(nChw8c);
- }
- jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
-
- if (!post_ops_ok(jcp, attr))
- return status::unimplemented;
-
- const auto &p = attr.post_ops_;
- jcp.with_sum = p.find(primitive_kind::sum) != -1;
- const int eltwise_ind = p.find(primitive_kind::eltwise);
- jcp.with_eltwise = eltwise_ind != -1;
- if (jcp.with_eltwise)
- jcp.eltwise = p.entry_[eltwise_ind].eltwise;
-
- const bool flat = jcp.ic == 3;
- const bool mimo = !flat;
-
- bool args_ok = true
- && IMPLICATION(flat, one_of(jcp.src_tag, ncw, nwc, nchw, nhwc)
- && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o))
- && IMPLICATION(mimo, one_of(jcp.src_tag, nCw8c, nChw8c)
- && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o))
- && one_of(jcp.dst_tag, nCw8c, nChw8c);
- if (!args_ok) return status::unimplemented;
-
- const int simd_w = 8; // 2 SSE vectors processing at once
-
- jcp.ur_h = 1; /* no code-unrolling by h so far */
- jcp.ur_w = 3;
- if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
- jcp.ur_w_tail = jcp.ow % jcp.ur_w;
-
- jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */
-
- args_ok = true
- && jcp.oc % simd_w == 0
- && jcp.l_pad <= jcp.ur_w
- && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0)
- || (jcp.stride_w == 1 && jcp.stride_h == 1))
- && IMPLICATION(mimo, jcp.ic % simd_w == 0);
- if (!args_ok) return status::unimplemented;
-
- int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
- + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
-
- // kernel needs 1 temporary YMM register
- const int num_avail_regs = 15;
- if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) {
- /* recalculate ur_w, nb_oc_blocking and ur_w_tail */
- jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail,
- nstl::min(jcp.ow, num_avail_regs / 2));
- jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w;
- jcp.ur_w_tail = jcp.ow % jcp.ur_w;
- /* check again ... */
- r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
- + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
- if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail))
- return status::unimplemented;
- }
- assert(jcp.nb_oc_blocking > 0);
- assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs);
-
- jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w;
- jcp.nb_ic = jcp.ic / jcp.ic_block;
-
- jcp.oc_block = simd_w;
- jcp.nb_oc = jcp.oc / jcp.oc_block;
-
- if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
- jcp.nb_ic_blocking = 12;
- jcp.nb_ic_blocking_max = 16;
- } else {
- jcp.nb_ic_blocking = 1;
- jcp.nb_ic_blocking_max = jcp.nb_ic_blocking;
- }
-
- return status::success;
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp
deleted file mode 100644
index 33c26ef081..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp
+++ /dev/null
@@ -1,93 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef JIT_SSE42_CONV_KERNEL_F32_HPP
-#define JIT_SSE42_CONV_KERNEL_F32_HPP
-
-#include "c_types_map.hpp"
-#include "cpu_memory.hpp"
-#include "jit_generator.hpp"
-#include "jit_primitive_conf.hpp"
-#include "jit_uni_eltwise.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct jit_sse42_conv_fwd_kernel_f32: public jit_generator {
- jit_sse42_conv_fwd_kernel_f32(jit_conv_conf_t ajcp,
- const primitive_attr_t &attr)
- : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr)
- {
- if (jcp.with_eltwise)
- eltwise_injector_ = new jit_uni_eltwise_injector_f32<sse42>(this,
- jcp.eltwise);
-
- this->generate();
- jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
- }
-
- ~jit_sse42_conv_fwd_kernel_f32() {
- delete eltwise_injector_;
- }
-
- static bool post_ops_ok(jit_conv_conf_t &jcp,
- const primitive_attr_t &attr);
-
- static status_t init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_conv_fwd_kernel_f32)
- jit_conv_conf_t jcp;
- const primitive_attr_t &attr_;
- void (*jit_ker)(jit_conv_call_s *);
-
-private:
- using reg64_t = const Xbyak::Reg64;
- reg64_t reg_input = rax;
- reg64_t aux_reg_input = r8;
- reg64_t reg_kernel = rdx;
- reg64_t aux_reg_kernel = r9;
- reg64_t reg_output = rsi;
- reg64_t reg_bias = rbx;
-
- reg64_t kj = r10;
- reg64_t oi_iter = r11;
- reg64_t ki_iter = r12;
- reg64_t reg_kh = abi_not_param1;
- reg64_t simd_iter = r15;
- reg64_t reg_oc_blocks = r14;
- reg64_t imm_addr64 = reg_oc_blocks;
- Xbyak::Reg32 reg_ci_flag = r13d;
-
- jit_uni_eltwise_injector_f32<sse42> *eltwise_injector_;
-
- inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r,
- int oc_blocks);
- inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks);
- inline void width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks);
- inline void solve_common(int oc_blocks);
-
- void generate();
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp
deleted file mode 100644
index 5f77d692f5..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp
+++ /dev/null
@@ -1,136 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "mkldnn_types.h"
-
-#include "c_types_map.hpp"
-#include "jit_sse42_convolution.hpp"
-#include "mkldnn_thread.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::utils;
-
-#define src_blk_off(f, n, c, h, w) \
- (pd()->ndims() == 3) \
- ? (f).blk_off(n, c, w) \
- : (f).blk_off(n, c, h, w)
-
-#define wht_blk_off_(f, g, ...) \
- pd()->with_groups() \
- ? (f).blk_off(g, __VA_ARGS__) \
- : (f).blk_off(__VA_ARGS__)
-#define wht_blk_off(f, g, oc, ic, kh, kw) \
- pd()->ndims() == 3 \
- ? wht_blk_off_(f, g, oc, ic, kw) \
- : wht_blk_off_(f, g, oc, ic, kh, kw)
-
-void jit_sse42_convolution_fwd_t::execute_forward(
- const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
- const memory_desc_wrapper bias_d(pd()->weights_md(1));
-
- const auto &jcp = kernel_->jcp;
-
- int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
- const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.oh;
-
- parallel(0, [&](const int ithr, const int nthr) {
- size_t start{ 0 }, end{ 0 };
- balance211(work_amount, nthr, ithr, start, end);
-
- int icbb = 0;
- while (icbb < jcp.nb_ic) {
- int icb_step = jcp.nb_ic_blocking;
- int icb_step_rem = jcp.nb_ic - icbb;
- if (icb_step_rem < jcp.nb_ic_blocking_max)
- icb_step = icb_step_rem;
-
- size_t n{0}, g{0}, ocbb{0}, oh{0};
- nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work,
- oh, jcp.oh);
- for (size_t iwork = start; iwork < end; ++iwork) {
- int ocb = ocbb * jcp.nb_oc_blocking;
- int ocb_num = jcp.nb_oc_blocking;
-
- for (int icb = icbb; icb < icbb + icb_step; ++icb) {
- auto par_conv = jit_conv_call_s();
-
- const int ij = oh * jcp.stride_h;
- const int i_t_overflow = nstl::max(0, jcp.t_pad - ij);
- const int i_b_overflow = nstl::max(jcp.ih, ij
- + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih;
-
- const size_t _oc = g * jcp.nb_oc + ocb;
- const size_t _ic = g * jcp.nb_ic + icb;
-
- const int ih = nstl::max(ij - jcp.t_pad
- + div_up(i_t_overflow,
- (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0);
- par_conv.src = &src[src_blk_off(src_d, n,
- jcp.ic == 3 ? 0 : _ic, ih, 0)];
-
- par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, oh, 0)];
-
- const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
- par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb,
- jcp.ic == 3 ? 0 : icb, wh, 0)];
-
- if (icb == 0) {
- if (bias)
- par_conv.bias =
- &bias[bias_d.blk_off(_oc * jcp.oc_block)];
- par_conv.flags |= FLAG_IC_FIRST;
- }
-
- if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) {
- par_conv.flags |= FLAG_IC_LAST;
- }
-
- par_conv.oc_blocks =
- nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb;
-
- par_conv.kw_padding = 0;
- const int kh_padding = jcp.kh
- - div_up(i_t_overflow, (jcp.dilate_h + 1))
- - div_up(i_b_overflow, (jcp.dilate_h + 1));
- par_conv.kh_padding = nstl::max(0, kh_padding);
- kernel_->jit_ker(&par_conv);
- }
- nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work,
- oh, jcp.oh);
- }
- icbb += icb_step;
- }
- });
-
- if (pd()->wants_zero_pad_dst())
- ctx.memory(MKLDNN_ARG_DST)->zero_pad();
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp
deleted file mode 100644
index d2f0a38c5c..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp
+++ /dev/null
@@ -1,103 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_SSE42_CONVOLUTION_HPP
-#define CPU_JIT_SSE42_CONVOLUTION_HPP
-
-#include "c_types_map.hpp"
-#include "utils.hpp"
-
-#include "cpu_convolution_pd.hpp"
-#include "cpu_primitive.hpp"
-
-#include "jit_primitive_conf.hpp"
-#include "jit_sse42_conv_kernel_f32.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct jit_sse42_convolution_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_fwd_pd_t {
- pd_t(engine_t *engine,
- const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const typename pd_t::base_class *hint_fwd_pd)
- : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit:", sse42, ""),
- jit_sse42_convolution_fwd_t);
-
- status_t init() {
- bool ok = true
- && is_fwd()
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::f32, data_type::f32, data_type::f32)
- && !has_zero_dim_memory()
- && set_default_formats();
- if (!ok) return status::unimplemented;
-
- return jit_sse42_conv_fwd_kernel_f32::init_conf(jcp_, *desc(),
- *src_md(), *weights_md(), *dst_md(), *attr());
- }
-
- jit_conv_conf_t jcp_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
-
- const bool flat = IC() == 3;
- auto src_tag = flat
- ? utils::pick(ndims() - 3, ncw, nchw, ncdhw)
- : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
- auto dst_tag =
- utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
- auto wei_tag = with_groups()
- ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o,
- gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o)
- : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o,
- OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o);
-
- return set_default_formats_common(src_tag, wei_tag, dst_tag);
- }
- };
-
- jit_sse42_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd)
- { kernel_ = new jit_sse42_conv_fwd_kernel_f32(pd()->jcp_, *pd()->attr()); }
- ~jit_sse42_convolution_fwd_t() { delete kernel_; };
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- jit_sse42_conv_fwd_kernel_f32 *kernel_;
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp
deleted file mode 100644
index 0e734f7265..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp
+++ /dev/null
@@ -1,1192 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "nstl.hpp"
-#include "utils.hpp"
-#include "jit_generator.hpp"
-#include "cpu_barrier.hpp"
-
-#include "jit_transpose_src_utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace Xbyak;
-
-#define GET_OFF(x) offsetof(ctx_t, x)
-
-struct jit_trans_iw_ic_t: public jit_trans_src_t, public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_ic_t)
-
- jit_trans_iw_ic_t(const jit_conv_conf_t *conf): jit_trans_src_t(conf) {
- generate();
- ker_ = (decltype(ker_))this->getCode();
- }
-
-private:
- using reg64_t = const Xbyak::Reg64;
- using reg32_t = const Xbyak::Reg32;
- using opmask_t = const Xbyak::Opmask;
-
- enum { typesize = sizeof(float), transpose_size = 16, small_spatial = 14 };
- int src_stride, tr_src_stride;
- int tail;
- bool enable_prefetch;
-
- opmask_t k3333 = k1;
- opmask_t k5555 = k2;
- opmask_t kAAAA = k3;
- opmask_t kCCCC = k4;
- opmask_t k0F0F = k5;
- opmask_t kF0F0 = k6;
- opmask_t kTail = k7;
-
- reg64_t reg_src = r8;
- reg64_t reg_tr_src = r9;
- reg64_t reg_src_prf = r10;
- reg64_t reg_tr_src_prf = r11;
- reg64_t reg_loop = r12;
- reg64_t reg_tr_src_tmp = r13;
- reg32_t regw_tmp = r14d;
-
- void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores);
- void generate();
-};
-
-void jit_trans_iw_ic_t::transpose(int nrows, int l_pad, int r_pad,
- bool nontemporal_stores) {
- assert(nrows >= 0 && nrows <= transpose_size);
- static_assert(transpose_size == 16, "Unsupported transpose size");
- if (!nrows)
- return;
-
- auto pf_src_t0 = [=](int i) {
- if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_src,
- (transpose_size + i) * src_stride));
- };
-
- auto pf_tr_src_t0 = [=](int i) {
- int offset = (transpose_size) * typesize + i * tr_src_stride;
- if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_tr_src, offset));
- if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_tr_src,
- offset + 64));
- };
-
- auto pf_src_t1 = [=](int i) {
- if(enable_prefetch) prefetcht1(EVEX_compress_addr(reg_src_prf,
- i * src_stride));
- };
-
- auto pf_tr_src_t1 = [=](int i) {
- if(enable_prefetch) prefetchwt1(EVEX_compress_addr(reg_tr_src_prf,
- i * tr_src_stride));
- };
-
- auto src_zmm = [=](int i) {
- assert(i >= 0 && i < 16);
- return Zmm(i);
- };
-
- auto tmp_zmm = [=](int i) {
- assert(i >= 0 && i < 16);
- return Zmm(16 + i);
- };
-
- auto load = [=](int i) {
- vmovups(src_zmm(i), EVEX_compress_addr(reg_src, i * src_stride));
- };
-
- auto store = [=](Zmm r, int i) {
- auto kmovw = [=](Opmask k, unsigned w) {
- mov(regw_tmp, w);
- jit_generator::kmovw(k, regw_tmp);
- };
-
- auto padding = [=] (Reg64 reg, int pad) {
- kmovw(kTail, (1 << pad) - 1);
- auto k = kTail;
- auto base = reg;
- base.setOpmaskIdx(k.getIdx(), true);
-
- auto zmm_zero = r;
- vpxord(zmm_zero, zmm_zero, zmm_zero);
- auto addr = EVEX_compress_addr(base, i * tr_src_stride);
- vmovups(addr, zmm_zero);
- };
-
- mov(reg_tr_src_tmp, reg_tr_src);
- if (l_pad > 0)
- add(reg_tr_src_tmp, l_pad * typesize);
-
- if (tail != transpose_size)
- kmovw(kTail, (1 << tail) - 1);
-
- // Xbyak does not allow k0 to be specified explicitly via the '|'
- // operator, so we have to do this via a method call (implicitly
- // EVEX encoding uses k0 to mean 'no mask')
- bool partial_store = nrows < 16;
- auto k = partial_store ? kTail : k0;
- auto base = reg_tr_src_tmp;
- base.setOpmaskIdx(k.getIdx(), true);
-
- auto addr = EVEX_compress_addr(base, i * tr_src_stride);
- if (nontemporal_stores && !partial_store)
- vmovntps(addr, r);
- else
- vmovups(addr, r);
-
- if (r_pad > 0) {
- add(reg_tr_src_tmp, tail * typesize);
- padding(reg_tr_src_tmp, r_pad);
- }
-
- if (l_pad > 0) {
- padding(reg_tr_src, l_pad);
- }
- };
-
- auto transpose16x8 = [=](int base_idx) {
- assert(base_idx == 0 || base_idx == 8);
-
- // swap 1
- for (int i = 0; i < 4; i++) {
- int src_idx0 = base_idx + i * 2;
- int src_idx1 = src_idx0 + 1;
-
- int next_src_idx0 = src_idx0 + 2;
- int next_src_idx1 = src_idx1 + 2;
- bool load_next = base_idx == 0 || i < 3;
-
- if (base_idx == 0 && i == 0) {
- load(src_idx0);
- load(src_idx1);
- }
-
- auto tmp0 = tmp_zmm(src_idx0);
- auto tmp1 = tmp_zmm(src_idx1);
- auto src0 = src_zmm(src_idx0);
- auto src1 = src_zmm(src_idx1);
-
- if (next_src_idx0 < nrows && load_next)
- load(next_src_idx0);
- valignd(tmp0, src0, src0, 0x1);
- pf_src_t1(base_idx + i);
-
- if (next_src_idx1 < nrows && load_next)
- load(next_src_idx1);
- valignd(tmp1, src1, src1, 0xf);
- pf_src_t0(base_idx + i);
-
- vmovaps(src0 | kAAAA, tmp1);
- vmovaps(src1 | k5555, tmp0);
- }
- // swap 2
- for (int i = 0; i < 4; i++) {
- int select_half = (i < 2) ? 0 : 2;
- int src_idx0 = base_idx + i + select_half + 0;
- int src_idx2 = src_idx0 + 2;
-
- auto tmp0 = tmp_zmm(src_idx0);
- auto tmp1 = tmp_zmm(src_idx2);
- auto src0 = src_zmm(src_idx0);
- auto src2 = src_zmm(src_idx2);
-
- valignd(tmp0, src0, src0, 0x2);
- pf_src_t1(base_idx + 4 + i);
- valignd(tmp1, src2, src2, 0xe);
- pf_src_t0(base_idx + 4 + i);
- vmovaps(src2 | k3333, tmp0);
- vmovaps(src0 | kCCCC, tmp1);
- }
-
- // swap 4
- for (int i = 0; i < 4; i++) {
- int src_idx0 = base_idx + i;
- int src_idx4 = src_idx0 + 4;
-
- auto tmp0 = tmp_zmm(src_idx0);
- auto src0 = src_zmm(src_idx0);
- auto src4 = src_zmm(src_idx4);
-
- vmovaps(tmp0, src0);
- vshuff32x4(src0 | kF0F0, src4, src4, 0xb1);
- pf_tr_src_t1(base_idx / 2 + i);
- vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1);
- pf_tr_src_t0(base_idx / 2 + i);
- }
- };
-
- auto fixup16x16 = [=]() {
- // swap 8
- for (int i = 0; i < 8; i++) {
- auto tmp = tmp_zmm(i);
- auto src0 = src_zmm(i);
- auto src8 = src_zmm(8 + i);
- vshuff64x2(tmp, src0, src8, 0x44);
- store(tmp, i);
- if (i % 2 == 0) {
- pf_tr_src_t1(8 + i / 2);
- pf_tr_src_t0(8 + i / 2);
- }
- }
-
- for (int i = 0; i < 8; i++) {
- auto tmp = tmp_zmm(8 + i);
- auto src0 = src_zmm(i);
- auto src8 = src_zmm(8 + i);
- vshuff64x2(tmp, src0, src8, 0xee);
- store(tmp, 8 + i);
- if (i % 2 == 0) {
- pf_tr_src_t1(12 + i / 2);
- pf_tr_src_t0(12 + i / 2);
- }
- }
- };
-
- transpose16x8(0);
- transpose16x8(8);
- fixup16x16();
-}
-
-void jit_trans_iw_ic_t::generate() {
- preamble();
-
- const int ic_block = conf_->ic_block;
- const int iw = conf_->iw;
- const int tr_iw = conf_->tr_iw;
- const int transposes = utils::div_up(iw, transpose_size);
- int loop_iters = nstl::max(0, transposes - 1);
- tail = iw - loop_iters * transpose_size;
-
- src_stride = ic_block * typesize;
- assert(src_stride == 64);
- tr_src_stride = tr_iw * typesize;
-
- bool nontemporal_stores = false;
- enable_prefetch = iw > small_spatial ? 1 : 0;
-
- assert(transpose_size == ic_block);
- const int src_step = ic_block * transpose_size * typesize;
- const int tr_src_step = ic_block * typesize;
-
- const int left_pad = conf_->l_pad;
- const int right_pad = tr_iw - iw - left_pad;
-
- mov(reg_src, ptr [param1 + GET_OFF(src)]);
- mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]);
- mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]);
- mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]);
-
- auto kmovw = [=](Opmask k, unsigned w) {
- mov(regw_tmp, w);
- jit_generator::kmovw(k, regw_tmp);
- };
-
- kmovw(k3333, 0x3333); // 0011001100110011
- kmovw(k5555, 0x5555); // 0101010101010101
- kmovw(kAAAA, 0xaaaa); // 1010101010101010
- kmovw(kCCCC, 0xcccc); // 1100110011001100
- kmovw(k0F0F, 0x0f0f); // 0000111100001111
- kmovw(kF0F0, 0xf0f0); // 1111000011110000
-
- if (left_pad > 0 && loop_iters > 0) {
- loop_iters--;
- transpose(transpose_size, left_pad, 0, nontemporal_stores);
- add(reg_src, src_step);
- add(reg_tr_src, tr_src_step + left_pad * typesize);
- add(reg_src_prf, src_step);
- add(reg_tr_src_prf, tr_src_step + left_pad * typesize);
- }
-
- if (loop_iters) {
- mov(reg_loop, loop_iters);
- Label loop;
- L(loop); {
- transpose(transpose_size, 0, 0, nontemporal_stores);
- add(reg_src, src_step);
- add(reg_tr_src, tr_src_step);
- add(reg_src_prf, src_step);
- add(reg_tr_src_prf, tr_src_step);
- sub(reg_loop, 1);
- jnz(loop);
- }
- }
- if (transposes > 1)
- transpose(tail, 0, right_pad, nontemporal_stores);
- else
- transpose(tail, left_pad, right_pad, nontemporal_stores);
-
- postamble();
-}
-
-struct jit_trans_iw_ic_int16_t: public jit_trans_src_t, public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_ic_int16_t)
- jit_trans_iw_ic_int16_t(const jit_conv_conf_t *conf):
- jit_trans_src_t(conf) {
- generate();
- ker_ = (decltype(ker_))this->getCode();
- }
-
-private:
- using reg64_t = const Xbyak::Reg64;
- using reg32_t = const Xbyak::Reg32;
- using opmask_t = const Xbyak::Opmask;
-
- enum { typesize = sizeof(int16_t), transpose_size = 16, small_spatial = 14 };
- int src_stride, tr_src_stride;
- int tail;
- bool enable_prefetch;
-
- opmask_t kFFFF = k1;
- opmask_t k5555 = k2;
- opmask_t kAAAA = k3;
- opmask_t kAA = k4;
- opmask_t k55 = k5;
- opmask_t kCC = k6;
- opmask_t k33 = k7;
- opmask_t kTail = k1;
-
- reg64_t reg_src = r8;
- reg64_t reg_tr_src = r9;
- reg64_t reg_src_prf = r10;
- reg64_t reg_tr_src_prf = r11;
- reg64_t reg_loop = r12;
- reg64_t reg_tr_src_tmp = r13;
- reg32_t regw_tmp = r14d;
- reg64_t imm_addr64 = rbx;
-
- Xbyak::Zmm vidx1 = zmm31;
- Xbyak::Zmm vidx2 = zmm30;
- Xbyak::Zmm vidx3 = zmm29;
- Xbyak::Zmm vidx4 = zmm28;
- Xbyak::Zmm vidx5 = zmm27;
- Xbyak::Zmm zmm_tmp = zmm26;
-
-
- void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores);
- void generate();
-};
-
-void jit_trans_iw_ic_int16_t::transpose(int nrows, int l_pad, int r_pad,
- bool nontemporal_stores) {
- assert(nrows >= 0 && nrows <= transpose_size);
- static_assert(transpose_size == 16, "Unsupported transpose size");
- if (!nrows)
- return;
-
- auto src_zmm = [=](int i) {
- return Zmm(i);
- };
-
- auto src_ymm = [=](int i) {
- assert(i >= 0 && i < 16);
- return Ymm(i);
- };
-
- auto load_ymm = [=](int i) {
- vmovups(src_ymm(i), EVEX_compress_addr(reg_src, i * src_stride));
- };
-
- auto kmovw = [=](Opmask k, unsigned w) {
- mov(regw_tmp, w);
- jit_generator::kmovw(k, regw_tmp);
- };
-
- auto store = [=](Zmm r, int i) {
-
- auto padding = [=] (Reg64 reg, int pad) {
- kmovw(kTail, (1 << pad) - 1);
- auto k = kTail;
- auto base = reg;
- base.setOpmaskIdx(k.getIdx(), true);
-
- auto zmm_zero = zmm_tmp;
- vpxord(zmm_zero, zmm_zero, zmm_zero);
- auto addr = EVEX_compress_addr(base, i * tr_src_stride);
- vmovups(addr, zmm_zero);
- };
-
- int store_tail = (nrows%2) ? nrows+1 : nrows;
-
- int store_pad = (l_pad%2) ? l_pad/2 + 1 : l_pad/2;
- mov(reg_tr_src_tmp, reg_tr_src);
- if (l_pad > 0) {
- padding(reg_tr_src, store_pad);
- add(reg_tr_src_tmp, l_pad * typesize);
- }
- if (r_pad > 0) {
- store_pad = (r_pad%2) ? r_pad/2 + 1 : r_pad/2;
- int addr_shift = (r_pad%2) ? 1 : 0;
- add(reg_tr_src_tmp, (nrows - addr_shift) * typesize);
- padding(reg_tr_src_tmp, store_pad);
- }
-
- mov(reg_tr_src_tmp, reg_tr_src);
- add(reg_tr_src_tmp, l_pad * typesize);
-
- kmovw(kTail, (1 << store_tail/2) - 1);
- auto k = kTail;
- auto base = reg_tr_src_tmp;
- base.setOpmaskIdx(k.getIdx(), true);
-
- auto addr = EVEX_compress_addr(base, i * tr_src_stride);
- vmovups(addr, r);
-
- };
-
- kmovw(kFFFF, 0xffff);
- //all loads
- for (int i=0; i<16; i++){
- vpxord(src_zmm(i), src_zmm(i), src_zmm(i));
- }
-
- for (int i = 0; i < nrows/2; i++) {
- auto src0 = src_ymm(2*i);
- auto src1 = src_ymm(2*i+1);
- auto zmm_src0 = src_zmm(2*i);
- load_ymm(2*i);
-
- vpunpcklwd(src1, src0,
- EVEX_compress_addr(reg_src, (2*i+1) * src_stride));
- vpunpckhwd(src0, src0,
- EVEX_compress_addr(reg_src, (2*i+1) * src_stride));
- vinserti64x4(zmm_src0, zmm_src0, src1, 1);
- vpermps(zmm_src0 | kFFFF, vidx4, zmm_src0);
- }
-
- // for odd numbers we need to mix row with zeroes
- if (nrows%2) {
- int i = nrows-1;
- auto src0 = src_ymm(i);
- auto src1 = src_ymm(i+1); //zero
-
- auto zmm_src0 = src_zmm(i);
- vpxor(src1, src1, src1);
-
- load_ymm(i);
- vpunpckhwd(src0, src0, src1);
- vinserti64x4(zmm_tmp, zmm_tmp, src0, 0);
- vpxor(src0, src0, src0);
- load_ymm(i);
- vpunpcklwd(src1, src0, src1);
- vinserti64x4(zmm_tmp, zmm_tmp, src1, 1);
- vpxord(zmm_src0, zmm_src0, zmm_src0);
- vmovups(zmm_src0, zmm_tmp);
- vpermps(zmm_src0 | kFFFF, vidx4, zmm_src0);
- }
-
- // swap 1
- for (int i=0; i<4; i++) {
- auto zmm0 = src_zmm(4*i);
- auto zmm1 = src_zmm(4*i+2);
- auto tmp0 = src_zmm(4*i+1);
- auto tmp1 = src_zmm(4*i+3);
-
- vmovups(tmp0, zmm0);
- vmovups(tmp1, zmm1);
-
- vpermps(tmp0 | kAAAA, vidx3, zmm1);
- vpermps(tmp1 | k5555, vidx3, zmm0);
- }
- // swap 2
- int base_idx;
- base_idx=0;
- for (int i=0; i<2; i++) {
- auto zmm0 = src_zmm(base_idx+2*i+1);
- auto zmm1 = src_zmm(base_idx+2*i+5);
-
- auto tmp0 = src_zmm(base_idx+2*i);
- auto tmp1 = src_zmm(base_idx+2*i+4);
-
- vmovupd(tmp0, zmm0);
- vmovupd(tmp1, zmm1);
-
- vpermpd(tmp0 | kAA, vidx2, zmm1);
- vpermpd(tmp1 | k55, vidx2, zmm0);
- }
- base_idx=8;
- for (int i=0; i<2; i++) {
- auto zmm0 = src_zmm(base_idx+2*i+1);
- auto zmm1 = src_zmm(base_idx+2*i+5);
-
- auto tmp0 = src_zmm(base_idx+2*i);
- auto tmp1 = src_zmm(base_idx+2*i+4);
-
- vmovupd(tmp0, zmm0);
- vmovupd(tmp1, zmm1);
-
- vpermpd(tmp0 | kAA, vidx2, zmm1);
- vpermpd(tmp1 | k55, vidx2, zmm0);
- }
-
- // swap 3
- for (int i=0; i<4; i++) {
- auto zmm0 = src_zmm(2*i);
- auto zmm1 = src_zmm(2*i+8);
-
- auto tmp0 = src_zmm(2*i+1);
- auto tmp1 = src_zmm(2*i+9);
-
- vmovupd(tmp0, zmm0);
- vmovupd(tmp1, zmm1);
-
- vpermpd(tmp0 | kCC, vidx1, zmm1);
- vpermpd(tmp1 | k33, vidx1, zmm0);
- }
-
- // all stores
- for (int i=0; i<8; i++)
- vextracti64x4(src_ymm(2*i), src_zmm(2*i+1), 1);
-
- store(src_zmm(1), 0);
- store(src_zmm(0), 1);
- store(src_zmm(3), 2);
- store(src_zmm(2), 3);
- store(src_zmm(9), 4);
- store(src_zmm(8), 5);
- store(src_zmm(11), 6);
- store(src_zmm(10), 7);
- store(src_zmm(5), 8);
- store(src_zmm(4), 9);
- store(src_zmm(7), 10);
- store(src_zmm(6), 11);
- store(src_zmm(13), 12);
- store(src_zmm(12), 13);
- store(src_zmm(15), 14);
- store(src_zmm(14), 15);
-
-}
-
-void jit_trans_iw_ic_int16_t::generate() {
- preamble();
-
- alignas(64) static constexpr const int64_t idx1[8]
- = { 2, 3, 0, 1, 6, 7, 4, 5 };
- alignas(64) static constexpr const int64_t idx2[8]
- = { 1, 0, 3, 2, 5, 4, 7, 6 };
- alignas(64) static constexpr const int32_t idx3[16]
- = { 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14 };
- alignas(64) static constexpr const int32_t idx4[16]
- = { 8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7 };
- alignas(64) static constexpr const int32_t idx5[16]
- = { 8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7 };
-
- const int ic_block = conf_->ic_block;
- const int iw = conf_->iw;
- const int tr_iw = conf_->tr_iw;
- const int transposes = utils::div_up(iw, transpose_size);
- int loop_iters = nstl::max(0, transposes - 1);
- tail = iw - loop_iters * transpose_size;
-
- src_stride = ic_block * typesize;
- tr_src_stride = tr_iw * typesize;
-
- bool nontemporal_stores = false;
- enable_prefetch = iw > small_spatial ? 1 : 0;
-
- assert(transpose_size == ic_block);
- const int src_step = ic_block * transpose_size * typesize;
- const int tr_src_step = ic_block * typesize;
-
- const int left_pad = conf_->l_pad;
- const int right_pad = tr_iw - iw - left_pad;
-
- mov(reg_src, ptr [param1 + GET_OFF(src)]);
- mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]);
- mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]);
- mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]);
-
- auto kmovw = [=](Opmask k, unsigned w) {
- mov(regw_tmp, w);
- jit_generator::kmovw(k, regw_tmp);
- };
-
- kmovw(kFFFF, 0xffff);
- kmovw(k5555, 0x5555);
- kmovw(kAAAA, 0xaaaa);
- kmovw(kAA, 0xaa);
- kmovw(k55, 0x55);
- kmovw(kCC, 0xcc);
- kmovw(k33, 0x33);
-
- auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
- mov(imm_addr64, reinterpret_cast<size_t>(addr));
- jit_generator::vmovdqa64(z, ptr[imm_addr64]);
- };
-
- auto vmovdqa32 = [=](Zmm z, const int32_t *addr) {
- mov(imm_addr64, reinterpret_cast<size_t>(addr));
- jit_generator::vmovdqa32(z, ptr[imm_addr64]);
- };
-
- vmovdqa64(vidx1, idx1);
- vmovdqa64(vidx2, idx2);
- vmovdqa32(vidx3, idx3);
- vmovdqa32(vidx4, idx4);
- vmovdqa32(vidx5, idx5);
-
- if (left_pad > 0 && loop_iters > 0) {
- loop_iters--;
- transpose(transpose_size, left_pad, 0, nontemporal_stores);
- add(reg_src, src_step);
- add(reg_tr_src, tr_src_step + left_pad * typesize);
- add(reg_src_prf, src_step);
- add(reg_tr_src_prf, tr_src_step + left_pad * typesize);
- }
-
- if (loop_iters) {
- mov(reg_loop, loop_iters);
- Label loop;
- L(loop); {
- transpose(transpose_size, 0, 0, nontemporal_stores);
- add(reg_src, src_step);
- add(reg_tr_src, tr_src_step);
- add(reg_src_prf, src_step);
- add(reg_tr_src_prf, tr_src_step);
- sub(reg_loop, 1);
- jnz(loop);
- }
- }
- if (transposes > 1)
- transpose(tail, 0, right_pad, nontemporal_stores);
- else
- transpose(tail, left_pad, right_pad, nontemporal_stores);
-
- postamble();
-
-}
-
-struct jit_trans_ow_oc_t: public jit_trans_dst_t, public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_ow_oc_t)
- jit_trans_ow_oc_t(const jit_conv_conf_t *conf): jit_trans_dst_t(conf) {
- generate();
- ker_ = (decltype(ker_))this->getCode();
- }
-
-private:
- using reg64_t = const Xbyak::Reg64;
- using reg32_t = const Xbyak::Reg32;
- using opmask_t = const Xbyak::Opmask;
- using zmm = const Xbyak::Zmm;
-
- enum { typesize = sizeof(int16_t), transpose_size = 16, small_spatial = 14 };
- int src_stride, tr_src_stride;
- int tail;
- bool enable_prefetch;
-
- opmask_t kFF = k1;
-
- zmm vidx1 = zmm31;
-
- reg64_t reg_src = r8;
- reg64_t reg_tr_src = r9;
- reg64_t reg_src_prf = r10;
- reg64_t reg_tr_src_prf = r11;
- reg64_t reg_loop = r12;
- reg64_t reg_tr_src_tmp = r13;
- reg32_t regw_tmp = r14d;
- reg64_t imm_addr64 = rbx;
-
- void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores);
- void generate();
-};
-
-void jit_trans_ow_oc_t::transpose(int nrows, int l_pad, int r_pad,
- bool nontemporal_stores) {
- assert(nrows >= 0 && nrows <= transpose_size);
- static_assert(transpose_size == 16, "Unsupported transpose size");
- if (!nrows)
- return;
-
- auto src_zmm = [=](int i) {
- return Zmm(i);
- };
-
- auto src_ymm = [=](int i) {
- assert(i >= 0 && i < 16);
- return Ymm(i);
- };
-
- auto load_ymm = [=](int i) {
- vmovups(src_ymm(i), EVEX_compress_addr(reg_src, i * src_stride));
- };
-
-
- auto store = [=](Zmm r, int i) {
- auto addr = EVEX_compress_addr(reg_tr_src, i * tr_src_stride);
- if (nontemporal_stores)
- vmovntps(addr, r);
- else
- vmovups(addr, r);
- };
-
- for (int i = 0; i < nrows/2; i++) {
- auto src0 = src_ymm(2*i);
- auto src1 = src_ymm(2*i+1);
- auto zmm_src0 = src_zmm(2*i);
- load_ymm(2*i);
- vpunpcklwd(src1, src0,
- EVEX_compress_addr(reg_src, (2*i+1) * src_stride));
- vpunpckhwd(src0, src0,
- EVEX_compress_addr(reg_src, (2*i+1) * src_stride));
- vinserti64x4(zmm_src0, zmm_src0, src1, 1);
- vpermpd(zmm_src0 | kFF, vidx1, zmm_src0);
- store(zmm_src0, 2*i);
- }
- if (r_pad > 0) {
- auto src0 = src_ymm(nrows-1);
- auto src1 = src_ymm(nrows);
- auto zmm_src0 = src_zmm(30);
- load_ymm(nrows-1);
-
- vpxor(src1, src1, src1);
- vpunpckhwd(src1, src0, src1);
- vinserti64x4(zmm_src0, zmm_src0, src1, 0);
- vpxor(src1, src1, src1);
- vpunpcklwd(src0, src0, src1);
- vinserti64x4(zmm_src0, zmm_src0, src0, 1);
- vpermpd(zmm_src0 | kFF, vidx1, zmm_src0);
- store(zmm_src0, nrows-1);
- }
-}
-
-void jit_trans_ow_oc_t::generate() {
- preamble();
-
- alignas(64) static constexpr const int64_t idx1[8]
- = { 4, 5, 0, 1, 6, 7, 2, 3 };
-
- const int oc_block = conf_->oc_block;
- const int ow = conf_->ow;
- const int transposes = utils::div_up(ow, transpose_size);
- int loop_iters = nstl::max(0, transposes - 1);
- tail = ow - loop_iters * transpose_size;
-
- src_stride = oc_block * typesize;
- tr_src_stride = oc_block * typesize;
-
- bool nontemporal_stores = false;
- enable_prefetch = ow > small_spatial ? 1 : 0;
-
- const int src_step = oc_block * transpose_size * typesize;
- const int tr_src_step = oc_block * transpose_size * typesize;
- const int right_pad = ow % 2;
-
- mov(reg_src, ptr [param1 + GET_OFF(src)]);
- mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]);
- mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]);
- mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]);
-
- auto kmovw = [=](Opmask k, unsigned w) {
- mov(regw_tmp, w);
- jit_generator::kmovw(k, regw_tmp);
- };
-
- kmovw(kFF, 0xFF);
-
- auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
- mov(imm_addr64, reinterpret_cast<size_t>(addr));
- jit_generator::vmovdqa64(z, ptr[imm_addr64]);
- };
-
- vmovdqa64(vidx1, idx1);
- if (loop_iters) {
- mov(reg_loop, loop_iters);
- Label loop;
- L(loop); {
- transpose(transpose_size, 0, 0, nontemporal_stores);
- add(reg_src, src_step);
- add(reg_tr_src, tr_src_step);
- add(reg_src_prf, src_step);
- add(reg_tr_src_prf, tr_src_step);
- sub(reg_loop, 1);
- jnz(loop);
- }
- }
- transpose(tail, 0, right_pad, nontemporal_stores);
-
- postamble();
-}
-
-struct jit_trans_iw_x4_4x_t: public jit_trans_src_t, public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_x4_4x_t)
-
- jit_trans_iw_x4_4x_t(const jit_conv_conf_t *conf): jit_trans_src_t(conf) {
- generate();
- ker_ = (decltype(ker_))this->getCode();
- }
-
- void generate();
- enum { typesize = (int)sizeof(float) };
-};
-
-/** @brief transposition of the form [:][iw/4][4] -> [:][4][iw/4]
- * required for 1st 4fma backward by weights convolution */
-void jit_trans_iw_x4_4x_t::generate() {
- using namespace utils;
-
- /* TODO: put into code */
- static int mask[16] = {
- 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, };
-
- const auto &c = *conf_;
- const int simd_w = cpu_isa_traits<avx512_common>::vlen / typesize;
- const int niters = c.tr_ld / simd_w;
-
- assert(niters <= 4); /* [bwd_w:tr_src:r1] */
-
- Reg64 reg_ptr_src = r8;
- Reg64 reg_ptr_tr_src = r9;
-
- Reg64 reg_ih = rax;
- Reg64 reg_ih_end = rbx;
-
- Reg64 reg_nthr_oc_b = rsi;
- Reg64 reg_ptr_tr_src_bctx = abi_not_param1;
-
- Reg64 reg_tmp = rdx;
-
- Zmm vmsk = Zmm(31);
- Opmask kmsk = k7;
-
- auto emit_tr_sync = [&]() {
- simple_barrier::generate(*this, reg_ptr_tr_src_bctx, reg_nthr_oc_b);
- };
-
- auto emit_tr_iw = [&]() {
- auto vreg = [](int iter, int i) {
- assert(4 * iter + i < 24);
- return Zmm(4 * iter + i);
- };
- auto vtmp = [](int i) { return Zmm(24 + i); };
-
- auto emit_load = [&](int iter) {
- for (int i = 0; i < 4; ++i) {
- auto v = vreg(iter, i);
- const int off = (iter * 4 + i) * simd_w;
-
- if (off + simd_w <= c.iw)
- vmovups(v, ptr[reg_ptr_src + off * typesize]);
- else if (off < c.iw)
- vmovups(v | kmsk | T_z, ptr[reg_ptr_src + off * typesize]);
- else
- vpxord(v, v, v);
- }
- };
-
- auto emit_tr = [&](int iter) {
- for (int i = 0; i < 4; ++i)
- vpermps(vreg(iter, i), vmsk, vreg(iter, i));
-
- vshuff32x4(vtmp(0), vreg(iter, 0), vreg(iter, 1), 0x88);
- vshuff32x4(vtmp(1), vreg(iter, 0), vreg(iter, 1), 0xdd);
- vshuff32x4(vtmp(2), vreg(iter, 2), vreg(iter, 3), 0x88);
- vshuff32x4(vtmp(3), vreg(iter, 2), vreg(iter, 3), 0xdd);
-
- vshuff32x4(vreg(iter, 0), vtmp(0), vtmp(2), 0x88);
- vshuff32x4(vreg(iter, 2), vtmp(0), vtmp(2), 0xdd);
- vshuff32x4(vreg(iter, 1), vtmp(1), vtmp(3), 0x88);
- vshuff32x4(vreg(iter, 3), vtmp(1), vtmp(3), 0xdd);
- };
-
- auto emit_store = [&]() {
- for (int i = 0; i < 4; ++i) {
- for (int iter = 0; iter < niters; ++iter) {
- const size_t off = i * c.tr_ld + iter * simd_w;
- vmovups(ptr[reg_ptr_tr_src + off * typesize], vreg(iter, i));
- }
- }
- };
-
- for (int iter = 0; iter < niters; ++iter)
- emit_load(iter);
-
- for (int iter = 0; iter < niters; ++iter)
- emit_tr(iter);
-
- emit_store();
- };
-
- preamble();
-
- mov(reg_ptr_src, ptr[abi_param1 + GET_OFF(src)]);
- mov(reg_ptr_tr_src, ptr[abi_param1 + GET_OFF(tr_src)]);
-
- mov(reg_nthr_oc_b.cvt32(), ptr[abi_param1 + GET_OFF(nthr_oc_b)]);
- mov(reg_ih.cvt32(), ptr[abi_param1 + GET_OFF(tr_src_ih_start)]);
- mov(reg_ih_end.cvt32(), ptr[abi_param1 + GET_OFF(tr_src_ih_end)]);
- mov(reg_ptr_tr_src_bctx, ptr[abi_param1 + GET_OFF(tr_src_bctx)]);
-
- emit_tr_sync();
-
- Label l_ih_loop, l_tr_done;
- cmp(reg_ih, reg_ih_end);
- je(l_tr_done, T_NEAR);
-
- mov(reg_tmp, (size_t)&mask[0]);
- vmovups(vmsk, ptr[reg_tmp]);
-
- if (c.iw % simd_w) {
- const char load_mask = (1 << (c.iw % simd_w)) - 1;
- mov(reg_tmp, load_mask);
- kmovw(kmsk, reg_tmp.cvt32());
- }
-
- /* src += ih_start * c.iw; */
- imul(reg_tmp, reg_ih, c.iw * typesize);
- add(reg_ptr_src, reg_tmp);
- /* tr_src += ih_start * c.stride_w * c.tr_ld; */
- imul(reg_tmp, reg_ih, c.stride_w * c.tr_ld * typesize);
- add(reg_ptr_tr_src, reg_tmp);
-
- L(l_ih_loop); {
- emit_tr_iw();
-
- add(reg_ptr_src, c.iw * typesize);
- add(reg_ptr_tr_src, c.stride_w * c.tr_ld * typesize);
-
- inc(reg_ih);
- cmp(reg_ih, reg_ih_end);
- jl(l_ih_loop, T_NEAR);
- }
-
- L(l_tr_done);
-
- emit_tr_sync();
-
- postamble();
-}
-
-/*
-// -------------------------------------------------
-// jit_transpose4x16_src
-// -------------------------------------------------
-*/
-
-void jit_transpose4x16_src::transpose(int nrows)
-{
- assert(nrows >= 0 && nrows <= transpose_size);
- static_assert(transpose_size == 4, "Unsupported transpose size");
- if (!nrows)
- return;
-
- auto pf_src_t0 = [=](int i) {
- if (tparams->src_pf0_distance)
- prefetcht0(EVEX_compress_addr(
- reg_src, (tparams->src_pf0_distance + i) * src_stride));
- };
-
- auto pf_tr_src_t0 = [=](int i) {
- if (tparams->tr_src_pf0_distance)
- prefetcht0(EVEX_compress_addr(reg_tr_src,
- (tparams->tr_src_pf0_distance + i) * src_stride));
- };
-
- auto pf_src_t1 = [=](int i) {
- if (tparams->src_pf1)
- prefetcht1(EVEX_compress_addr(reg_src_prf, i * src_stride));
- };
-
- auto pf_tr_src_t1 = [=](int i) {
- if (tparams->tr_src_pf1)
- prefetchwt1(EVEX_compress_addr(reg_tr_src_prf, i * tr_src_stride));
- };
-
- auto src_zmm = [=](int i) {
- assert(i >= 0 && i < 4);
- return Zmm(i);
- };
-
- auto tmp_zmm = [=](int i) {
- assert(i >= 0 && i < 4);
- return Zmm(4 + i);
- };
-
- auto load = [=](int i) {
- vmovups(src_zmm(i), EVEX_compress_addr(reg_src, i * src_stride));
- };
-
- auto store = [=](Zmm r, int i) {
- vmovups(EVEX_compress_addr(reg_tr_src, i * tr_src_stride), r);
- };
-
- auto tmp0 = tmp_zmm(0);
- auto tmp1 = tmp_zmm(1);
- auto tmp2 = tmp_zmm(2);
- auto tmp3 = tmp_zmm(3);
-
- auto src0 = src_zmm(0);
- auto src1 = src_zmm(1);
- auto src2 = src_zmm(2);
- auto src3 = src_zmm(3);
- for (int i = 0; i < nrows; i++) {
- load(i);
- }
-
- for (size_t i = nrows; i < 4; i++) {
- vpxord(src_zmm(i), src_zmm(i), src_zmm(i));
- }
-
- vmovupd(tmp0, src0);
- vmovupd(tmp1, src1);
- pf_src_t0(0);
- vpermpd(tmp0 | kF0, vidx01, src2);
- vpermpd(tmp1 | kF0, vidx01, src3);
-
- valignd(src0, src0, src0, 8);
- valignd(src1, src1, src1, 8);
- pf_src_t0(1);
- vmovupd(tmp2, src0);
- vmovupd(tmp3, src1);
- pf_src_t0(2);
- vpermpd(tmp2 | kF0, vidx10, src2);
- vpermpd(tmp3 | kF0, vidx10, src3);
- pf_src_t0(3);
-
- vmovupd(src0, tmp0);
- pf_src_t1(0);
- vmovupd(src1, tmp2);
- pf_src_t1(1);
- vmovupd(src2, tmp1);
- pf_src_t1(2);
- vmovupd(src3, tmp3);
- pf_src_t1(3);
- vpermpd(src0 | kCC, vidx1, tmp1);
- vpermpd(src1 | kCC, vidx1, tmp3);
- pf_tr_src_t0(0);
- vpermpd(src2 | k33, vidx1, tmp0);
- vpermpd(src3 | k33, vidx1, tmp2);
- pf_tr_src_t0(1);
-
- vmovupd(tmp0, src0);
- vmovupd(tmp1, src2);
- pf_tr_src_t0(2);
- vmovupd(tmp2, src1);
- vmovupd(tmp3, src3);
- pf_tr_src_t0(3);
- vpermps(tmp0 | kFFFF, vidxP, src0);
- pf_tr_src_t1(0);
- vpermps(tmp1 | kFFFF, vidxP, src2);
- pf_tr_src_t1(1);
- vpermps(tmp2 | kFFFF, vidxP, src1);
- pf_tr_src_t1(3);
- vpermps(tmp3 | kFFFF, vidxP, src3);
- pf_tr_src_t1(4);
-
- store(tmp0, 0);
- store(tmp1, 1);
- store(tmp2, 2);
- store(tmp3, 3);
-}
-
-alignas(64) static constexpr const int64_t idx01[8]
- = { 0, 0, 0, 0, 0, 1, 2, 3 };
-alignas(64) static constexpr const int64_t idx10[8]
- = { 0, 0, 0, 0, 4, 5, 6, 7 };
-alignas(64) static constexpr const int64_t idx1[8] = { 2, 3, 0, 1, 6, 7, 4, 5 };
-alignas(64) static constexpr const int32_t idxP[16]
- = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 };
-
-void jit_transpose4x16_src::generate()
-{
- preamble();
-
- const int ic_block = params->ic_block;
- const int is = params->is;
- int tail = is % transpose_size;
-
- src_stride = ic_block * typesize;
- assert(src_stride == 64);
- tr_src_stride = ic_block * typesize;
-
- const int src_step = ic_block * transpose_size * typesize;
- const int tr_src_step = ic_block * transpose_size * typesize;
-
-#define GET_TR_OFF(x) offsetof(jit_src_transpose_s, x)
- mov(reg_loop, ptr[param1 + GET_TR_OFF(size)]);
- mov(reg_src, ptr[param1 + GET_TR_OFF(src)]);
- mov(reg_tr_src, ptr[param1 + GET_TR_OFF(tr_src)]);
- mov(reg_src_prf, ptr[param1 + GET_TR_OFF(src_prf)]);
- mov(reg_tr_src_prf, ptr[param1 + GET_TR_OFF(tr_src_prf)]);
-#undef GET_TR_OFF
-
- auto kmovw = [=](Opmask k, unsigned w) {
- mov(regw_tmp, w);
- jit_generator::kmovw(k, regw_tmp);
- };
-
- auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
- mov(imm_addr64, reinterpret_cast<size_t>(addr));
- jit_generator::vmovdqa64(z, ptr[imm_addr64]);
- };
-
- auto vmovdqa32 = [=](Zmm z, const int32_t *addr) {
- mov(imm_addr64, reinterpret_cast<size_t>(addr));
- jit_generator::vmovdqa32(z, ptr[imm_addr64]);
- };
-
- kmovw(kF0, 0xf0); // 11110000
- kmovw(kCC, 0xcc); // 11001100
- kmovw(k33, 0x33); // 00110011
- kmovw(kFFFF, 0xffff); // 1111111111111111
-
- vmovdqa64(vidx01, idx01);
- vmovdqa64(vidx10, idx10);
- vmovdqa64(vidx1, idx1);
- vmovdqa32(vidxP, idxP);
-
- Label loop_label;
- Label tail_label;
-
- cmp(reg_loop, transpose_size);
- jl(tail_label, T_NEAR);
-
- L(loop_label);
- {
- transpose(transpose_size);
- add(reg_src, src_step);
- add(reg_tr_src, tr_src_step);
- add(reg_src_prf, src_step);
- add(reg_tr_src_prf, tr_src_step);
- sub(reg_loop, transpose_size);
- cmp(reg_loop, transpose_size);
- jge(loop_label, T_NEAR);
- }
- L(tail_label);
- transpose(tail);
-
- postamble();
-}
-
-jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf) {
- if (conf->ver == ver_4fma && !conf->is_1stconv)
- return new jit_trans_iw_ic_t(conf);
- if (conf->ver == ver_4fma && conf->is_1stconv)
- return new jit_trans_iw_x4_4x_t(conf);
- assert(!"unsupported configuration");
- return nullptr;
-}
-
-jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf) {
- assert(!"unsupported configuration");
- return nullptr;
-}
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp
deleted file mode 100644
index 565e97e4fc..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp
+++ /dev/null
@@ -1,145 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_TRANSPOSE_SRC_HPP
-#define CPU_JIT_TRANSPOSE_SRC_HPP
-
-#include "cpu_barrier.hpp"
-#include "jit_primitive_conf.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct jit_trans_src_t {
- struct ctx_t {
- const void *src;
- const void *tr_src;
- const void *src_prf;
- const void *tr_src_prf;
-
- /* 1st conv 4fma: backward by weights */
- int nthr_oc_b; /* number of threads process given src image */
- int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */
- simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */
- };
-
- jit_trans_src_t(const jit_conv_conf_t *conf)
- : conf_(conf), ker_(nullptr) {}
- virtual ~jit_trans_src_t() {}
-
- void operator()(const ctx_t *ctx)
- { assert(ker_); ker_(ctx); }
-
- const jit_conv_conf_t *conf_;
- void (*ker_)(const ctx_t *);
-};
-
-struct jit_src_transpose_s {
- size_t size;
- const void *src;
- const void *tr_src;
- const void *src_prf;
- const void *tr_src_prf;
-};
-
-struct jit_trans_dst_t {
- struct ctx_t {
- const void *src;
- const void *tr_src;
- const void *src_prf;
- const void *tr_src_prf;
-
- /* 1st conv 4fma: backward by weights */
- int nthr_oc_b; /* number of threads process given src image */
- int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */
- simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */
- };
-
- jit_trans_dst_t(const jit_conv_conf_t *conf)
- : conf_(conf), ker_(nullptr) {}
- virtual ~jit_trans_dst_t() {}
-
- void operator()(const ctx_t *ctx)
- { assert(ker_); ker_(ctx); }
-
- const jit_conv_conf_t *conf_;
- void (*ker_)(const ctx_t *);
-};
-
-struct jit_transpose4x16_src_t {
- int src_pf0_distance;
- int tr_src_pf0_distance;
- bool src_pf1;
- bool tr_src_pf1;
-};
-
-struct jit_transpose4x16_src : public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_transpose4x16_src)
-
- jit_transpose4x16_src(const jit_1x1_conv_conf_t *aparams,
- jit_transpose4x16_src_t *tparams_)
- : params(aparams), tparams(tparams_)
- {
- this->generate();
- jit_ker = (decltype(jit_ker))this->getCode();
- }
-
- const jit_1x1_conv_conf_t *params;
- const jit_transpose4x16_src_t *tparams;
- void (*jit_ker)(jit_src_transpose_s *);
-
- void operator()(jit_src_transpose_s *arg) { jit_ker(arg); }
-
- static const int transpose_size = 4;
-private:
- static const int typesize = sizeof(float);
-
- int src_stride, tr_src_stride;
-
- Xbyak::Reg64 imm_addr64 = rbx;
-
- Xbyak::Opmask kF0 = k1;
- Xbyak::Opmask kCC = k2;
- Xbyak::Opmask k33 = k3;
- Xbyak::Opmask kFFFF = k4;
-
- Xbyak::Zmm vidx01 = zmm31;
- Xbyak::Zmm vidx10 = zmm30;
- Xbyak::Zmm vidx1 = zmm29;
- Xbyak::Zmm vidxP = zmm28;
-
- Xbyak::Reg64 reg_src = r8;
- Xbyak::Reg64 reg_tr_src = r9;
- Xbyak::Reg64 reg_src_prf = r10;
- Xbyak::Reg64 reg_tr_src_prf = r11;
- Xbyak::Reg64 reg_loop = r12;
- Xbyak::Reg64 reg_tr_src_tmp = r13;
- Xbyak::Reg32 regw_tmp = r14d;
-
- void transpose_block(int ur, int nrows);
- void transpose(int nrows);
- void generate();
-};
-
-jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf);
-jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf);
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp
deleted file mode 100644
index 53313f9f01..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp
+++ /dev/null
@@ -1,327 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef JIT_UNI_1x1_CONV_UTILS_HPP
-#define JIT_UNI_1x1_CONV_UTILS_HPP
-
-#include "memory_tracking.hpp"
-#include "mkldnn_thread.hpp"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "jit_generator.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::utils;
-
-struct reduce_to_unit_stride_t {
- convolution_desc_t conv_d_;
- bool reduce_src_;
- size_t space_per_thread_;
-};
-
-/* 1x1-kernel does not support non-unit strides so far, so the idea is:
- * - for fwd or bwd_weights: to copy src to a scratch memory (with strides
- * equal to 1) and then call the kernel
- * - for bwd_data: reduce the problem to the one with unit stride by
- * performing computations in a scratch memory (with strides equal to 1)
- * and then copy the result to diff_src */
-template <typename conv_pd_t>
-inline void rtus_prepare(conv_pd_t *self, const convolution_desc_t *&conv_d,
- const memory_desc_t *&src_d, const memory_desc_t *dst_d) {
- const bool is_bwd_data = self->desc()->prop_kind
- == prop_kind::backward_data;
-
- const int ndims = src_d->ndims;
- const auto dat_tag = ndims == 3
- ? memory_desc_wrapper(dst_d).matches_one_of_tag(
- format_tag::nCw8c, format_tag::nCw16c)
- : memory_desc_wrapper(dst_d).matches_one_of_tag(
- format_tag::nChw8c, format_tag::nChw16c);
-
- bool rtus_applicable = true
- && utils::pick(ndims - 3,
- (conv_d->strides[0] != 1 && !one_of(conv_d->src_desc.data_type,
- data_type::s32)),
- (conv_d->strides[0] != 1 || conv_d->strides[1] != 1))
- && dat_tag != format_tag::undef;
- for (int d = 2; d < ndims; ++d) {
- /* TODO: relax these conditions (by improving reducer) */
- rtus_applicable = rtus_applicable
- && conv_d->padding[0][d - 2] == 0
- && dst_d->dims[d] * conv_d->strides[d - 2] == src_d->dims[d];
- }
-
- if (rtus_applicable) {
- self->rtus_.reduce_src_ = true;
- conv_d = &(self->rtus_.conv_d_ = *conv_d);
- self->rtus_.conv_d_.strides[0] = 1;
- if (ndims == 4)
- self->rtus_.conv_d_.strides[1] = 1;
- utils::array_set(self->rtus_.conv_d_.padding[0], 0, 2);
- if (ndims == 4)
- utils::array_set(self->rtus_.conv_d_.padding[1], 0, 2);
- const int ic = src_d->dims[1];
- if (is_bwd_data) {
- src_d = &(self->rtus_.conv_d_.diff_src_desc = *dst_d);
- self->rtus_.conv_d_.diff_src_desc.dims[1] = ic;
- memory_desc_wrapper::compute_blocking(
- self->rtus_.conv_d_.diff_src_desc, dat_tag);
- } else {
- data_type_t data_type = self->rtus_.conv_d_.src_desc.data_type;
- src_d = &(self->rtus_.conv_d_.src_desc = *dst_d);
- self->rtus_.conv_d_.src_desc.dims[1] = ic;
- self->rtus_.conv_d_.src_desc.data_type = data_type;
- memory_desc_wrapper::compute_blocking(
- self->rtus_.conv_d_.src_desc, dat_tag);
- }
- }
-}
-
-template <typename conv_pd_t>
-inline void rtus_prepare_space_info(conv_pd_t *self,
- memory_tracking::registrar_t &scratchpad) {
- const auto &jcp = self->jcp_;
-
- const int max_threads = mkldnn_get_max_threads();
- const size_t factor = utils::pick_by_prop_kind(self->desc()->prop_kind,
- jcp.nb_reduce, jcp.nb_load_blocking_max, jcp.nb_bcast_blocking);
- size_t typesize = types::data_type_size(
- conv_prop_invariant_src_d(self->desc())->data_type);
-
- self->rtus_.space_per_thread_ = factor * jcp.is * jcp.ic_block;
- scratchpad.book(memory_tracking::names::key_conv_rtus_space,
- typesize * max_threads * self->rtus_.space_per_thread_);
-}
-
-template <cpu_isa_t isa>
-struct rtus_driver_t: public jit_generator {
-
- struct call_params_t {
- const void *ws; /* reduced image (w/ strides = 1) */
- const void *src; /* source image (w/ non-unit strides) */
- size_t icb;
- size_t os;
- size_t iw_start;
- };
-
- void (*ker_)(const call_params_t *p);
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(rtus_driver_t)
-
- /* cpu specific part */
- using Vmm = typename utils::conditional<isa == avx2, Xbyak::Ymm,
- Xbyak::Zmm>::type;
-
- Xbyak::Reg64 reg_ws = abi_param1;
- Xbyak::Reg64 reg_src = abi_not_param1;
- Xbyak::Reg64 reg_icb = rdx;
- Xbyak::Reg64 reg_os = r11;
- Xbyak::Reg64 reg_iw_start = r8;
-
- Xbyak::Reg64 reg_cur_os = rax;
- Xbyak::Reg64 reg_cur_iw = r9;
- Xbyak::Reg64 reg_cur_src = r10;
-
- int iw_, stride_w_;
- int src_step_h_, src_step_icb_, ws_step_icb_, vlen_, vlen_shift_;
- bool src_to_ws_;
- size_t typesize_;
- Vmm reg_zero;
- Vmm reg_v;
-
- rtus_driver_t(int iw, int stride_w, int src_step_h,
- int src_step_icb, int ws_step_icb, bool src_to_ws, size_t typesize)
- : iw_(iw), stride_w_(stride_w), src_step_h_(src_step_h)
- , src_step_icb_(src_step_icb), ws_step_icb_(ws_step_icb)
- , src_to_ws_(src_to_ws), typesize_(typesize)
- {
- using namespace Xbyak;
- vlen_ = cpu_isa_traits<isa>::vlen;
- vlen_shift_ = cpu_isa_traits<isa>::vlen_shift;
- if (typesize_ == 2) {
- vlen_ /= 2;
- vlen_shift_--;
- }
-
- reg_zero = Vmm(0);
- reg_v = Vmm(1);
-
- generate();
- }
-
- void loop_is() {
- using namespace Xbyak;
-
- mov(reg_cur_src, reg_src);
- mov(reg_cur_iw, reg_iw_start);
- mov(reg_cur_os, reg_os);
-
- Label is_loop, skip_h_step;
- L(is_loop);
-
- if (src_to_ws_) {
- vmovups(reg_v, ptr[reg_cur_src]);
- vmovups(ptr[reg_ws], reg_v);
- } else {
- vmovups(reg_v, ptr[reg_ws]);
- vmovups(ptr[reg_cur_src], reg_v);
- for (int w = 1; w < stride_w_; ++w)
- vmovups(ptr[reg_cur_src + w * vlen_], reg_zero);
- }
-
- add(reg_ws, vlen_);
-
- add(reg_cur_iw, stride_w_);
- add(reg_cur_src, stride_w_ * vlen_);
-
- cmp(reg_cur_iw, iw_);
- jl(skip_h_step);
- /* for 1d convolution the loop over h should be skipped */
- if (src_step_icb_ == iw_) jmp(skip_h_step);
-
- if (src_to_ws_) {
- add(reg_cur_src, (src_step_h_ - iw_) * vlen_);
- } else {
- Xbyak::Reg64 reg_cur_src_fin = reg_cur_iw; /* just reuse */
- mov(reg_cur_src_fin, reg_cur_src);
- add(reg_cur_src_fin, (src_step_h_ - iw_) * vlen_);
- Label ih_loop;
- L(ih_loop);
-
- for (int w = 0; w < stride_w_; ++w)
- vmovups(ptr[reg_cur_src + w * vlen_], reg_zero);
-
- add(reg_cur_src, stride_w_ * vlen_);
- cmp(reg_cur_src, reg_cur_src_fin);
- jl(ih_loop);
- }
- xor_(reg_cur_iw, reg_cur_iw);
-
- L(skip_h_step);
-
- sub(reg_cur_os, vlen_);
- jnz(is_loop);
-
- /* restore dst */
- sub(reg_ws, reg_os);
- }
-
- void generate() {
- using namespace Xbyak;
- assert(isa == avx2 || isa == avx512_common
- || isa == avx512_core || isa == avx512_mic);
-
-#if defined(_WIN32)
- assert(reg_src == abi_not_param1 && abi_not_param1 == rdi);
- push(rdi);
-#endif
-
-#define READ_PARAM(what) \
- mov(reg_ ## what, ptr[abi_param1 + offsetof(call_params_t, what)])
- READ_PARAM(src);
- READ_PARAM(icb);
- READ_PARAM(os);
- READ_PARAM(iw_start);
-
- assert(reg_ws == abi_param1);
- READ_PARAM(ws); /* reg_ws should always be read the last */
-#undef READ_PARAM
-
- shl(reg_os, vlen_shift_);
-
- if (!src_to_ws_)
- uni_vpxor(reg_zero, reg_zero, reg_zero);
-
- Label icb_loop;
- L(icb_loop);
-
- loop_is();
-
- add(reg_ws, ws_step_icb_ * vlen_);
- add(reg_src, src_step_icb_ * vlen_);
-
- dec(reg_icb);
- jnz(icb_loop, T_NEAR);
-
-#if defined(_WIN32)
- pop(rdi);
-#endif
-
- uni_vzeroupper();
- ret();
- this->ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(
- this->getCode()));
- }
-};
-
-template <cpu_isa_t isa, typename conv_t>
-inline void init_rtus_driver(conv_t *self) {
- const auto &conf = *self->pd();
- if (!conf.rtus_.reduce_src_) return;
-
- const auto &cd = *conf.desc();
- const int ndims = conf.ndims();
- const int stride_h = (conf.ndims() == 3) ? 1 : cd.strides[0];
- const int stride_w = cd.strides[ndims - 3];
-
- const bool is_bwd_data = cd.prop_kind == prop_kind::backward_data;
- const auto &src_d = is_bwd_data ? *conf.diff_src_md() : *conf.src_md();
-
- const int ih = ndims == 3 ? 1 : src_d.dims[2];
- const int iw = src_d.dims[ndims - 1];
-
- const int src_step_h = stride_h * iw;
- const int src_step_icb = ih * iw;
- const int ws_step_icb = conf.jcp_.is;
- const bool src_to_ws = !is_bwd_data;
- const size_t typesize = types::data_type_size(
- conv_prop_invariant_src_d(self->pd()->desc())->data_type);
-
- self->rtus_driver_ = new rtus_driver_t<isa>(iw, stride_w, src_step_h,
- src_step_icb, ws_step_icb, src_to_ws, typesize);
-}
-
-inline int best_divider(int value, int min_divider, int max_divider,
- bool find_max, int step = 1)
-{
- max_divider = nstl::max(1, nstl::min(max_divider, value));
- min_divider = nstl::max(1, nstl::min(min_divider, max_divider));
-
- auto loss_ratio = [](int total, int chunk)
- { return float(rnd_up(total, chunk) - total) / rnd_up(total, chunk); };
-
- float min_loss = FLT_MAX;
- int x_divider = max_divider;
- for (int divider = max_divider; divider >= min_divider; divider -= step) {
- const float loss = loss_ratio(value, divider);
- if ((find_max && loss < min_loss) || (!find_max && loss <= min_loss)) {
- min_loss = loss;
- x_divider = divider;
- }
- }
- return x_divider;
-}
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp
deleted file mode 100644
index 72fe3a8109..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp
+++ /dev/null
@@ -1,1407 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "math_utils.hpp"
-#include "memory_tracking.hpp"
-#include "mkldnn_thread.hpp"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_barrier.hpp"
-#include "cpu_batch_normalization_utils.hpp"
-#include "jit_generator.hpp"
-
-#include "jit_uni_batch_normalization.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-namespace {
-
-using namespace memory_tracking::names;
-
-using namespace Xbyak;
-namespace barrier = simple_barrier;
-
-typedef float data_t;
-
-template <cpu_isa_t isa>
-struct jit_bnorm_t: public jit_generator {
- struct call_params_t {
- // keep all sizes at 8 bytes -- jit code expects this
- size_t N_ithr, N_nthr;
- size_t coff_max, soff_max;
- size_t mb_stride_Bc, spat_size, spat_size_loc;
- size_t S_s, S_tail;
- size_t is_cblk_tail;
- data_t chan_size, eps, one;
- const data_t *scale_shift;
- const data_t *mean, *var;
- const data_t *diff_scale_shift;
- const data_t *src, *dst;
- const data_t *diff_src, *diff_dst;
- const data_t *rbuf1, *rbuf2;
- const uint8_t *ws;
- barrier::ctx_t *barrier;
- };
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_t)
-
- /* cpu specific part */
- using Vmm = typename utils::conditional3<isa == sse42, Xmm,
- isa == avx2, Ymm, Zmm>::type;
- const AddressFrame &vmmword = (isa == sse42) ? xword :
- (isa == avx2) ? yword : zword;
-
- const int vlen = isa == sse42 ? 32 : cpu_isa_traits<isa>::vlen;
-
- const batch_normalization_pd_t *bdesc_;
- bool is_spatial_thr_;
-
- void (*ker)(const call_params_t *);
- void operator()(const call_params_t *p) { (*ker)(p); }
-
- Reg64 reg_param = abi_param1;
-
- Reg64 reg_scale_shift = rbx;
- Reg64 reg_rbuf1 = abi_not_param1;
- Reg64 reg_rbuf2 = rdx;
-
- Reg64 reg_mean = rbp;
- Reg64 reg_var = reg_param;
- Reg64 reg_diff_scale_shift = rax;
-
- Reg64 reg_coff = r8;
- Reg64 reg_coff_max = r9;
- Reg64 reg_soff = r10;
- Reg64 reg_soff_max = r11;
- Reg64 reg_ctr = r12;
- Reg64 reg_roff = r13;
-
- Reg64 reg_mb_stride_Bc = r14;
-
- Reg64 reg_src = r15;
- Reg64 reg_diff_src = reg_rbuf1;
- Reg64 reg_dst = rsi;
- Reg64 reg_diff_dst = reg_dst;
-
- Reg64 reg_tmp_off = reg_roff;
-
- // Reuse loop counters
- Reg64 reg_bar = reg_coff;
- Reg64 reg_nnthr = reg_soff; // must be usable w/ loops over coff
- Reg64 reg_tmp = reg_ctr;
-
- // Relu section
- bool with_relu, with_relu_inf_only;
- Vmm vzero; // is_fwd() ? vdiff_beta : vbeta
- Reg64 reg_ws = reg_roff;
- Label l_relu_mask_avx2;
- Opmask kstore_mask = Opmask(1);
-
- // channel tail processing
- Opmask ktail_mask = Opmask(2);
-
- size_t unroll_blocks;
- size_t unroll_regs;
- Vmm vbuf = Vmm(isa == avx512_common ? 20 : 5);
- Vmm vdiff_beta = Vmm(isa == avx512_common ? 21 : 6);
- Vmm vdiff_gamma = Vmm(isa == avx512_common ? 22 : 7);
- Vmm vsqrtvar = Vmm(isa == avx512_common ? 23 : 8);
- Vmm vone = Vmm(isa == avx512_common ? 24 : 9);
- Vmm vmean = Vmm(isa == avx512_common ? 25 : 10);
- Vmm vgamma = Vmm(isa == avx512_common ? 26 : 11);
- Vmm vbeta = Vmm(isa == avx512_common ? 27 : 12);
- Vmm veps = Vmm(isa == avx512_common ? 28 : 13);
- Vmm vchan_size = Vmm(isa == avx512_common ? 29 : 14);
- Vmm vtail_mask = Vmm(isa == avx512_common ? 30 : 15);
-
- size_t t0_pf_offt;
- size_t t1_pf_offt;
- size_t spat_size;
- size_t chan_data_offt;
-
- enum {
- stack_off_N_nthr = 0,
- stack_off_N_ithr = 8,
- stack_off_src = 16,
- stack_off_dst = 24,
- stack_off_diff_src = 32,
- stack_off_diff_dst = 40,
- stack_off_diff_scale_shift = 48,
- stack_off_ws = 56,
- stack_off_barrier = 64,
- stack_off_spat_size_loc = 72,
- stack_off_s_s = 80,
- stack_off_s_tail = 88,
- stack_off_is_cblk_tail = 96,
- stack_size_required = 104,
- };
-
- bool is_c_padded() const {
- const memory_desc_wrapper data_d(bdesc_->src_md());
- return bdesc_->C() != data_d.padded_dims()[1];
- }
-
- void compute_static_strides() {
- spat_size = bdesc_->D() * bdesc_->W() * bdesc_->H();
- chan_data_offt = bdesc_->C() * sizeof(data_t);
-
- if (isa == avx512_mic) {
- t0_pf_offt = 4096;
- t1_pf_offt = 0;
- } else {
- t0_pf_offt = 0;
- t1_pf_offt = 0;
- }
- }
-
- void load_common_params() {
-# define PARAM_OFF(x) offsetof(call_params_t, x)
- mov(reg_rbuf1, ptr[reg_param + PARAM_OFF(rbuf1)]);
- if (bdesc_->is_bwd())
- mov(reg_rbuf2, ptr[reg_param + PARAM_OFF(rbuf2)]);
- mov(reg_coff_max, ptr[reg_param + PARAM_OFF(coff_max)]);
- mov(reg_soff_max, ptr[reg_param + PARAM_OFF(soff_max)]);
- mov(reg_mb_stride_Bc, ptr[reg_param + PARAM_OFF(mb_stride_Bc)]);
- shl(reg_coff_max, 2);
- shl(reg_soff_max, 2);
- shl(reg_mb_stride_Bc, 2);
-
- mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]);
- mov(reg_scale_shift, ptr[reg_param + PARAM_OFF(scale_shift)]);
-
- uni_vbroadcastss(vchan_size, vmmword[reg_param + PARAM_OFF(chan_size)]);
- uni_vbroadcastss(vone, vmmword[reg_param + PARAM_OFF(one)]);
- uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]);
-
- mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_nthr)]);
- mov(ptr[rsp + stack_off_N_nthr], reg_tmp);
- mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_ithr)]);
- mov(ptr[rsp + stack_off_N_ithr], reg_tmp);
- mov(reg_tmp, ptr[reg_param + PARAM_OFF(src)]);
- mov(ptr[rsp + stack_off_src], reg_tmp);
- mov(reg_tmp, ptr[reg_param + PARAM_OFF(dst)]);
- mov(ptr[rsp + stack_off_dst], reg_tmp);
- mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_src)]);
- mov(ptr[rsp + stack_off_diff_src], reg_tmp);
- mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_dst)]);
- mov(ptr[rsp + stack_off_diff_dst], reg_tmp);
- mov(reg_tmp, ptr[reg_param + PARAM_OFF(ws)]);
- mov(ptr[rsp + stack_off_ws], reg_tmp);
- mov(reg_tmp, ptr[reg_param + PARAM_OFF(barrier)]);
- mov(ptr[rsp + stack_off_barrier], reg_tmp);
- if (is_spatial_thr_) {
- mov(reg_tmp, ptr[reg_param + PARAM_OFF(spat_size_loc)]);
- mov(ptr[rsp + stack_off_spat_size_loc], reg_tmp);
- mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_s)]);
- mov(ptr[rsp + stack_off_s_s], reg_tmp);
- mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_tail)]);
- mov(ptr[rsp + stack_off_s_tail], reg_tmp);
- }
- if (is_c_padded()) {
- mov(reg_tmp, ptr[reg_param + PARAM_OFF(is_cblk_tail)]);
- mov(ptr[rsp + stack_off_is_cblk_tail], reg_tmp);
- }
-
- if (bdesc_->is_fwd()) {
- mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]);
- mov(reg_var, reg_tmp);
- } else {
- mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_scale_shift)]);
- mov(ptr[rsp + stack_off_diff_scale_shift], reg_tmp);
- mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]);
- mov(reg_var, reg_tmp);
- }
-# undef PARAM_OFF
- }
-
- void prepare_tail_mask_avx512_common() {
- if (!is_c_padded()) return;
-
- const int tail = bdesc_->C() % (int)(vlen / sizeof(float));
- const int mask = (1 << tail) - 1;
-
- Reg32 regw_tmp = reg_tmp.cvt32();
- mov(regw_tmp, mask);
- kmovw(ktail_mask, regw_tmp);
- }
-
- void prepare_tail_mask_avx2_common() {
- if (!is_c_padded()) return;
-
- const int tail = bdesc_->C() % (int)(vlen / sizeof(float));
- static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff,
- 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
- 0, 0, 0, 0, 0, 0, 0, 0};
-
- mov(reg_tmp, reinterpret_cast<size_t>(&mask[8 - tail]));
- vmovups(vtail_mask, ptr[reg_tmp]);
- }
-
- void prepare_relu() {
- with_relu = bdesc_->is_fwd()
- ? bdesc_->with_relu_post_op() || bdesc_->fuse_bn_relu()
- : bdesc_->fuse_bn_relu();
- with_relu_inf_only = with_relu && bdesc_->is_fwd()
- && !(bdesc_->fuse_bn_relu() && bdesc_->is_training());
-
- vzero = bdesc_->is_fwd() ? vdiff_beta : vbeta;
- if (with_relu) {
- uni_vpxor(vzero, vzero, vzero);
- if (!bdesc_->is_fwd() && isa == avx2)
- prepare_l_relu_mask_avx2();
- }
- }
-
- void prepare_l_relu_mask_avx2() {
- Label l_mask_after;
- jmp(l_mask_after);
- align(32);
- L(l_relu_mask_avx2); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */
- for (int i = 0; i < 8; ++i) dd(1<<i);
- L(l_mask_after);
- }
-
- void fwd_process_relu_avx2(Vmm vdst, int offt, Vmm vstore_mask) {
- Reg64 reg_store_mask = reg_diff_scale_shift;
- shr(reg_soff, 5);
- vcmpps(vstore_mask, vzero, vdst, _cmp_lt_os);
- vmovmskps(reg_store_mask, vstore_mask);
- mov(ptr[reg_ws + reg_soff + offt / (1 << 5)], reg_store_mask.cvt8());
- vblendvps(vdst, vzero, vdst, vstore_mask);
- shl(reg_soff, 5);
- }
-
- void fwd_process_relu_avx512_common(Vmm vdst, int offt) {
- shr(reg_soff, 5);
- vcmpps(kstore_mask, vzero, vdst, _cmp_lt_os);
- kmovw(ptr[reg_ws + reg_soff + offt / (1 << 5)], kstore_mask);
- vblendmps(vdst | kstore_mask, vzero, vdst);
- shl(reg_soff, 5);
- }
-
- void bwd_process_relu_avx2(Vmm vdiff_dst, int offt, Vmm vstore_mask) {
- shr(reg_soff, 5);
- vpbroadcastb(vstore_mask, ptr[reg_ws + reg_soff + offt / (1 << 5)]);
- vpand(vstore_mask, vstore_mask, ptr[rip + l_relu_mask_avx2]);
- vpcmpeqd(vstore_mask, vstore_mask, ptr[rip + l_relu_mask_avx2]);
- vblendvps(vdiff_dst, vzero, vdiff_dst, vstore_mask);
- shl(reg_soff, 5);
- }
-
- void bwd_process_relu_avx512_common(Vmm vdiff_dst, int offt) {
- shr(reg_soff, 5);
- kmovw(kstore_mask, ptr[reg_ws + reg_soff + offt / (1 << 5)]);
- vmovups(vdiff_dst | kstore_mask | T_z, vdiff_dst);
- shl(reg_soff, 5);
- }
-
- void uni_vmovups_tail_avx2_common(const Operand &dst,
- const Operand &src, Label &l_ret) {
- if (dst.isMEM()) {
- vmaskmovps(dst.getAddress(), vtail_mask, Vmm(src.getIdx()));
- } else {
- vmaskmovps(Vmm(dst.getIdx()), vtail_mask, src.getAddress());
- }
- jmp(l_ret);
- }
-
- void uni_vmovups_tail_avx512_common(const Operand &dst,
- const Operand &src, Label &l_ret) {
- if (dst.isMEM())
- uni_vmovups(dst.getAddress() | ktail_mask | T_z, Vmm(src.getIdx()));
- else
- uni_vmovups(Vmm(dst.getIdx()) | ktail_mask | T_z, src.getAddress());
-
- jmp(l_ret);
- }
-
- void uni_vmovups_maybe_tail(const Operand &dst, const Operand &src) {
- Label l_no_mask, l_ret;
-
- if (is_c_padded()) {
- mov(reg_tmp, ptr[rsp + stack_off_is_cblk_tail]);
- cmp(reg_tmp, 0);
- jz(l_no_mask);
-
- lea(reg_tmp, ptr[reg_coff + vlen]);
- cmp(reg_tmp, reg_coff_max);
- jl(l_no_mask);
- assert(isa == avx512_common || isa == avx2);
- if (isa == avx512_common)
- uni_vmovups_tail_avx512_common(dst, src, l_ret);
- else if (isa == avx2)
- uni_vmovups_tail_avx2_common(dst, src, l_ret);
- }
- L(l_no_mask);
- if (dst.isMEM())
- uni_vmovups(dst.getAddress(), Vmm(src.getIdx()));
- else
- uni_vmovups(Vmm(dst.getIdx()), src.getAddress());
-
- L(l_ret);
- }
-
- void barrier() {
- mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
- mov(reg_bar, ptr[rsp + stack_off_barrier]);
- simple_barrier::generate(*this, reg_bar, reg_nnthr);
- }
-
- Address mean_ptr(size_t offt = 0) {
- return vmmword[reg_mean + reg_coff + offt + 0 * chan_data_offt];
- }
-
- Address var_ptr(size_t offt = 0) {
- return vmmword[reg_var + reg_coff + offt + 0 * chan_data_offt];
- }
-
- Address diff_gamma_ptr(size_t offt = 0) {
- return vmmword[reg_diff_scale_shift + reg_coff + offt
- + 0 * chan_data_offt];
- }
-
- Address diff_beta_ptr(size_t offt = 0) {
- return vmmword[reg_diff_scale_shift + reg_coff + offt
- + 1 * chan_data_offt];
- }
-
- Address gamma_ptr(size_t offt = 0) {
- return vmmword[reg_scale_shift + reg_coff + offt + 0 * chan_data_offt];
- }
-
- Address beta_ptr(size_t offt = 0) {
- return vmmword[reg_scale_shift + reg_coff + offt + 1 * chan_data_offt];
- }
-
- template <typename init_t, typename body_t, typename fini_t>
- void spat_loop(size_t len, size_t blocks, size_t regs,
- init_t init, body_t body, fini_t fini) {
- size_t factor = regs * blocks;
- size_t loop_unroll = len / factor * factor;
- size_t loop_tail = len - loop_unroll;
- size_t num_active_regs = (len < regs) ? len : regs;
- for (size_t i = 0; i < num_active_regs; i++)
- init(i);
- if (loop_unroll) {
- if (is_spatial_thr_) {
- mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]);
- add(reg_soff, ptr[rsp + stack_off_s_s]);
- } else {
- mov(reg_ctr, loop_unroll);
- }
- Label label;
- L(label); {
- for (size_t i = 0; i < factor; i++) {
- size_t base_reg = i % regs;
- body(base_reg, i);
- }
- add(reg_soff, factor * vlen);
- sub(reg_ctr, factor);
- jnz(label);
- }
- if (is_spatial_thr_) {
- add(reg_soff, ptr[rsp + stack_off_s_tail]);
- }
- }
-
- for (size_t i = 0; i < loop_tail; i++) {
- size_t base_reg = i % regs;
- body(base_reg, i);
- }
- if (loop_tail)
- add(reg_soff, loop_tail * vlen);
-
- for (size_t i = 0; i < num_active_regs; i++)
- fini(i);
- }
-
- void mean_channels() {
- Label ch_label;
- L(ch_label); {
- uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
- spat_loop(spat_size, unroll_blocks,
- unroll_regs,
- [=](size_t base_reg) {
- Vmm v = Vmm(base_reg * 2);
- if (base_reg)
- uni_vpxor(v, v, v);
- },
- [=](size_t base_reg, size_t i) {
- Vmm v0 = Vmm(base_reg * 2 + 0);
- Vmm v1 = Vmm(base_reg * 2 + 1);
- size_t offt = i * vlen;
- uni_vmovups(v1,
- vmmword[reg_src + reg_soff + offt]);
- uni_vaddps(v0, v0, v1);
- mic_prefetcht0(ptr[reg_src + reg_soff + offt
- + t0_pf_offt]);
- mic_prefetcht1(ptr[reg_src + reg_soff + offt
- + t1_pf_offt]);
- },
- [=](size_t base_reg) {
- Vmm b = Vmm(0);
- Vmm v = Vmm(base_reg * 2);
- if (base_reg)
- uni_vaddps(b, b, v);
- });
- uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
-
- add(reg_coff, vlen);
- cmp(reg_coff, reg_coff_max);
- jl(ch_label);
- }
- }
-
- void var_channels() {
- Label ch_label;
- L(ch_label); {
- uni_vmovups_maybe_tail(vmean, mean_ptr());
- uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
- spat_loop(spat_size, unroll_blocks, unroll_regs,
- [=](size_t base_reg) {
- Vmm v = Vmm(base_reg * 3);
- if (base_reg > 0)
- uni_vpxor(v, v, v);
- },
- [=](size_t base_reg, size_t i) {
- Vmm v = Vmm(3 * base_reg);
- Vmm vtmp0 = Vmm(3 * base_reg + 1);
- Vmm vtmp1 = Vmm(3 * base_reg + 2);
- size_t offt = i * vlen;
- uni_vmovups(vtmp0,
- vmmword[reg_src + reg_soff + offt]);
- if (isa == sse42) {
- movups(vtmp1, vmean);
- subps(vtmp1, vtmp0);
- } else {
- vsubps(vtmp1, vmean, vtmp0);
- }
- uni_vfmadd231ps(v, vtmp1, vtmp1);
-
- mic_prefetcht0(ptr[reg_src + reg_soff + offt
- + t0_pf_offt]);
- mic_prefetcht1(ptr[reg_src + reg_soff + offt
- + t1_pf_offt]);
- },
- [=](size_t base_reg) {
- Vmm b = Vmm(0);
- Vmm v = Vmm(base_reg * 3);
- if (base_reg)
- uni_vaddps(b, b, v);
- });
- uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
- add(reg_coff, vlen);
- cmp(reg_coff, reg_coff_max);
- jl(ch_label);
- }
- }
-
- void compute_mean_variance() {
- uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
- xor_(reg_coff, reg_coff);
- Label zero_rbuf;
- L(zero_rbuf); {
- uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
- add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
- cmp(reg_coff, reg_coff_max);
- jne(zero_rbuf);
- }
-
- mov(reg_src, ptr[rsp + stack_off_src]);
-
- xor_(reg_soff, reg_soff);
- Label mean_spatial;
- L(mean_spatial); {
- xor_(reg_coff, reg_coff);
-
- if (isa == sse42)
- mov(reg_tmp_off, reg_soff);
-
- mean_channels();
-
- if (isa == sse42) {
- mov(reg_soff, reg_tmp_off);
- add(reg_src, vlen / 2);
- mov(reg_coff, vlen / 2);
-
- mean_channels();
-
- sub(reg_src, vlen / 2);
- }
-
- add(reg_soff, reg_mb_stride_Bc);
- cmp(reg_soff, reg_soff_max);
- jne(mean_spatial);
- }
-
- Label no_mean_reduction;
- barrier(); {
- mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
- cmp(reg_tmp, 0);
- jne(no_mean_reduction);
- mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
- xor_(reg_coff, reg_coff);
- Label mean_reduction_channels;
- L(mean_reduction_channels); {
- mov(reg_roff, reg_coff);
- uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
- uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
- mov(reg_ctr, reg_nnthr);
- Label mean_reduction_thrs;
- L(mean_reduction_thrs); {
- uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]);
- uni_vmovups(vmmword[reg_rbuf1 + reg_roff], Vmm(0));
- add(reg_roff, reg_coff_max);
- sub(reg_ctr, 1);
- jnz(mean_reduction_thrs);
- }
- uni_vdivps(Vmm(1), Vmm(1), vchan_size);
- uni_vmovups_maybe_tail(mean_ptr(), Vmm(1));
-
- add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
-
- cmp(reg_coff, reg_coff_max);
- jne(mean_reduction_channels);
- }
- }
- L(no_mean_reduction);
- barrier();
-
- xor_(reg_soff, reg_soff);
- Label var_spatial;
- L(var_spatial); {
- xor_(reg_coff, reg_coff);
-
- if (isa == sse42)
- mov(reg_tmp_off, reg_soff);
-
- var_channels();
-
- if (isa == sse42) {
- mov(reg_soff, reg_tmp_off);
- add(reg_src, vlen / 2);
- mov(reg_coff, vlen / 2);
-
- var_channels();
-
- sub(reg_src, vlen / 2);
- }
-
- add(reg_soff, reg_mb_stride_Bc);
- cmp(reg_soff, reg_soff_max);
- jne(var_spatial);
- }
-
- Label no_var_reduction;
- barrier(); {
- mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
- cmp(reg_tmp, 0);
- jne(no_var_reduction);
-
- mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
- xor_(reg_coff, reg_coff);
- Label var_reduction_channels;
- L(var_reduction_channels); {
- mov(reg_roff, reg_coff);
- uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
- mov(reg_ctr, reg_nnthr);
- Label var_reduction_thrs;
- L(var_reduction_thrs); { // TODO: unroll (?)
- uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]);
- add(reg_roff, reg_coff_max);
- sub(reg_ctr, 1);
- jnz(var_reduction_thrs);
- }
- uni_vdivps(Vmm(1), Vmm(1), vchan_size);
- uni_vmovups_maybe_tail(var_ptr(), Vmm(1));
- add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
-
- cmp(reg_coff, reg_coff_max);
- jne(var_reduction_channels);
- }
- }
- L(no_var_reduction);
- barrier();
- }
-
- void forward_channels() {
- Label ch_label;
- L(ch_label); {
- uni_vmovups_maybe_tail(vmean, mean_ptr());
- uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
- uni_vaddps(vsqrtvar, vsqrtvar, veps);
- uni_vsqrtps(vsqrtvar, vsqrtvar);
-
- if (bdesc_->use_scaleshift()) {
- uni_vmovups_maybe_tail(vgamma, gamma_ptr());
- uni_vmovups_maybe_tail(vbeta, beta_ptr());
- }
-
- Vmm vscale = bdesc_->use_scaleshift() ? vgamma : vone;
- Vmm vdiv = bdesc_->use_scaleshift() ? vgamma : vsqrtvar;
-
- if (isa == sse42) {
- movups(vbuf, vscale);
- divps(vbuf, vsqrtvar);
- movups(vdiv, vbuf);
- } else {
- vdivps(vdiv, vscale, vsqrtvar);
- }
-
- auto compute = [=](bool output_is_aligned) {
- spat_loop(spat_size, unroll_blocks, unroll_regs,
- [](size_t base_reg) {UNUSED(base_reg);},
- [=](size_t base_reg, size_t i) {
- Vmm v = Vmm(base_reg);
- size_t offt = i * vlen;
- uni_vmovups(v,
- vmmword[reg_src + reg_soff + offt]);
- mic_prefetcht0(ptr[reg_src + reg_soff + offt
- + t0_pf_offt]);
- mic_prefetcht1(ptr[reg_src + reg_soff + offt
- + t1_pf_offt]);
- uni_vsubps(v, v, vmean);
- if (bdesc_->use_scaleshift()) {
- uni_vfmadd213ps(v, vgamma, vbeta);
- } else {
- uni_vmulps(v, v, vsqrtvar);
- }
- if (with_relu_inf_only) {
- uni_vmaxps(v, v, vzero);
- } else if (with_relu) {
- if (isa == avx512_common)
- fwd_process_relu_avx512_common(v, offt);
- else
- fwd_process_relu_avx2(v, offt, Vmm(3));
- }
- if (output_is_aligned) {
- uni_vmovntps(
- vmmword[reg_dst + reg_soff + offt], v);
- } else {
- uni_vmovups(
- vmmword[reg_dst + reg_soff + offt], v);
- }
- },
- [](size_t base_reg) {UNUSED(base_reg);});
- };
-
- Label unaligned_store, end_store;
- test(reg_dst, vlen - 1);
- jnz(unaligned_store, T_NEAR);
- compute(true);
- jmp(end_store, T_NEAR);
- L(unaligned_store); {
- compute(false);
- }
- L(end_store);
-
- add(reg_coff, vlen);
- cmp(reg_coff, reg_coff_max);
- jl(ch_label);
- }
- }
-
- void forward() {
- mov(reg_src, ptr[rsp + stack_off_src]);
- mov(reg_dst, ptr[rsp + stack_off_dst]);
- mov(reg_ws, ptr[rsp + stack_off_ws]);
-
- xor_(reg_soff, reg_soff);
- Label dst_spatial;
- L(dst_spatial); {
- xor_(reg_coff, reg_coff);
- if (isa == sse42)
- mov(reg_tmp_off, reg_soff);
-
- forward_channels();
-
- if (isa == sse42) {
- mov(reg_soff, reg_tmp_off);
- add(reg_src, vlen / 2);
- add(reg_dst, vlen / 2);
- mov(reg_coff, vlen / 2);
-
- forward_channels();
-
- sub(reg_src, vlen / 2);
- sub(reg_dst, vlen / 2);
- }
-
- add(reg_soff, reg_mb_stride_Bc);
- cmp(reg_soff, reg_soff_max);
- jnz(dst_spatial);
- }
- }
-
- void backward_sh_channels() {
- Label sh_channels;
- L(sh_channels); {
- uni_vmovups_maybe_tail(vmean, mean_ptr());
- uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
- uni_vmovups(Vmm(1), vmmword[reg_rbuf2 + reg_coff]);
- spat_loop(spat_size, 1, 1,
- [=](size_t base_reg) {
- if (base_reg > 0) {
- for (int i = 0; i < 2; i++) {
- Vmm v(base_reg * 5 + i);
- uni_vpxor(v, v, v);
- }
- }
- },
- [=](size_t base_reg, size_t i) {
- Vmm o0 = Vmm(base_reg * 5 + 0);
- Vmm o1 = Vmm(base_reg * 5 + 1);
- Vmm t1 = Vmm(base_reg * 5 + 2);
- Vmm t2 = Vmm(base_reg * 5 + 3);
- Vmm t3 = Vmm(base_reg * 5 + 4);
- size_t offt = i * vlen;
- uni_vmovups(t1, vmmword[reg_src + reg_soff + offt]);
- uni_vmovups(t2, vmmword[reg_diff_dst + reg_soff
- + offt]);
- if (with_relu) {
- if (isa == avx512_common)
- bwd_process_relu_avx512_common(t2, offt);
- else if (isa == avx2)
- bwd_process_relu_avx2(t2, offt, t3);
- else
- assert(false);
- }
- uni_vsubps(t3, vmean, t1, t3);
- if (isa == sse42) {
- mulps(t3, t2);
- subps(o0, t3);
- } else {
- vfnmadd231ps(o0, t3, t2);
- }
- uni_vaddps(o1, o1, t2);
- mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt
- + t0_pf_offt]);
- mic_prefetcht0(ptr[reg_src + reg_soff + offt
- + t0_pf_offt]);
- mic_prefetcht1(ptr[reg_diff_dst + reg_soff + offt
- + t1_pf_offt]);
- mic_prefetcht1(ptr[reg_src + reg_soff + offt
- + t1_pf_offt]);
- },
- [=](size_t base_reg) {
- Vmm b0 = Vmm(0);
- Vmm b1 = Vmm(1);
- if (base_reg) {
- uni_vaddps(b0, b0, Vmm(base_reg * 5 + 0));
- uni_vaddps(b1, b1, Vmm(base_reg * 5 + 1));
- }
- });
- uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
- uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(1));
- add(reg_coff, vlen);
- cmp(reg_coff, reg_coff_max);
- jl(sh_channels);
- }
- }
-
- void backward_diff_channels() {
- Label diff_channels;
- L(diff_channels); {
- uni_vmovups_maybe_tail(vmean, mean_ptr());
- uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
- uni_vaddps(vsqrtvar, vsqrtvar, veps);
- uni_vsqrtps(vsqrtvar, vsqrtvar);
- uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf);
- if (bdesc_->use_scaleshift())
- uni_vmovups_maybe_tail(vgamma, gamma_ptr());
- uni_vmovups_maybe_tail(vdiff_gamma, diff_gamma_ptr());
- uni_vmovups_maybe_tail(vdiff_beta, diff_beta_ptr());
- uni_vmulps(vdiff_gamma, vdiff_gamma, vsqrtvar);
- uni_vdivps(vdiff_beta, vdiff_beta, vchan_size);
- uni_vdivps(vdiff_gamma, vdiff_gamma, vchan_size);
-
- auto compute = [=](bool output_is_aligned) {
- spat_loop(spat_size, unroll_blocks, unroll_regs,
- [=](size_t base_reg) {UNUSED(base_reg);},
- [=](size_t base_reg, size_t i) {
- Vmm v(base_reg * 2 + 0);
- Vmm t(base_reg * 2 + 1);
- Vmm t1(base_reg * 2 + 2);
- size_t offt = i * vlen;
- uni_vmovups(v, vmmword[reg_diff_dst + reg_soff
- + offt]);
- if (with_relu) {
- if (isa == avx512_common)
- bwd_process_relu_avx512_common(v, offt);
- else if (isa == avx2)
- bwd_process_relu_avx2(v, offt, t);
- else
- assert(false);
- }
- if (!bdesc_->use_global_stats()) {
- uni_vsubps(v, v, vdiff_beta);
- uni_vmovups(t, vmmword[reg_src + reg_soff
- + offt]);
- uni_vsubps(t, vmean, t, t1);
- uni_vmulps(t, t, vdiff_gamma);
- uni_vaddps(v, v, t);
- }
- uni_vmulps(v, v, vsqrtvar);
- if (bdesc_->use_scaleshift()) {
- uni_vmulps(v, v, vgamma);
- }
- if (output_is_aligned) {
- uni_vmovntps(
- vmmword[reg_diff_src + reg_soff + offt],
- v);
- } else {
- uni_vmovups(
- vmmword[reg_diff_src + reg_soff + offt],
- v);
- }
- mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt
- + t0_pf_offt]);
- mic_prefetcht0(ptr[reg_src + reg_soff + offt
- + t0_pf_offt]);
- mic_prefetcht1(ptr[reg_diff_dst + reg_soff
- + offt + t1_pf_offt]);
- mic_prefetcht1(ptr[reg_src + reg_soff + offt
- + t1_pf_offt]);
- },
- [=](size_t base_reg) {UNUSED(base_reg);});
- };
-
- Label unaligned_store, end_store;
- test(reg_diff_src, vlen - 1);
- jnz(unaligned_store, T_NEAR);
- compute(true);
- jmp(end_store, T_NEAR);
- L(unaligned_store); {
- compute(false);
- }
- L(end_store);
-
- add(reg_coff, vlen);
- cmp(reg_coff, reg_coff_max);
- jl(diff_channels);
- }
- }
-
- void backward() {
- uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
- xor_(reg_coff, reg_coff);
- Label zero_rbuf, sh_spatial;
-
- L(zero_rbuf); {
- uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
- uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(0));
- add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
- cmp(reg_coff, reg_coff_max);
- jne(zero_rbuf);
- }
-
- mov(reg_src, ptr[rsp + stack_off_src]);
- mov(reg_diff_dst, ptr[rsp + stack_off_diff_dst]);
- if (with_relu) {
- assert(isa == avx2 || isa == avx512_common);
- mov(reg_ws, ptr[rsp + stack_off_ws]);
- }
-
- xor_(reg_soff, reg_soff);
- L(sh_spatial); {
- xor_(reg_coff, reg_coff);
- if (isa == sse42) {
- mov(reg_tmp_off, reg_soff);
- }
- backward_sh_channels();
- if (isa == sse42) {
- mov(reg_soff, reg_tmp_off);
- add(reg_diff_dst, vlen / 2);
- add(reg_src, vlen / 2);
- mov(reg_coff, vlen / 2);
- backward_sh_channels();
- sub(reg_diff_dst, vlen / 2);
- sub(reg_src, vlen / 2);
- }
- add(reg_soff, reg_mb_stride_Bc);
- cmp(reg_soff, reg_soff_max);
- jne(sh_spatial);
- }
-
- mov(reg_diff_scale_shift, ptr[rsp + stack_off_diff_scale_shift]);
-
- Label no_sh_reduction;
- barrier(); {
- mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
- cmp(reg_tmp, 0);
- Label sh_reduction_channels;
- jne(no_sh_reduction, T_NEAR);
-
- mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
- xor_(reg_coff, reg_coff);
- L(sh_reduction_channels); {
- mov(reg_roff, reg_coff);
- uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
- uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
- uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
- uni_vaddps(vsqrtvar, vsqrtvar, veps);
- uni_vsqrtps(vsqrtvar, vsqrtvar);
- uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf);
- mov(reg_ctr, reg_nnthr);
- Label sh_reduction_thrs;
- L(sh_reduction_thrs); { // TODO: unroll (?)
- uni_vaddps(Vmm(0), Vmm(0), vmmword[reg_rbuf1 + reg_roff]);
- uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf2 + reg_roff]);
- add(reg_roff, reg_coff_max);
- sub(reg_ctr, 1);
- jnz(sh_reduction_thrs);
- }
- uni_vmulps(Vmm(0), Vmm(0), vsqrtvar);
- uni_vmovups_maybe_tail(diff_gamma_ptr(), Vmm(0));
- uni_vmovups_maybe_tail(diff_beta_ptr(), Vmm(1));
- add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
- cmp(reg_coff, reg_coff_max);
- jne(sh_reduction_channels);
- }
- }
- L(no_sh_reduction);
- barrier();
-
- mov(reg_diff_src, ptr[rsp + stack_off_diff_src]);
- if (with_relu) {
- assert(isa == avx2 || isa == avx512_common);
- mov(reg_ws, ptr[rsp + stack_off_ws]);
- }
-
- xor_(reg_soff, reg_soff);
- Label diff_spatial;
- L(diff_spatial); {
- xor_(reg_coff, reg_coff);
- if (isa == sse42) {
- mov(reg_tmp_off, reg_soff);
- }
- backward_diff_channels();
- if (isa == sse42) {
- mov(reg_soff, reg_tmp_off);
- add(reg_diff_dst, vlen / 2);
- add(reg_diff_src, vlen / 2);
- add(reg_src, vlen / 2);
- mov(reg_coff, vlen / 2);
- backward_diff_channels();
- sub(reg_diff_dst, vlen / 2);
- sub(reg_diff_src, vlen / 2);
- sub(reg_src, vlen / 2);
- }
- add(reg_soff, reg_mb_stride_Bc);
- cmp(reg_soff, reg_soff_max);
- jne(diff_spatial);
- }
- }
-
- jit_bnorm_t(const batch_normalization_pd_t *bdesc): bdesc_(bdesc) {
- static_assert(isa == sse42 || isa == avx2 || isa == avx512_common
- || isa == avx512_mic, "unsupported isa");
-
- const int simd_w = isa == sse42 ? 8 :
- cpu_isa_traits<isa>::vlen / sizeof(data_t);
- is_spatial_thr_ =
- bnorm_utils::is_spatial_thr(bdesc_, simd_w, sizeof(data_t));
-
- unroll_blocks = isa == avx512_common && !is_spatial_thr_ ? 4 : 1;
- unroll_regs = isa == avx512_common && !is_spatial_thr_ ? 4 : 1;
-
- preamble();
-
- if (isa == avx512_common)
- prepare_tail_mask_avx512_common();
- else if (isa == avx2)
- prepare_tail_mask_avx2_common();
-
- compute_static_strides();
- sub(rsp, stack_size_required);
- load_common_params();
- prepare_relu();
-
- if (bdesc_->is_fwd()) {
- if (!bdesc_->stats_is_src()) {
- compute_mean_variance();
- }
- forward();
- } else {
- backward();
- }
- add(rsp, stack_size_required);
- postamble();
-
- ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
- this->getCode()));
- }
-};
-
-template <cpu_isa_t isa>
-struct uni_bnorm_driver_t: public c_compatible {
- uni_bnorm_driver_t(const batch_normalization_pd_t *bdesc)
- : bdesc_(bdesc), ker_(bdesc_)
- {
- const int nthrs = mkldnn_get_max_threads();
- const dim_t C_PADDED = get_c_padded(bdesc_);
-
- size_t data_size = sizeof(data_t) * bdesc_->MB() * C_PADDED
- * bdesc_->D() * bdesc_->H() * bdesc_->W();
- l3_size_ = get_cache_size(3, true) * nthrs / 2;
- do_blocking_ = (data_size >= l3_size_ / 2 && l3_size_ > 0);
- }
-
- ~uni_bnorm_driver_t() {}
-
- static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
- const batch_normalization_pd_t *bdesc) {
- int nthrs = mkldnn_get_max_threads();
- dim_t C_PADDED = get_c_padded(bdesc);
-
- int sbuf_sz = use_tmp_stats(bdesc) * 2 * C_PADDED;
- int pbuf_sz = use_tmp_diff_scale_shift(bdesc) * 2 * C_PADDED;
- int rbuf_sz = (bdesc->is_fwd() ? 1 : 2) * C_PADDED * nthrs;
-
- scratchpad.book(key_bnorm_tmp_stats, sizeof(data_t) * sbuf_sz);
- scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * pbuf_sz);
- scratchpad.book(key_bnorm_reduction, sizeof(data_t) * rbuf_sz);
-
- if (mkldnn_thr_syncable()) {
- int n_barriers = C_PADDED / simd_w;
- scratchpad.book(key_barrier, sizeof(barrier::ctx_t) * n_barriers);
- }
- }
-
- void exec(int ithr, int nthr, const data_t *src, data_t *diff_src,
- data_t *dst, const data_t *diff_dst, const data_t *scale_shift,
- data_t *diff_scale_shift, const data_t *mean, const data_t *var,
- const uint8_t *ws, const memory_tracking::grantor_t &scratchpad) {
- auto sbuf = scratchpad.get<data_t>(key_bnorm_tmp_stats);
- auto pbuf = scratchpad.get<data_t>(key_bnorm_tmp_diff_ss);
- auto rbuf = scratchpad.get<data_t>(key_bnorm_reduction);
- auto barriers = scratchpad.get<barrier::ctx_t>(key_barrier);
-
- dim_t N = bdesc_->MB();
- dim_t C = bdesc_->C();
- dim_t C_PADDED = get_c_padded(bdesc_);
- dim_t D = bdesc_->D();
- dim_t H = bdesc_->H();
- dim_t W = bdesc_->W();
- dim_t SP = D * H * W;
- dim_t img_size = C_PADDED * D * H * W;
- const int vlen = isa == sse42 ? 32 : cpu_isa_traits<isa>::vlen;
-
- typename jit_bnorm_t<isa>::call_params_t p;
-
- p.eps = bdesc_->desc()->batch_norm_epsilon;
- p.one = 1.0f;
- p.spat_size = D * H * W;
- p.chan_size = 1.0f * N * p.spat_size;
-
- dim_t C_blks = C_PADDED / simd_w;
-
- int C_ithr{0}, C_nthr{0}, N_ithr{0}, N_nthr{0}, S_ithr{0}, S_nthr{0};
- dim_t C_blk_s{0}, C_blk_e{0}, N_s{0}, N_e{0}, S_s{0}, S_e{0};
-
- dim_t C_blks_per_iter{ 1 };
- int64_t iters{ 1 };
- if (do_blocking_) {
- int num_tensors = bdesc_->is_fwd() ? 1 : 2;
- size_t working_set_size
- = (N * D * H * W * simd_w * sizeof(data_t)) * num_tensors;
- bnorm_utils::cache_balance(working_set_size, C_blks,
- C_blks_per_iter, iters);
- }
-
- bool spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_,
- true, ithr, nthr, N, do_blocking_ ? C_blks_per_iter : C_blks,
- SP, C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, N_e,
- S_ithr, S_nthr, S_s, S_e);
-
- int SP_N_ithr = N_ithr * S_nthr + S_ithr;
- int SP_N_nthr = N_nthr * S_nthr;
- assert(IMPLICATION(!mkldnn_thr_syncable(), SP_N_nthr == 1));
-
- p.N_ithr = SP_N_ithr;
- p.N_nthr = SP_N_nthr;
-
- int last_iter_blks = C_blks - (iters - 1) * C_blks_per_iter;
- int global_C_blk_s;
- int global_barriers_per_iter = C_nthr;
-
- for (int64_t it = 0; it < iters; it++) {
- if (it == iters - 1 && iters > 1) {
- C_blk_s = C_blk_e = N_s = N_e = 0;
- spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_,
- spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP,
- C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s,
- N_e, S_ithr, S_nthr, S_s, S_e);
-
- // Update call parameters for JIT, last iteration
- p.N_ithr = N_ithr * S_nthr + S_ithr;
- p.N_nthr = N_nthr * S_nthr;
- }
-
- global_C_blk_s = do_blocking_ ?
- (C_blk_s == -1) ? -1 : it * C_blks_per_iter + C_blk_s :
- C_blk_s;
-
- int C_blks_thr = C_blk_e - C_blk_s;
- int N_thr = N_e - N_s;
-
- size_t coff_base = global_C_blk_s * simd_w;
- size_t soff_base
- = global_C_blk_s * p.spat_size * simd_w + N_s * img_size;
-
- p.spat_size_loc = S_e - S_s;
- p.S_s = S_s * vlen;
- p.S_tail = (p.spat_size - S_e) * vlen;
- p.coff_max = C_blks_thr * simd_w;
- p.mean = (use_tmp_stats(bdesc_) ? sbuf : mean) + coff_base;
- p.var = (use_tmp_stats(bdesc_) ? sbuf + C_PADDED : var) + coff_base;
- p.scale_shift = scale_shift + coff_base;
- p.diff_scale_shift = (use_tmp_diff_scale_shift(bdesc_)
- ? pbuf : diff_scale_shift) + coff_base;
-
- p.soff_max = N_thr * img_size;
- p.src = src + soff_base;
- p.dst = dst + soff_base;
- p.diff_src = diff_src + soff_base;
- p.diff_dst = diff_dst + soff_base;
- p.ws = ws + soff_base / 8;
-
- p.mb_stride_Bc = img_size - p.coff_max * p.spat_size;
-
- // use SP_N_nthr which is the same as p.N_nthr except maybe for
- // the last iteration.
- p.rbuf1 = rbuf + ((it * C_blks_per_iter) * SP_N_nthr
- + C_blk_s * p.N_nthr + p.N_ithr * C_blks_thr) * simd_w;
- // rbuf1 and rbuf2 have to be disjoint
- p.rbuf2 = p.rbuf1 + C_PADDED * nthr;
- p.is_cblk_tail = (it * C_blks_per_iter + C_blk_e) * simd_w > C;
-
- size_t iter_bariers
- = do_blocking_ ? it * global_barriers_per_iter : 0;
- p.barrier = barriers + C_ithr + iter_bariers;
- if (p.soff_max != 0 && p.coff_max != 0)
- ker_(&p);
- }
- }
-
- void init_barriers(const memory_tracking::grantor_t &scratchpad) {
- auto barriers = scratchpad.get<barrier::ctx_t>(key_barrier);
- if (barriers) {
- const int n_barriers = get_c_padded(bdesc_) / simd_w;
- for (int i = 0; i < n_barriers; ++i)
- barrier::ctx_init(&barriers[i]);
- }
- }
-
-private:
- enum {
- simd_w = isa == sse42 ? 8 : cpu_isa_traits<isa>::vlen / sizeof(data_t)
- };
-
- static bool use_tmp_stats(const batch_normalization_pd_t *bdesc) {
- return true
- && !bdesc->stats_is_src()
- && bdesc->desc()->prop_kind == prop_kind::forward_inference;
- }
-
- static bool use_tmp_diff_scale_shift(const batch_normalization_pd_t *bdesc)
- {
- return false
- || (bdesc->is_bwd() && !bdesc->use_scaleshift())
- || bdesc->desc()->prop_kind == prop_kind::backward_data;
- }
-
- static dim_t get_c_padded(const batch_normalization_pd_t *bdesc)
- { return bdesc->src_md()->padded_dims[1]; }
-
- const batch_normalization_pd_t *bdesc_;
- bool do_blocking_;
- size_t l3_size_;
-
- jit_bnorm_t<isa> ker_;
-};
-
-}
-
-using namespace data_type;
-using namespace format_tag;
-using namespace utils;
-
-/* fwd */
-
-template <cpu_isa_t isa>
-status_t jit_uni_batch_normalization_fwd_t<isa>::pd_t::init() {
- auto desired_fmt_tag = (ndims() == 4)
- ? isa == avx512_common ? nChw16c : nChw8c
- : isa == avx512_common ? nCdhw16c : nCdhw8c;
-
- bool ok = true
- && mayiuse(isa)
- && is_fwd()
- && !has_zero_dim_memory()
- && one_of(ndims(), 4, 5)
- && src_md()->data_type == f32
- && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32)
- && memory_desc_matches_tag(*src_md(), desired_fmt_tag)
- && (attr()->has_default_values() || this->with_relu_post_op());
- if (!ok) return status::unimplemented;
-
- if (is_training() && fuse_bn_relu()) {
- if (isa < avx2) return status::unimplemented;
- init_default_ws(1);
- }
-
- if (memory_desc_wrapper(src_md()).padded_dims()[1] != C()
- && isa < avx2)
- return status::unimplemented;
-
- auto scratchpad = scratchpad_registry().registrar();
- uni_bnorm_driver_t<isa>::init_scratchpad(scratchpad, this);
-
- return status::success;
-}
-
-template <cpu_isa_t isa>
-jit_uni_batch_normalization_fwd_t<isa>::jit_uni_batch_normalization_fwd_t(
- const pd_t *apd): cpu_primitive_t(apd)
-{ bnorm_driver_ = new uni_bnorm_driver_t<isa>(pd()); }
-
-template <cpu_isa_t isa>
-status_t jit_uni_batch_normalization_fwd_t<isa>::execute(
- const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto scale_shift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);
-
- auto mean = pd()->stats_is_src()
- ? const_cast<data_t *>(CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN))
- : CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN);
- auto var = pd()->stats_is_src()
- ? const_cast<data_t *>(CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE))
- : CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE);
-
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
- auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE);
-
- auto scratchpad = this->scratchpad(ctx);
-
- bnorm_driver_->init_barriers(scratchpad);
-
- parallel(0, [&](const int ithr, const int nthr) {
- bnorm_driver_->exec(ithr, nthr, src, nullptr, dst, nullptr,
- scale_shift, nullptr, mean, var, ws, scratchpad);
- });
-
- return status::success;
-}
-
-template <cpu_isa_t isa>
-jit_uni_batch_normalization_fwd_t<isa>::~jit_uni_batch_normalization_fwd_t()
-{ delete bnorm_driver_; }
-
-/* bwd */
-
-template <cpu_isa_t isa>
-status_t jit_uni_batch_normalization_bwd_t<isa>::pd_t::init() {
- auto desired_fmt_tag = (ndims() == 4)
- ? one_of(isa, sse42, avx2) ? nChw8c : nChw16c
- : one_of(isa, sse42, avx2) ? nCdhw8c : nCdhw16c;
-
- bool ok = true
- && mayiuse(isa)
- && is_bwd()
- && !has_zero_dim_memory()
- && one_of(ndims(), 4, 5)
- && everyone_is(f32, src_md()->data_type, diff_src_md()->data_type)
- && IMPLICATION(use_scaleshift(),
- utils::everyone_is(f32,
- weights_md()->data_type,
- diff_weights_md()->data_type))
- && memory_desc_matches_tag(*src_md(), desired_fmt_tag)
- && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag)
- && attr()->has_default_values();
- if (!ok) return status::unimplemented;
-
- if (memory_desc_wrapper(src_md()).padded_dims()[1] != C()
- && isa < avx2)
- return status::unimplemented;
-
- if (fuse_bn_relu()) {
- if (isa < avx2) return status::unimplemented;
- init_default_ws(1);
- if (!compare_ws(hint_fwd_pd_))
- return status::unimplemented;
- }
-
- /* TODO: extra checks required */
-
- auto scratchpad = scratchpad_registry().registrar();
- uni_bnorm_driver_t<isa>::init_scratchpad(scratchpad, this);
-
- return status::success;
-}
-
-template <cpu_isa_t isa>
-jit_uni_batch_normalization_bwd_t<isa>::jit_uni_batch_normalization_bwd_t(
- const pd_t *apd): cpu_primitive_t(apd)
-{ bnorm_driver_ = new uni_bnorm_driver_t<isa>(pd()); }
-
-template <cpu_isa_t isa>
-status_t jit_uni_batch_normalization_bwd_t<isa>::execute(
- const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN);
- auto var = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE);
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto scale_shift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);
- auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE);
-
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
- auto diff_scale_shift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT);
-
- auto scratchpad = this->scratchpad(ctx);
-
- bnorm_driver_->init_barriers(scratchpad);
-
- parallel(0, [&](const int ithr, const int nthr) {
- bnorm_driver_->exec(ithr, nthr, src, diff_src, nullptr, diff_dst,
- scale_shift, diff_scale_shift, mean, var, ws, scratchpad);
- });
-
- return status::success;
-}
-
-template <cpu_isa_t isa>
-jit_uni_batch_normalization_bwd_t<isa>::~jit_uni_batch_normalization_bwd_t()
-{ delete bnorm_driver_; }
-
-/* struct instantiation */
-template struct jit_uni_batch_normalization_fwd_t<sse42>;
-template struct jit_uni_batch_normalization_bwd_t<sse42>;
-template struct jit_uni_batch_normalization_fwd_t<avx2>;
-template struct jit_uni_batch_normalization_bwd_t<avx2>;
-template struct jit_uni_batch_normalization_fwd_t<avx512_common>;
-template struct jit_uni_batch_normalization_bwd_t<avx512_common>;
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp
deleted file mode 100644
index 96410ec84e..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp
+++ /dev/null
@@ -1,100 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef JIT_UNI_BATCH_NORMALIZATION_HPP
-#define JIT_UNI_BATCH_NORMALIZATION_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_batch_normalization_pd.hpp"
-#include "cpu_isa_traits.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-namespace { template <cpu_isa_t isa> struct uni_bnorm_driver_t; }
-
-template <cpu_isa_t isa>
-struct jit_uni_batch_normalization_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_batch_normalization_fwd_pd_t {
- pd_t(engine_t *engine, const batch_normalization_desc_t *adesc,
- const primitive_attr_t *attr,
- const batch_normalization_fwd_pd_t *hint_fwd_pd)
- : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit:", isa, ""),
- jit_uni_batch_normalization_fwd_t<isa>);
-
- status_t init();
- };
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- jit_uni_batch_normalization_fwd_t(const pd_t *apd);
- ~jit_uni_batch_normalization_fwd_t();
-
- virtual status_t execute(const exec_ctx_t &ctx) const override;
-
-private:
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- uni_bnorm_driver_t<isa> *bnorm_driver_;
-};
-
-template <cpu_isa_t isa>
-struct jit_uni_batch_normalization_bwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_batch_normalization_bwd_pd_t {
- pd_t(engine_t *engine, const batch_normalization_desc_t *adesc,
- const primitive_attr_t *attr,
- const batch_normalization_fwd_pd_t *hint_fwd_pd)
- : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit:", isa, ""),
- jit_uni_batch_normalization_bwd_t<isa>);
-
- status_t init();
- };
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- jit_uni_batch_normalization_bwd_t(const pd_t *apd);
- ~jit_uni_batch_normalization_bwd_t();
-
- virtual status_t execute(const exec_ctx_t &ctx) const override;
-
-private:
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- uni_bnorm_driver_t<isa> *bnorm_driver_;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp
deleted file mode 100644
index b7dc5f85c5..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp
+++ /dev/null
@@ -1,1302 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-#include "cpu_memory.hpp"
-
-#include "jit_uni_dw_conv_kernel_f32.hpp"
-
-#define GET_OFF(field) offsetof(jit_conv_call_s, field)
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::format_tag;
-using namespace mkldnn::impl::prop_kind;
-using namespace mkldnn::impl::memory_tracking::names;
-using namespace mkldnn::impl::utils;
-
-using namespace Xbyak;
-
-template <cpu_isa_t isa>
-void jit_uni_dw_conv_fwd_kernel_f32<isa>::load_src(int ur_ch_blocks, int ur_w) {
- int repeats = isa == sse42 ? 2 : 1;
- for (int i = 0; i < repeats; i++) {
- for (int ch = 0; ch < ur_ch_blocks; ch++) {
- for (int ow = 0; ow < ur_w; ow++) {
- Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
-
- int b_off = ch*jcp.ch_block + i*4;
- if (this->jcp.with_bias)
- uni_vmovups(vmm_acc,
- vmmword[reg_bias + b_off*sizeof(float)]);
- else
- uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
-
- int o_off = ch*jcp.oh*jcp.ow*jcp.ch_block
- + ow*jcp.ch_block + i*4;
- if (this->jcp.with_sum)
- uni_vaddps(vmm_acc, vmm_acc,
- vmmword[reg_output + o_off*sizeof(float)]);
- }
- }
- }
-}
-
-template <cpu_isa_t isa>
-void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter(
- int ur_ch_blocks, int ur_w) {
- int ch_blk = jcp.ch_block;
- int dilate_h = jcp.dilate_h + 1;
- int dilate_w = jcp.dilate_w + 1;
- int stride_w = jcp.stride_w;
-
- Label iter_exit_label;
-
- cmp(reg_kh, 0);
- je(iter_exit_label, T_NEAR);
- cmp(reg_kw, 0);
- je(iter_exit_label, T_NEAR);
-
- mov(iter_kh, reg_kh);
- Label kh_label;
- L(kh_label); {
- mov(iter_kw, reg_kw);
- mov(aux1_reg_input, aux_reg_input);
- mov(aux1_reg_kernel, aux_reg_kernel);
-
- Label kw_label;
- L(kw_label); {
- int repeats = isa == sse42 ? 2 : 1;
- for (int i = 0; i < repeats; i++) {
- for (int ch = 0; ch < ur_ch_blocks; ch++) {
- int ker_off = ch*jcp.kh*jcp.kw*ch_blk + i*4;
- Vmm vmm_ker = get_ker_reg(0);
- uni_vmovups(vmm_ker, ptr[aux1_reg_kernel
- + ker_off*sizeof(float)]);
-
- for (int ow = 0; ow < ur_w; ow++) {
- int inp_off = ch*jcp.ih*jcp.iw*ch_blk
- + ow*stride_w*ch_blk + i*4;
- Vmm vmm_src = get_src_reg(0);
- uni_vmovups(vmm_src, ptr[aux1_reg_input
- + inp_off*sizeof(float)]);
-
- Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w
- + ch*ur_w + ow);
- uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
- }
- }
- }
- add(aux1_reg_kernel, ch_blk*sizeof(float));
- add(aux1_reg_input, ch_blk*dilate_w*sizeof(float));
-
- dec(iter_kw);
- cmp(iter_kw, 0);
- jg(kw_label, T_NEAR);
- }
- add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
- add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float));
-
- dec(iter_kh);
- cmp(iter_kh, 0);
- jg(kh_label, T_NEAR);
- }
-
- L(iter_exit_label);
-}
-
-template <cpu_isa_t isa>
-void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter_unrolled(
- int ur_ch_blocks, int ur_w) {
- int ch_blk = jcp.ch_block;
- int dilate_h = jcp.dilate_h + 1;
- int dilate_w = jcp.dilate_w + 1;
- int stride_w = jcp.stride_w;
-
- Label iter_exit_label;
-
- cmp(reg_kh, 0);
- je(iter_exit_label, T_NEAR);
-
- mov(iter_kh, reg_kh);
- Label kh_label;
- L(kh_label); {
- int repeats = isa == sse42 ? 2 : 1;
- for (int i = 0; i < repeats; i++) {
- for (int ch = 0; ch < ur_ch_blocks; ch++) {
- for (int kw = 0; kw < jcp.kw; kw++) {
- int ker_off = ch*jcp.kh*jcp.kw*ch_blk + kw*ch_blk + i*4;
-
- Vmm vmm_ker = get_ker_reg(0);
- uni_vmovups(vmm_ker, ptr[aux_reg_kernel
- + ker_off*sizeof(float)]);
-
- for (int ow = 0; ow < ur_w; ow++) {
- int inp_off = ch*jcp.ih*jcp.iw*ch_blk
- + ow*stride_w*ch_blk + kw*ch_blk*dilate_w + i*4;
-
- Vmm vmm_src = get_src_reg(0);
- uni_vmovups(vmm_src, ptr[aux_reg_input
- + inp_off*sizeof(float)]);
-
- Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w
- + ch*ur_w + ow);
- uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
- }
- }
- }
- }
-
- add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
- add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float));
-
- dec(iter_kh);
- cmp(iter_kh, 0);
- jg(kh_label, T_NEAR);
- }
-
- L(iter_exit_label);
-}
-
-template <cpu_isa_t isa>
-void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_activation(
- int ur_ch_blocks, int ur_w) {
- if (this->jcp.with_eltwise) {
- int repeats = isa == sse42 ? 2 : 1;
- eltwise_injector_->compute_vector_range(4, repeats * ur_w * ur_ch_blocks + 4);
- }
-}
-
-template <cpu_isa_t isa>
-void jit_uni_dw_conv_fwd_kernel_f32<isa>::store_dst(
- int ur_ch_blocks, int ur_w) {
- int ch_blk = jcp.ch_block;
-
- int repeats = isa == sse42 ? 2 : 1;
- for (int i = 0; i < repeats; i++) {
- for (int ch = 0; ch < ur_ch_blocks; ch++) {
- for (int ow = 0; ow < ur_w; ow++) {
- int o_off = ch*jcp.oh*jcp.ow*ch_blk + ow*ch_blk + i*4;
- Vmm vmm_dst = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
-
- uni_vmovups(vmmword[reg_output + o_off*sizeof(float)], vmm_dst);
- }
- }
- }
-}
-
-template <cpu_isa_t isa>
-void jit_uni_dw_conv_fwd_kernel_f32<isa>::loop_body(int ur_ch_blocks) {
- Label unrolled_w_label;
- Label tail_w_label;
- Label exit_label;
-
- L(unrolled_w_label); {
- int ur_w = jcp.ur_w;
-
- cmp(reg_ur_w, ur_w);
- jl(tail_w_label, T_NEAR);
-
- mov(aux_reg_input, reg_input);
- mov(aux_reg_kernel, reg_kernel);
-
- load_src(ur_ch_blocks, ur_w);
- apply_filter_unrolled(ur_ch_blocks, ur_w);
- apply_activation(ur_ch_blocks, ur_w);
- store_dst(ur_ch_blocks, ur_w);
-
- add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
- add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
-
- sub(reg_ur_w, ur_w);
- jmp(unrolled_w_label);
- }
-
- L(tail_w_label); {
- int ur_w = 1;
-
- cmp(reg_ur_w, ur_w);
- jl(exit_label, T_NEAR);
-
- mov(aux_reg_input, reg_input);
- mov(aux_reg_kernel, reg_kernel);
-
- load_src(ur_ch_blocks, ur_w);
- apply_filter(ur_ch_blocks, ur_w);
- apply_activation(ur_ch_blocks, ur_w);
- store_dst(ur_ch_blocks, ur_w);
-
- add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
- add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
-
- sub(reg_ur_w, ur_w);
- jmp(tail_w_label);
- }
-
- L(exit_label);
-}
-
-template <cpu_isa_t isa>
-void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate() {
- this->preamble();
-
- mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
- mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
- mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
- if (jcp.with_bias)
- mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
- mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
- mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
- mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]);
- mov(reg_ur_w, ptr[this->param1 + GET_OFF(ur_w)]);
-
- Label ch_blocks_tail_label;
- Label exit_label;
-
- int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking;
-
- cmp(reg_ch_blocks, jcp.nb_ch_blocking);
- jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR);
-
- loop_body(jcp.nb_ch_blocking); // channel main loop
-
- if (ch_blocks_tail) {
- L(ch_blocks_tail_label);
-
- cmp(reg_ch_blocks, ch_blocks_tail);
- jne(exit_label, T_NEAR);
-
- loop_body(ch_blocks_tail); // channel tail loop
- }
-
- L(exit_label);
-
- this->postamble();
-
- if (jcp.with_eltwise)
- eltwise_injector_->prepare_table();
-}
-
-template <cpu_isa_t isa>
-bool jit_uni_dw_conv_fwd_kernel_f32<isa>::post_ops_ok(
- jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
- const auto &p = attr.post_ops_;
-
- auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
- auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
-
- switch (p.len_) {
- case 0: return true; // no post_ops
- case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
- case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
- default: return false;
- }
-
- return false;
-}
-
-template <cpu_isa_t isa>
-status_t jit_uni_dw_conv_fwd_kernel_f32<isa>::init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
- const primitive_attr_t &attr)
-{
- if (!mayiuse(isa)) return status::unimplemented;
-
- const int simd_w = isa == avx512_common ? 16 : 8;
-
- jcp.prop_kind = cd.prop_kind;
-
- const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
- if (!with_groups) return status::unimplemented;
-
- jcp.ngroups = weights_d.dims()[0];
- jcp.mb = src_d.dims()[0];
-
- jcp.oc = dst_d.dims()[1];
- jcp.oc_without_padding = jcp.oc;
- jcp.ic = src_d.dims()[1];
-
- jcp.ih = src_d.dims()[2];
- jcp.iw = src_d.dims()[3];
- jcp.oh = dst_d.dims()[2];
- jcp.ow = dst_d.dims()[3];
-
- jcp.kh = weights_d.dims()[3];
- jcp.kw = weights_d.dims()[4];
-
- jcp.t_pad = cd.padding[0][0];
- jcp.l_pad = cd.padding[0][1];
- jcp.b_pad = cd.padding[1][0];
- jcp.r_pad = cd.padding[1][1];
-
- jcp.stride_h = cd.strides[0];
- jcp.stride_w = cd.strides[1];
-
- jcp.dilate_h = cd.dilates[0];
- jcp.dilate_w = cd.dilates[1];
-
- jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
-
- if (!post_ops_ok(jcp, attr))
- return status::unimplemented;
-
- const auto &p = attr.post_ops_;
- jcp.with_sum = p.find(primitive_kind::sum) != -1;
- const int eltwise_ind = p.find(primitive_kind::eltwise);
- jcp.with_eltwise = eltwise_ind != -1;
- if (jcp.with_eltwise)
- jcp.eltwise = p.entry_[eltwise_ind].eltwise;
-
- bool ok_to_pad_channels = true
- && jcp.oc == jcp.ngroups
- && jcp.ic == jcp.ngroups
- && one_of(isa, avx512_common, avx2);
- if (ok_to_pad_channels) {
- jcp.oc = rnd_up(jcp.oc, simd_w);
- jcp.ic = rnd_up(jcp.oc, simd_w);
- jcp.ngroups = rnd_up(jcp.ngroups, simd_w);
- }
-
- auto dat_tag = isa == avx512_common ? nChw16c : nChw8c;
- auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g;
-
- jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
- jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
- jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
-
- bool args_ok = true
- && jcp.oc == jcp.ngroups
- && jcp.ic == jcp.ngroups
- && jcp.ngroups % simd_w == 0
- && jcp.src_tag == dat_tag
- && jcp.wei_tag == wei_tag
- && jcp.dst_tag == dat_tag
- && jcp.ic <= src_d.padded_dims()[1]
- && jcp.oc <= dst_d.padded_dims()[1]
- && jcp.ngroups <= weights_d.padded_dims()[0];
- if (!args_ok) return status::unimplemented;
-
- jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
-
- jcp.ch_block = simd_w;
- jcp.nb_ch = jcp.oc / jcp.ch_block;
- jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2;
- if (jcp.nb_ch < jcp.nb_ch_blocking)
- jcp.nb_ch_blocking = jcp.nb_ch;
-
- return status::success;
-}
-
-template <cpu_isa_t isa>
-void jit_uni_dw_conv_fwd_kernel_f32<isa>::init_scratchpad(
- memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
- if (jcp.with_bias && jcp.oc_without_padding != jcp.oc)
- scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
-}
-
-template struct jit_uni_dw_conv_fwd_kernel_f32<avx512_common>;
-template struct jit_uni_dw_conv_fwd_kernel_f32<avx2>;
-template struct jit_uni_dw_conv_fwd_kernel_f32<sse42>;
-
-template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::load_ddst(
- int ur_ch_blocks, int ur_str_w) {
- int repeats = isa == sse42 ? 2 : 1;
- for (int i = 0; i < repeats; i++) {
- for (int ch = 0; ch < ur_ch_blocks; ch++) {
- for (int w = 0; w < ur_str_w; w++) {
- Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w
- + ch*ur_str_w + w);
- uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
- }
- }
- }
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::apply_filter(
- int ur_ch_blocks, int ur_str_w) {
- int kw = jcp.kw;
- int kh = jcp.kh;
- int ow = jcp.ow;
- int oh = jcp.oh;
-
- int ch_blk = jcp.ch_block;
- int stride_h = jcp.stride_h;
- int stride_w = jcp.stride_w;
-
- Label iter_exit_label;
-
- cmp(reg_kh, 0);
- je(iter_exit_label, T_NEAR);
-
- cmp(reg_kw, 0);
- je(iter_exit_label, T_NEAR);
-
- mov(iter_kh, reg_kh);
- Label kh_label;
- L(kh_label); {
- mov(aux1_reg_ddst, aux_reg_ddst);
- mov(aux1_reg_kernel, aux_reg_kernel);
-
- mov(iter_kw, reg_kw);
- Label kw_label;
- L(kw_label); {
- int repeats = isa == sse42 ? 2 : 1;
- for (int i = 0; i < repeats; i++) {
- for (int ch = 0; ch < ur_ch_blocks; ch++) {
- int ker_off = ch*kh*kw*ch_blk + i*4;
- Vmm vmm_ker = get_ker_reg(0);
- uni_vmovups(vmm_ker, ptr[aux1_reg_kernel
- + ker_off*sizeof(float)]);
-
- for (int w = 0; w < ur_str_w; w++) {
- int ddst_off = (ch*oh*ow + w)*ch_blk + i*4;
-
- Vmm vmm_src = get_src_reg(0);
- uni_vmovups(vmm_src, ptr[aux1_reg_ddst
- + ddst_off*sizeof(float)]);
-
- Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w
- + ch*ur_str_w + w);
- uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
- }
- }
- }
-
- add(aux1_reg_kernel, ch_blk*stride_w*sizeof(float));
- sub(aux1_reg_ddst, ch_blk*sizeof(float));
-
- sub(iter_kw, stride_w);
- cmp(iter_kw, 0);
- jg(kw_label, T_NEAR);
- }
-
- add(aux_reg_kernel, kw*ch_blk*stride_h*sizeof(float));
- sub(aux_reg_ddst, ow*ch_blk*sizeof(float));
-
- sub(iter_kh, stride_h);
- cmp(iter_kh, 0);
- jg(kh_label, T_NEAR);
- }
-
- L(iter_exit_label);
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::store_dsrc(
- int ur_ch_blocks, int ur_str_w) {
- int ch_blk = jcp.ch_block;
- int iw = jcp.iw;
- int ih = jcp.ih;
- int stride_w = jcp.stride_w;
-
- int repeats = isa == sse42 ? 2 : 1;
- for (int i = 0; i < repeats; i++) {
- for (int ch = 0; ch < ur_ch_blocks; ch++) {
- for (int w = 0; w < ur_str_w; w++) {
- int dsrc_off = (ch*ih*iw + w*stride_w)*ch_blk + i*4;
- Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w
- + ch*ur_str_w + w);
-
- uni_vmovups(ptr[reg_dsrc + dsrc_off*sizeof(float)], vmm_acc);
- }
- }
- }
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::loop_body(
- int ur_ch_blocks) {
- Label unrolled_w_label;
- Label tail_w_label;
- Label exit_label;
-
- L(unrolled_w_label); {
- int ur_w = jcp.ur_w;
-
- cmp(reg_ur_str_w, ur_w);
- jl(tail_w_label, T_NEAR);
-
- mov(aux_reg_ddst, reg_ddst);
- mov(aux_reg_kernel, reg_kernel);
-
- load_ddst(ur_ch_blocks, ur_w);
- apply_filter(ur_ch_blocks, ur_w);
- store_dsrc(ur_ch_blocks, ur_w);
-
- add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
- add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block);
-
- sub(reg_ur_str_w, ur_w);
- jmp(unrolled_w_label);
- }
-
- L(tail_w_label); {
- int ur_w = 1;
-
- cmp(reg_ur_str_w, ur_w);
- jl(exit_label, T_NEAR);
-
- mov(aux_reg_ddst, reg_ddst);
- mov(aux_reg_kernel, reg_kernel);
-
- load_ddst(ur_ch_blocks, ur_w);
- apply_filter(ur_ch_blocks, ur_w);
- store_dsrc(ur_ch_blocks, ur_w);
-
- add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
- add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block);
-
- sub(reg_ur_str_w, ur_w);
- jmp(tail_w_label);
- }
-
- L(exit_label);
-}
-
-template <cpu_isa_t isa>
-void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::generate() {
- preamble();
-
- mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]);
- mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]);
- mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
- mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
- mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
- mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]);
- mov(reg_ur_str_w, ptr[this->param1 + GET_OFF(ur_str_w)]);
-
- Label ch_blocks_tail_label;
- Label exit_label;
-
- int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking;
-
- cmp(reg_ch_blocks, jcp.nb_ch_blocking);
- jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR);
-
- loop_body(jcp.nb_ch_blocking); // channel main loop
-
- if (ch_blocks_tail) {
- L(ch_blocks_tail_label);
-
- cmp(reg_ch_blocks, ch_blocks_tail);
- jne(exit_label, T_NEAR);
-
- loop_body(ch_blocks_tail); // channel tail loop
- }
-
- L(exit_label);
-
- this->postamble();
-}
-
-template <cpu_isa_t isa>
-status_t jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_conf(
- jit_conv_conf_t &jcp, const convolution_desc_t &cd,
- const memory_desc_wrapper &diff_src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &diff_dst_d) {
- if (!mayiuse(isa)) return status::unimplemented;
-
- const int simd_w = isa == avx512_common ? 16 : 8;
-
- const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
- if (!with_groups) return status::unimplemented;
-
- jcp.ngroups = weights_d.dims()[0];
- jcp.mb = diff_src_d.dims()[0];
-
- jcp.oc = diff_dst_d.dims()[1];
- jcp.oc_without_padding = jcp.oc;
- jcp.ic = diff_src_d.dims()[1];
-
- jcp.ih = diff_src_d.dims()[2];
- jcp.iw = diff_src_d.dims()[3];
- jcp.oh = diff_dst_d.dims()[2];
- jcp.ow = diff_dst_d.dims()[3];
-
- jcp.kh = weights_d.dims()[3];
- jcp.kw = weights_d.dims()[4];
-
- jcp.t_pad = cd.padding[0][0];
- jcp.l_pad = cd.padding[0][1];
- jcp.b_pad = cd.padding[1][0];
- jcp.r_pad = cd.padding[1][1];
-
- jcp.stride_h = cd.strides[0];
- jcp.stride_w = cd.strides[1];
-
- jcp.dilate_h = cd.dilates[0];
- jcp.dilate_w = cd.dilates[1];
-
- jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
- jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
-
- bool ok_to_pad_channels = true
- && jcp.oc == jcp.ngroups
- && jcp.ic == jcp.ngroups
- && one_of(isa, avx512_common, avx2);
- if (ok_to_pad_channels) {
- jcp.oc = rnd_up(jcp.oc, simd_w);
- jcp.ic = rnd_up(jcp.oc, simd_w);
- jcp.ngroups = rnd_up(jcp.ngroups, simd_w);
- }
-
- auto dat_tag = isa == avx512_common ? nChw16c : nChw8c;
- auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g;
-
- jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag);
- jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
- jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag);
-
- bool args_ok = true
- && jcp.oc == jcp.ngroups
- && jcp.ic == jcp.ngroups
- && jcp.ngroups % simd_w == 0
- && jcp.dilate_h == 0
- && jcp.dilate_w == 0
- && jcp.src_tag == dat_tag
- && jcp.wei_tag == wei_tag
- && jcp.dst_tag == dat_tag
- && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
- && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1
- && jcp.ic <= diff_src_d.padded_dims()[1]
- && jcp.oc <= diff_dst_d.padded_dims()[1]
- && jcp.ngroups <= weights_d.padded_dims()[0];
- if (!args_ok) return status::unimplemented;
-
- jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
-
- jcp.ch_block = simd_w;
- jcp.nb_ch = jcp.ic / jcp.ch_block;
- jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2;
- if (jcp.nb_ch < jcp.nb_ch_blocking)
- jcp.nb_ch_blocking = jcp.nb_ch;
-
- return status::success;
-}
-
-template <cpu_isa_t isa>
-void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_scratchpad(
- memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
- UNUSED(scratchpad);
- UNUSED(jcp);
-}
-
-template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx512_common>;
-template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx2>;
-template struct jit_uni_dw_conv_bwd_data_kernel_f32<sse42>;
-
-template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_filter() {
- for (int r = 0; r < reg_repeats; ++r) {
- for (int i = 0; i < jcp.kw; ++i) {
- Vmm vmm_acc = get_acc_reg(r * jcp.kw + i);
- uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
- }
- }
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::load_filter() {
- for (int r = 0; r < reg_repeats; ++r) {
- const int reg_set = r * jcp.kw;
- for (int i = 0; i < jcp.kw; ++i) {
- int off_filter = (reg_set + i) * simd_w;
- Vmm vmm_acc = get_acc_reg(reg_set + i);
- uni_vmovups(vmm_acc,
- vmmword[reg_tmp_filter + off_filter * sizeof(float)]);
- }
- }
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_bias() {
- for (int r = 0; r < reg_repeats; ++r) {
- Vmm vmm_bias = get_bias_reg(r);
- uni_vpxor(vmm_bias, vmm_bias, vmm_bias);
- }
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::load_bias() {
- for (int r = 0; r < reg_repeats; ++r) {
- Vmm vmm_bias = get_bias_reg(r);
- uni_vmovups(
- vmm_bias, vmmword[reg_bias_baddr + r * simd_w * sizeof(float)]);
- }
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_step_unroll(
- int unroll_w, int l_pad, int pad_offset, int ow_block) {
-
- const int iw_block = ow_block * jcp.stride_w;
- const int right_border = jcp.iw - iw_block;
-
- const int cascade_input = nstl::min(jcp.stride_w, jcp.kw);
-
- /* preamble count for number of cascaded LOAD + FMA operation */
- const int input_overlap = nstl::max(jcp.kw - l_pad, 0);
-
- /* LOAD initial input registers, then cascade LOADs and FMAs*/
- for (int r = 0; r < reg_repeats; ++r) {
- for (int i_ur = 0; i_ur < unroll_w; ++i_ur) {
- int off_output = (i_ur * reg_repeats + r) * simd_w;
- Vmm vmm_output = get_output_reg(r);
- uni_vmovups(vmm_output,
- ptr[reg_tmp_output + off_output * sizeof(float)]);
- if (i_ur == 0) {
- for (int c = 0; c < input_overlap; ++c) {
- int off_input
- = ((c - pad_offset) * reg_repeats + r) * simd_w;
- Vmm vmm_input
- = get_input_reg((c % jcp.kw) * reg_repeats + r);
- uni_vmovups(vmm_input,
- ptr[reg_tmp_input + off_input * sizeof(float)]);
- }
- } else {
- for (int c = 0; c < cascade_input; ++c) {
- int overlap = (i_ur - 1) * jcp.stride_w + input_overlap;
- int off_input
- = ((overlap + c - pad_offset) * reg_repeats + r)
- * simd_w;
- Vmm vmm_input = get_input_reg(
- ((overlap + c) % jcp.kw) * reg_repeats + r);
- uni_vmovups(vmm_input,
- ptr[reg_tmp_input + off_input * sizeof(float)]);
- }
- }
-
- for (int i_kw = 0; i_kw < jcp.kw; ++i_kw) {
- int io_overlap = i_kw + (i_ur * jcp.stride_w);
-
- /* Don't apply FMAs that fall into the padded region */
- if (io_overlap - l_pad < 0
- || io_overlap - jcp.l_pad >= right_border)
- continue;
-
- Vmm vmm_input = get_input_reg(
- ((io_overlap - l_pad) % jcp.kw) * reg_repeats + r);
- Vmm vmm_acc = get_acc_reg(i_kw * reg_repeats + r);
- Vmm vmm_aux = isa == sse42 ? get_aux_reg() : vmm_input;
- if (isa == sse42)
- uni_vmovups(vmm_aux, vmm_input);
- uni_vfmadd231ps(vmm_acc, vmm_aux, vmm_output);
- }
- }
- }
-}
-
-template <cpu_isa_t isa>
-inline void
-jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias_step_unroll(
- const int unroll_w) {
- for (int r = 0; r < reg_repeats; ++r) {
- for (int i = 0; i < unroll_w; ++i) {
- Vmm vmm_bias = get_bias_reg(r);
- int off_output = (i * reg_repeats + r) * simd_w;
- if (isa == sse42) {
- /* Need to support unaligned address loads for SSE42*/
- Vmm vmm_output = get_output_reg(1 + r);
- uni_vmovups(vmm_output,
- ptr[reg_tmp_output + off_output * sizeof(float)]);
- uni_vaddps(vmm_bias, vmm_bias, vmm_output);
- } else {
- uni_vaddps(vmm_bias, vmm_bias,
- vmmword[reg_tmp_output + off_output * sizeof(float)]);
- }
- }
- }
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::store_filter() {
- for (int r = 0; r < reg_repeats; ++r) {
- const int reg_set = r * jcp.kw;
- for (int i = 0; i < jcp.kw; ++i) {
- int off_filter = (i + reg_set) * simd_w;
- Vmm vmm_acc = get_acc_reg(i + reg_set);
- uni_vmovups(vmmword[reg_tmp_filter + off_filter * sizeof(float)],
- vmm_acc);
- }
- }
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::store_bias() {
- for (int r = 0; r < reg_repeats; ++r) {
- Vmm vmm_bias = get_bias_reg(r);
- uni_vmovups(
- vmmword[reg_bias_baddr + r * simd_w * sizeof(float)], vmm_bias);
- }
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias_loop(
- const int block_size) {
- Label oh_label;
- Label ow_blk_label;
-
- const int unroll_w = nstl::min(block_size, jcp.ow);
- const int unroll_w_trips = jcp.ow / unroll_w;
- const int tail_w = jcp.ow > block_size ? jcp.ow % block_size : 0;
-
- const int ch_offset = jcp.ch_block;
-
- mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]);
- mov(reg_oh_worksize,
- ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]);
-
- mov(reg_tmp_output, reg_output_baddr);
- L(oh_label);
- {
-
- mov(iter_ow_blk, unroll_w_trips);
- L(ow_blk_label);
- {
-
- compute_bias_step_unroll(unroll_w);
- add(reg_tmp_output, unroll_w * ch_offset * sizeof(float));
-
- dec(iter_ow_blk);
- cmp(iter_ow_blk, 0);
- jg(ow_blk_label, T_NEAR);
- }
-
- if (tail_w > 0) {
- compute_bias_step_unroll(tail_w);
- add(reg_tmp_output, tail_w * ch_offset * sizeof(float));
- }
-
- inc(reg_oh);
- cmp(reg_oh, reg_oh_worksize);
- jl(oh_label, T_NEAR);
- }
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_zero_filter() {
-
- const int ch_offset = jcp.ch_block;
-
- Label kh_loop_label, skip_zeroing_label;
-
- mov(reg_exec_flags,
- ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]);
- and_(reg_exec_flags, FLAG_ZERO_FILTER);
- test(reg_exec_flags, reg_exec_flags);
- je(skip_zeroing_label);
-
- zero_filter();
-
- mov(reg_tmp_filter, reg_filter_baddr);
- mov(reg_kh, jcp.kh);
- L(kh_loop_label);
- {
- store_filter();
-
- add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float));
- dec(reg_kh);
- cmp(reg_kh, 0);
- jg(kh_loop_label);
- }
-
- /* Comeback pointers */
- sub(reg_tmp_filter, jcp.kh * jcp.kw * ch_offset * sizeof(float));
-
- L(skip_zeroing_label);
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_h_step(
- int unroll_w, int l_pad, int pad_offset, int ow_block) {
-
- const int ch_offset = jcp.ch_block;
-
- Label kh_loop_label, skip_loop_label;
-
- cmp(reg_kh_count, 0);
- je(skip_loop_label, T_NEAR);
-
- mov(reg_kh, reg_kh_count);
- L(kh_loop_label);
- {
- load_filter();
- compute_ow_step_unroll(unroll_w, l_pad, pad_offset, ow_block);
- store_filter();
-
- add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float));
- add(reg_tmp_input, jcp.iw * ch_offset * sizeof(float));
- dec(reg_kh);
- cmp(reg_kh, 0);
- jg(kh_loop_label);
- }
-
- /* Comeback pointers */
- Label kh_comeback_label;
- mov(reg_kh, reg_kh_count);
- L(kh_comeback_label);
- {
- sub(reg_tmp_input, jcp.iw * ch_offset * sizeof(float));
- sub(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float));
- dec(reg_kh);
- cmp(reg_kh, 0);
- jg(kh_comeback_label, T_NEAR);
- }
-
- L(skip_loop_label);
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_h_loop(
- int unroll_w, int l_pad, int pad_offset, int ow_block) {
-
- const size_t io_overlap = jcp.ih / jcp.stride_h < jcp.oh ?
- jcp.ih / jcp.stride_h - 1 :
- jcp.oh - jcp.b_pad - 1;
- const int ch_offset = jcp.ch_block;
- const int t_overlap_off = jcp.t_pad % jcp.stride_h == 0 ? jcp.stride_h : 1;
- const int b_overlap_off = jcp.b_pad % jcp.stride_h == 0 ? jcp.stride_h : 1;
-
- Label tpad_loop_label, h_loop_label, skip_tpad_label, skip_bpad_label,
- end_h_loop_label;
-
- mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]);
- mov(reg_oh_worksize,
- ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]);
- mov(reg_kh_count,
- ptr[this->param1 + offsetof(jit_dw_conv_call_s, kh_count)]);
-
- mov(reg_tmp_output, reg_output_baddr);
- mov(reg_tmp_input, reg_input_baddr);
- mov(reg_tmp_filter, reg_filter_baddr);
-
- L(h_loop_label);
- {
-
- compute_h_step(unroll_w, l_pad, pad_offset, ow_block);
-
- add(reg_tmp_output, jcp.ow * ch_offset * sizeof(float));
-
- /* If within the top_pad region */
- if (jcp.t_pad > 0) {
- /* Skip t_pad area if no longer in initial h_block */
- cmp(reg_oh, jcp.t_pad);
- jg(skip_tpad_label, T_NEAR);
-
- cmp(reg_kh_count, jcp.kh);
- jge(skip_tpad_label, T_NEAR);
-
- add(reg_kh_count, t_overlap_off);
- sub(reg_tmp_filter,
- t_overlap_off * jcp.kw * ch_offset * sizeof(float));
-
- /* kernel has moved beyond padding (adjust for stride effects) */
- if (jcp.t_pad % jcp.stride_h != 0) {
- int inp_corr = jcp.stride_h - jcp.t_pad % jcp.stride_h;
- add(reg_tmp_input,
- inp_corr * jcp.iw * ch_offset * sizeof(float));
- }
- jmp(tpad_loop_label, T_NEAR);
- }
-
- L(skip_tpad_label);
-
- cmp(reg_oh, io_overlap);
- jl(skip_bpad_label, T_NEAR);
- sub(reg_kh_count, b_overlap_off);
-
- L(skip_bpad_label);
- add(reg_tmp_input, jcp.stride_h * jcp.iw * ch_offset * sizeof(float));
-
- L(tpad_loop_label);
-
- cmp(reg_oh, jcp.ih / jcp.stride_h);
- jge(end_h_loop_label, T_NEAR);
-
- inc(reg_oh);
-
- cmp(reg_oh, reg_oh_worksize);
- jl(h_loop_label, T_NEAR);
- }
- L(end_h_loop_label);
-}
-
-template <cpu_isa_t isa>
-inline void
-jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_block_unroll() {
-
- const int ch_offset = jcp.ch_block;
- int ow = jcp.ow;
- int pad_offset = 0;
- int l_pad = jcp.l_pad;
-
- /* Calculate effective padding */
- int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
- + (jcp.kw - 1) * (jcp.dilate_w + 1)
- - (jcp.iw + jcp.l_pad - 1));
-
- /* Is this strictly defined by:
- * -code-size (?)
- * -address size (?) */
- const int max_unroll_w = 30;
- const int block_size = 15;
-
- int unroll_w_tail = 0;
- int unroll_w = 0;
- int unroll_w_trips = 0;
-
- if (jcp.ow > max_unroll_w) {
- unroll_w = nstl::min(block_size, jcp.ow);
- unroll_w_trips = ow / unroll_w;
- /* calculate tail */
- unroll_w_tail = ow % unroll_w;
- /* Perform some rebalancing if tail too small*/
- if ((unroll_w_tail == 0 && r_pad != 0)
- || (r_pad > 0 && r_pad >= unroll_w_tail)) {
- if (unroll_w_trips > 1) {
- unroll_w_tail += unroll_w;
- unroll_w_trips--;
- } else {
- /* Idealy, this case shouldn't happen */
- unroll_w_tail += (unroll_w - unroll_w / 2);
- unroll_w = unroll_w / 2;
- }
- }
- } else {
- unroll_w = jcp.ow;
- unroll_w_trips = nstl::max(1, ow / unroll_w);
- }
- if (jcp.with_bias) {
- Label skip_load_bias;
- mov(reg_bias_baddr,
- ptr[this->param1 + offsetof(jit_dw_conv_call_s, bias)]);
-
- zero_bias();
-
- mov(reg_exec_flags,
- ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]);
- and_(reg_exec_flags, FLAG_ZERO_BIAS);
- test(reg_exec_flags, reg_exec_flags);
- jne(skip_load_bias);
-
- load_bias();
-
- L(skip_load_bias);
- compute_bias_loop(block_size);
-
- store_bias();
- }
-
- /* Pass filter address, then offset for h_padding. */
- compute_zero_filter();
- mov(reg_kh_offset,
- ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter_pad_off)]);
- add(reg_filter_baddr, reg_kh_offset);
-
- /* compute left padded block */
- if (l_pad) {
- compute_h_loop(unroll_w, l_pad, 0, 0);
- add(reg_output_baddr, unroll_w * ch_offset * sizeof(float));
- add(reg_input_baddr,
- unroll_w * jcp.stride_w * ch_offset * sizeof(float));
- unroll_w_trips--;
- pad_offset = l_pad;
- l_pad = 0;
- }
-
- /* compute middle block */
- Label ow_blk_label;
-
- /* Insert loop for 'ow' block when middle block needs to execute more
- * than once */
- bool do_ow_blk_loop = unroll_w_trips > 1;
- if (do_ow_blk_loop) {
- mov(iter_ow_blk, unroll_w_trips);
- L(ow_blk_label);
- }
- if (unroll_w_trips > 0) {
- compute_h_loop(unroll_w, l_pad, pad_offset, 0);
- add(reg_output_baddr, unroll_w * ch_offset * sizeof(float));
- add(reg_input_baddr,
- unroll_w * jcp.stride_w * ch_offset * sizeof(float));
- }
- if (do_ow_blk_loop) {
- dec(iter_ow_blk);
- cmp(iter_ow_blk, 0);
- jg(ow_blk_label, T_NEAR);
- }
-
- /* compute right padded block */
- if (unroll_w_tail) {
- compute_h_loop(unroll_w_tail, 0, pad_offset, jcp.ow - unroll_w_tail);
- }
-}
-
-template <cpu_isa_t isa>
-void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::generate() {
- preamble();
-
- mov(reg_input_baddr,
- ptr[this->param1 + offsetof(jit_dw_conv_call_s, input)]);
- mov(reg_output_baddr,
- ptr[this->param1 + offsetof(jit_dw_conv_call_s, output)]);
- mov(reg_filter_baddr,
- ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter)]);
-
- compute_ow_block_unroll();
-
- this->postamble();
-}
-
-template <cpu_isa_t isa>
-status_t jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::init_conf(
- jit_conv_conf_t &jcp, const convolution_desc_t &cd,
- const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &diff_weights_d,
- const memory_desc_wrapper &diff_dst_d, int nthreads) {
- if (!mayiuse(isa))
- return status::unimplemented;
-
- jcp.ngroups = diff_weights_d.dims()[0];
- jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
- jcp.ic = src_d.dims()[1] / jcp.ngroups;
-
- const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
-
- jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.oc, jcp.ic);
-
- if (!jcp.is_depthwise)
- return status::unimplemented;
-
- jcp.ch_block = isa == avx512_common ? 16 : 8;
-
- jcp.mb = src_d.dims()[0];
-
- jcp.ih = src_d.dims()[2];
- jcp.iw = src_d.dims()[3];
- jcp.oh = diff_dst_d.dims()[2];
- jcp.ow = diff_dst_d.dims()[3];
-
- jcp.kh = diff_weights_d.dims()[3];
- jcp.kw = diff_weights_d.dims()[4];
-
- jcp.stride_h = cd.strides[0];
- jcp.stride_w = cd.strides[1];
-
- jcp.t_pad = cd.padding[0][0];
- jcp.b_pad = cd.padding[1][0];
-
- jcp.l_pad = cd.padding[0][1];
- jcp.r_pad = cd.padding[1][1];
-
- jcp.dilate_h = cd.dilates[0];
- jcp.dilate_w = cd.dilates[1];
-
- jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
- jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
-
- jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
-
- auto dat_tag = isa == avx512_common ? nChw16c : nChw8c;
- auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g;
-
- jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
- jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
- jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag);
-
- bool args_ok = true
- && jcp.src_tag == dat_tag
- && jcp.wei_tag == wei_tag
- && jcp.dst_tag == dat_tag
- && jcp.ngroups % jcp.ch_block == 0 && jcp.dilate_h == 0
- && jcp.dilate_w == 0 && jcp.kw <= 3
- && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
- && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1;
- if (!args_ok)
- return status::unimplemented;
-
- jcp.nb_ch = jcp.ngroups / jcp.ch_block;
-
- /* kernel applicability check wrt boundaries
- * the conditions are quite general across the kernels we have,
- * but ideally the check should belong to a specific kernel... */
- const int max_hpad = (jcp.kh - 1 + 1) / 2;
- const int max_wpad = (jcp.kw - 1 + 1) / 2;
- const bool boundaries_ok = true && jcp.t_pad <= max_hpad
- && jcp.b_pad <= max_hpad && jcp.l_pad <= max_wpad
- && jcp.r_pad <= max_wpad;
- if (!boundaries_ok)
- return status::unimplemented;
-
- balance(jcp, nthreads);
-
- return status::success;
-}
-
-template <cpu_isa_t isa>
-void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::init_scratchpad(
- memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
- /* Notes: if splitting thread work on 'mb', then a reduction has to take
- * place. Hence, book a per-thread, local weights-buffer for the
- * reduction */
- if (jcp.nthr_mb > 1) {
- const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
- scratchpad.book(key_conv_wei_reduction,
- sizeof(float) * wei_size * (jcp.nthr_mb - 1));
-
- if (jcp.with_bias)
- scratchpad.book(key_conv_bia_reduction,
- sizeof(float) * jcp.ngroups * (jcp.nthr_mb - 1));
- }
-}
-
-template <cpu_isa_t isa>
-void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::balance(jit_conv_conf_t &jcp,
- int nthreads) {
- jcp.nthr = nthreads;
- jcp.nthr_g = jcp.nthr_mb = 1;
-
- /* Basic-Heuristics for parallel strategy:
- * 1) Tries to parallel on the number of Groups (g) where tasks are
- * independent. Otherwise,
- * 2) Tries to split the work across g and MiniBatch (mb).
- * Parallelizing on mb requires computing a reduction for weights.
- *
- * NOTE: because of 'task partitioning' scheme, there will be unbalanced
- * per-thread load when the number of threads is high (e.g. > 16).
- */
- jcp.nthr_g = nstl::min(jcp.nb_ch, jcp.nthr);
- jcp.nthr_mb = nstl::min(nstl::max(1, jcp.nthr / jcp.nthr_g), jcp.mb);
-
- jcp.nthr = jcp.nthr_g * jcp.nthr_mb;
-}
-
-template struct jit_uni_dw_conv_bwd_weights_kernel_f32<avx512_common>;
-template struct jit_uni_dw_conv_bwd_weights_kernel_f32<avx2>;
-template struct jit_uni_dw_conv_bwd_weights_kernel_f32<sse42>;
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp
deleted file mode 100644
index 9c08fc4a09..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp
+++ /dev/null
@@ -1,253 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef JIT_UNI_DW_CONV_KERNEL_F32_HPP
-#define JIT_UNI_DW_CONV_KERNEL_F32_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-
-#include "jit_generator.hpp"
-#include "jit_primitive_conf.hpp"
-#include "jit_uni_eltwise.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <cpu_isa_t isa>
-struct jit_uni_dw_conv_fwd_kernel_f32: public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_fwd_kernel_f32)
-
- jit_uni_dw_conv_fwd_kernel_f32(jit_conv_conf_t ajcp)
- : jcp(ajcp), eltwise_injector_(nullptr)
- {
- if (jcp.with_eltwise)
- eltwise_injector_ = new jit_uni_eltwise_injector_f32<isa>(this,
- jcp.eltwise);
-
- this->generate();
- jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
- }
-
- ~jit_uni_dw_conv_fwd_kernel_f32() {
- delete eltwise_injector_;
- }
-
- static bool post_ops_ok(jit_conv_conf_t &jcp,
- const primitive_attr_t &attr);
- static status_t init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
-
- static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
- const jit_conv_conf_t &jcp);
-
- jit_conv_conf_t jcp;
- void (*jit_ker)(jit_conv_call_s *);
-
-private:
- using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
- isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
- using reg64_t = const Xbyak::Reg64;
- const Xbyak::AddressFrame &vmmword = (isa == sse42)
- ? xword : (isa == avx2) ? yword : zword;
- const int vlen = cpu_isa_traits<isa>::vlen;
-
- // dw convolution
- reg64_t reg_input = r8;
- reg64_t aux_reg_input = r9;
- reg64_t aux1_reg_input = r10;
- reg64_t reg_kernel = r11;
- reg64_t aux_reg_kernel = r12;
- reg64_t aux1_reg_kernel = r13;
- reg64_t reg_output = r14;
- reg64_t reg_bias = r15;
- reg64_t reg_kh = rax;
- reg64_t reg_kw = rbx;
- reg64_t iter_kh = rdx;
- reg64_t iter_kw = rsi;
- reg64_t reg_ur_w = rbp;
- reg64_t reg_ch_blocks = aux1_reg_input;
- reg64_t imm_addr64 = aux1_reg_input;
-
- inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); }
- inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); }
- inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); }
-
- inline void load_src(int ur_ch_blocks, int ur_w);
- inline void apply_filter(int ur_ch_blocks, int ur_w);
- inline void apply_filter_unrolled(int ur_ch_blocks, int ur_w);
- inline void apply_activation(int ur_ch_blocks, int ur_w);
- inline void store_dst(int ur_ch_blocks, int ur_w);
- inline void loop_body(int ur_ch_blocks);
-
- jit_uni_eltwise_injector_f32<isa> *eltwise_injector_;
-
- void generate();
-};
-
-template <cpu_isa_t isa>
-struct jit_uni_dw_conv_bwd_data_kernel_f32: public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_data_kernel_f32)
-
- jit_uni_dw_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) {
- this->generate();
- jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
- }
-
- static status_t init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd,
- const memory_desc_wrapper &diff_src_d,
- const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &diff_dst_d);
-
- static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
- const jit_conv_conf_t &jcp);
-
- jit_conv_conf_t jcp;
- void (*jit_ker)(jit_conv_call_s *);
-
-private:
- using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
- isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
- using reg64_t = const Xbyak::Reg64;
-
- inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); }
- inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); }
- inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); }
-
- reg64_t reg_ddst = rax;
- reg64_t aux_reg_ddst = r8;
- reg64_t aux1_reg_ddst = abi_not_param1;
- reg64_t reg_kernel = rdx;
- reg64_t aux_reg_kernel = r10;
- reg64_t aux1_reg_kernel = rbp;
- reg64_t reg_dsrc = rsi;
-
- reg64_t reg_ur_str_w = r9;
- reg64_t reg_ch_blocks = rbx;
-
- reg64_t iter_kh = r11;
- reg64_t iter_kw = r12;
- reg64_t reg_kh = r13;
- reg64_t reg_kw = r14;
-
- inline void loop_body(int ur_ch_blocks);
- inline void load_ddst(int ur_ch_blocks, int ur_str_w);
- inline void apply_filter(int ur_ch_blocks, int ur_str_w);
- inline void store_dsrc(int ur_ch_blocks, int ur_str_w);
-
- void generate();
-};
-
-template <cpu_isa_t isa>
-struct jit_uni_dw_conv_bwd_weights_kernel_f32 : public jit_generator {
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_weights_kernel_f32)
-
- jit_uni_dw_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp) : jcp(ajcp) {
- this->generate();
- jit_ker = (void (*)(jit_dw_conv_call_s *)) this->getCode();
- }
-
- static status_t init_conf(jit_conv_conf_t &jcp,
- const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
- const memory_desc_wrapper &diff_weights_d,
- const memory_desc_wrapper &diff_dst_d, int nthreads);
-
- static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
- const jit_conv_conf_t &jcp);
-
- static void balance(jit_conv_conf_t &jcp, int nthreads);
-
- jit_conv_conf_t jcp;
- void (*jit_ker)(jit_dw_conv_call_s *);
-
-private:
- using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
- isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
- using reg64_t = const Xbyak::Reg64;
- const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
- const int reg_repeats = (isa == sse42) ? 2 : 1;
-
- const Xbyak::AddressFrame &vmmword
- = (isa == sse42) ? xword : (isa == avx2) ? yword : zword;
-
- /* XXX: offset between input and accummulators is 3, therefore, assume 'kw'
- * is no larger than 3*/
- inline Vmm get_bias_reg(int idx = 0) { return Vmm(idx); }
- inline Vmm get_output_reg(int idx) { return Vmm(idx + 1); }
- inline Vmm get_input_reg(int idx) { return Vmm(idx + 4 * reg_repeats + 1); }
- inline Vmm get_acc_reg(int idx) { return Vmm(idx + 1 * reg_repeats + 1); }
- inline Vmm get_aux_reg() { return Vmm(0); }
-
- reg64_t reg_tmp_input = r9;
- reg64_t reg_tmp_output = r10;
- reg64_t reg_tmp_filter = r13;
- reg64_t reg_kh_offset = rax;
-
- /* parameter passed by driver into kernel */
- Xbyak::Reg8 reg_exec_flags = bl;
-
- reg64_t reg_oh_worksize = r14;
- reg64_t reg_oh = rax;
-
- reg64_t iter_ow_blk = r11;
-
- reg64_t reg_kh = rsi;
- reg64_t reg_kh_count = rdx;
-
- /* Base addresses for convolution parameters. */
- reg64_t reg_input_baddr = r15;
- reg64_t reg_output_baddr = r12;
- reg64_t reg_filter_baddr = abi_not_param1;
- reg64_t reg_bias_baddr = r13;
-
- /* Micro-kernel JIT'ing, fusing 'kw' and 'ow_block' loops into unrolled FMAs
- */
- inline void compute_ow_step_unroll(
- int unroll_w, int l_pad, int pad_offset, int ow_block);
-
- /* JIT'ing the outer loops for the micro-kernel -> {kh, oh_block} */
- inline void compute_h_step(
- int unroll_w, int l_pad, int pad_offset, int ow_block);
- inline void compute_h_loop(
- int unroll_w, int l_pad, int pad_offset, int ow_block);
-
- /* Write 'width' micro-kernel JITs; depending on the padding and convolution
- * size, write a micro-kernel for the left ow-block, middle ow-block(s), and
- * right ow-block.*/
- inline void compute_ow_block_unroll();
-
- inline void compute_zero_filter();
- inline void load_filter();
- inline void zero_filter();
- inline void load_bias();
- inline void zero_bias();
- inline void compute_bias_step_unroll(const int unroll_w);
- inline void compute_bias_loop(const int block_size);
- inline void store_filter();
- inline void store_bias();
-
- void generate();
-};
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp
deleted file mode 100644
index 58449601a3..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp
+++ /dev/null
@@ -1,427 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "mkldnn_thread.hpp"
-
-#include "jit_uni_dw_convolution.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::memory_tracking::names;
-using namespace mkldnn::impl::utils;
-
-template <cpu_isa_t isa>
-void _jit_uni_dw_convolution_fwd_t<isa>::execute_forward(
- const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
- const memory_desc_wrapper bias_d(pd()->weights_md(1));
-
- const auto &jcp = kernel_->jcp;
-
- if (pd()->wants_padded_bias()) {
- auto padded_bias = this->scratchpad(ctx).template get<data_t>(
- key_conv_padded_bias);
- utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
- utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
- jcp.oc - jcp.oc_without_padding);
- bias = padded_bias;
- }
-
- int dil_h = jcp.dilate_h + 1;
- int dil_w = jcp.dilate_w + 1;
- int str_h = jcp.stride_h;
- int str_w = jcp.stride_w;
-
- auto kernel_params = [&](int ur_w_step, int ow, int oh, int ih, int kh,
- int kh_padding, int ch, int ch_num, int n) {
- auto par_conv = jit_conv_call_s();
-
- const int i_l_overflow = nstl::max(0, (jcp.l_pad - ow * str_w));
- const int i_r_overflow = nstl::max(jcp.iw, (ow * str_w
- + (jcp.kw - 1)*dil_w - jcp.l_pad + 1)) - jcp.iw;
-
- const int iw = nstl::max((ow*str_w - jcp.l_pad
- + div_up(i_l_overflow, dil_w)*dil_w), 0);
- const int kw = div_up(i_l_overflow, dil_w);
-
- const int kw_padding = jcp.kw - div_up(i_l_overflow, dil_w)
- - div_up(i_r_overflow, dil_w);
-
- par_conv.src = &src[src_d.blk_off(n, ch, ih, iw)];
- par_conv.dst = &dst[dst_d.blk_off(n, ch, oh, ow)];
-
- par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, kh, kw)];
- if (bias) par_conv.bias = &bias[bias_d.blk_off(ch*jcp.ch_block)];
-
- par_conv.kh_padding = (size_t)nstl::max(0, kh_padding);
- par_conv.kw_padding = (size_t)nstl::max(0, kw_padding);
-
- par_conv.ur_w = (size_t)ur_w_step;
-
- par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch;
-
- return par_conv;
- };
-
- const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
- parallel_nd(jcp.mb, chb_work, jcp.oh,
- [&](int n, int chb, int oh) {
- int ch = chb * jcp.nb_ch_blocking;
- int ch_num = jcp.nb_ch_blocking;
-
- const int i_t_overflow = nstl::max(0, (int)(jcp.t_pad - oh*str_h));
- const int i_b_overflow = nstl::max(jcp.ih,
- (int)(oh*str_h + (jcp.kh - 1)*dil_h - jcp.t_pad + 1)) - jcp.ih;
-
- const int ih = nstl::max((int)(oh*str_h - jcp.t_pad
- + div_up(i_t_overflow, dil_h)*dil_h), 0);
- const int kh = div_up(i_t_overflow, dil_h);
- const int kh_padding = jcp.kh - div_up(i_t_overflow, dil_h)
- - div_up(i_b_overflow, dil_h);
-
- // left border
- int ow = 0;
- int l_border = nstl::min(div_up(jcp.l_pad, str_w), jcp.ow);
- int ur_w_step = 1;
- for (; ow < l_border; ow++) {
- jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih,
- kh, kh_padding, ch, ch_num, n);
-
- kernel_->jit_ker(&par_conv);
- }
-
- // main loop
- ur_w_step = (jcp.iw - (jcp.kw - 1)*dil_w + jcp.l_pad - 1)
- / jcp.stride_w - ow + 1;
- if (ur_w_step > 0) {
- jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih,
- kh, kh_padding, ch, ch_num, n);
-
- kernel_->jit_ker(&par_conv);
-
- ow += ur_w_step;
- }
-
- // right border
- ur_w_step = 1;
- for (; ow < jcp.ow; ow++) {
- jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih,
- kh, kh_padding, ch, ch_num, n);
-
- kernel_->jit_ker(&par_conv);
- }
- });
-
- if (pd()->wants_zero_pad_dst())
- ctx.memory(MKLDNN_ARG_DST)->zero_pad();
-}
-
-template struct _jit_uni_dw_convolution_fwd_t<avx512_common>;
-template struct _jit_uni_dw_convolution_fwd_t<avx2>;
-template struct _jit_uni_dw_convolution_fwd_t<sse42>;
-
-template <cpu_isa_t isa>
-void _jit_uni_dw_convolution_bwd_data_t<isa>::execute_backward_data(
- const exec_ctx_t &ctx) const {
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
-
- const auto &jcp = kernel_->jcp;
-
- auto kernel_params = [&](int ur_str_w, int iw, int oh, int ih,
- int i_t_overflow, int i_b_overflow, int stride_off_h,
- int ch, int ch_num, int n) {
- auto par_conv = jit_conv_call_s();
-
- const int i_l_overflow = nstl::max(0, (jcp.kw - 1 - iw - jcp.l_pad));
- const int i_r_overflow = nstl::max(0, (jcp.kw - 1 - (jcp.iw - 1 - iw)
- - jcp.r_pad));
-
- int ow = iw + jcp.l_pad - i_r_overflow;
- int stride_off_w = ow % jcp.stride_w;
- ow /= jcp.stride_w;
-
- par_conv.src = &diff_src[diff_src_d.blk_off(n, ch, ih, iw)];
- par_conv.dst = &diff_dst[diff_dst_d.blk_off(n, ch, oh, ow)];
- par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, i_b_overflow
- + stride_off_h, i_r_overflow + stride_off_w)];
-
- par_conv.kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow
- - stride_off_h);
- par_conv.kw_padding = nstl::max(0, jcp.kw - i_l_overflow - i_r_overflow
- - stride_off_w);
-
- par_conv.ur_str_w = ur_str_w;
-
- par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch;
-
- return par_conv;
- };
-
- const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
- parallel_nd(jcp.mb, chb_work, jcp.ih,
- [&](int n, int chb, int ih) {
- int ch = chb * jcp.nb_ch_blocking;
- int ch_num = jcp.nb_ch_blocking;
-
- const int i_t_overflow = nstl::max(0, (int)(jcp.kh - 1 - ih
- - jcp.t_pad));
- const int i_b_overflow = nstl::max(0, (int)(jcp.kh - 1
- - (jcp.ih - 1 - ih) - jcp.b_pad));
-
- int oh = ih + jcp.t_pad - i_b_overflow;
- int stride_off_h = oh % jcp.stride_h;
- oh /= jcp.stride_h;
-
- for (int i_str_w = 0; i_str_w < jcp.stride_w; i_str_w++) {
- // left border
- int iw = i_str_w;
- int l_border = nstl::min(jcp.kw - 1 - jcp.l_pad, jcp.iw);
- int ur_str_w = 1;
- for (; iw < l_border; iw += jcp.stride_w) {
- jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh,
- ih, i_t_overflow, i_b_overflow,
- stride_off_h, ch, ch_num, n);
-
- kernel_->jit_ker(&par_conv);
- }
-
- // main loop
- ur_str_w = nstl::min((jcp.iw - jcp.kw + jcp.r_pad - iw)
- / jcp.stride_w, jcp.iw);
- if (ur_str_w > 0) {
- jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh,
- ih, i_t_overflow, i_b_overflow,
- stride_off_h, ch, ch_num, n);
-
- kernel_->jit_ker(&par_conv);
-
- iw += ur_str_w * jcp.stride_w;
- }
-
- // right border
- ur_str_w = 1;
- for (; iw < jcp.iw; iw += jcp.stride_w) {
- jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh,
- ih, i_t_overflow, i_b_overflow,
- stride_off_h, ch, ch_num, n);
-
- kernel_->jit_ker(&par_conv);
- }
- }
- });
-}
-
-template struct _jit_uni_dw_convolution_bwd_data_t<avx512_common>;
-template struct _jit_uni_dw_convolution_bwd_data_t<avx2>;
-template struct _jit_uni_dw_convolution_bwd_data_t<sse42>;
-
-template <cpu_isa_t isa>
-_jit_uni_dw_convolution_bwd_weights_t<isa>::
-_jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd)
- : cpu_primitive_t(apd)
- , kernel_(nullptr), acc_ker_(nullptr)
-{
- kernel_ = new jit_uni_dw_conv_bwd_weights_kernel_f32<isa>(pd()->jcp_);
- if (pd()->jcp_.nthr_mb > 1 && do_parallel_reduction())
- acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>();
-}
-
-template <cpu_isa_t isa>
-void _jit_uni_dw_convolution_bwd_weights_t<isa>::execute_backward_weights(
- const exec_ctx_t &ctx) const {
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
- auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
-
- auto diff_wei_reduction_buf =
- scratchpad(ctx).template get<data_t>(key_conv_wei_reduction);
- auto diff_bia_reduction_buf =
- scratchpad(ctx).template get<data_t>(key_conv_bia_reduction);
-
- const auto &jcp = kernel_->jcp;
-
- /* Used when executing a parallel reduction */
- simple_barrier::ctx_t reduction_bctx;
- simple_barrier::ctx_init(&reduction_bctx);
-
- const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
- const size_t bias_size = jcp.with_bias ? jcp.ngroups : 0;
-
- const int ch_block = jcp.ch_block;
-
- auto set_kernel_params = [&](jit_dw_conv_call_s *conv_params,
- const int batch, const int group, const int oh_start,
- const int work_size, const unsigned char exec_flag,
- const size_t kh_padding, const size_t filter_off) {
- const int tpad_underflow_off = jcp.t_pad - filter_off;
-
- conv_params->exec_flags = exec_flag;
- conv_params->kh_count = jcp.kh - kh_padding;
-
- const int oh_s = oh_start;
- const int oh_e = oh_start + work_size;
- const int ih_s = oh_s * jcp.stride_h;
-
- conv_params->filter_pad_off
- = filter_off * jcp.kw * ch_block * sizeof(float);
- conv_params->oh_index = oh_s;
- conv_params->oh_count = oh_e;
-
- size_t diff_dst_off
- = ((batch * (jcp.ngroups / ch_block) + group) * jcp.oh
- + oh_start)
- * jcp.ow;
-
- size_t src_off = ((batch * (jcp.ngroups / ch_block) + group) * jcp.ih
- + ih_s - tpad_underflow_off) * jcp.iw;
-
- conv_params->output = &diff_dst[diff_dst_off * ch_block];
- conv_params->input = &src[src_off * ch_block];
- };
-
- parallel(jcp.nthr, [&](const int ithr, const int nthr) {
- assert(nthr == jcp.nthr);
-
- auto conv_params = jit_dw_conv_call_s();
- const int h_block_size = 15;
-
- /* assign iteration space to thread */
- const int ithr_g = ithr % jcp.nthr_g;
- const int ithr_mb = (ithr / jcp.nthr_g) % jcp.nthr_mb;
-
- /* split dimensions */
- int g_start{ 0 }, g_end{ 0 };
- balance211(jcp.nb_ch, jcp.nthr_g, ithr_g, g_start, g_end);
-
- int mb_start{ 0 }, mb_end{ 0 };
- balance211(jcp.mb, jcp.nthr_mb, ithr_mb, mb_start, mb_end);
-
- auto diff_wei = ithr_mb == 0
- ? diff_weights : diff_wei_reduction_buf + (ithr_mb - 1) * wei_size;
- auto diff_bia = ithr_mb == 0
- ? diff_bias : diff_bia_reduction_buf + (ithr_mb - 1) * bias_size;
-
- for (int g = g_start; g < g_end; ++g) {
- unsigned char zero_filter_flag = FLAG_ZERO_FILTER;
- unsigned char zero_bias_flag = jcp.with_bias ? FLAG_ZERO_BIAS : 0;
-
- size_t diff_wei_off = g * jcp.kh * jcp.kw;
- conv_params.filter = &diff_wei[diff_wei_off * ch_block];
-
- if (jcp.with_bias)
- conv_params.bias = &diff_bia[g * ch_block];
-
- for (int mb = mb_start; mb < mb_end; ++mb) {
- int oh = 0;
- while (oh < jcp.oh) {
- const int h_work = nstl::min(h_block_size, jcp.oh - oh);
- auto kh_t_padding = nstl::max(0, jcp.t_pad - oh);
- auto kh_b_padding
- = (oh * jcp.stride_h + jcp.kh - 1 > jcp.ih) ?
- jcp.b_pad - (h_work - 1) :
- 0;
-
- set_kernel_params(&conv_params, mb, g, oh, h_work,
- zero_filter_flag | zero_bias_flag,
- kh_t_padding + kh_b_padding, kh_t_padding);
- kernel_->jit_ker(&conv_params);
-
- zero_bias_flag &= ~FLAG_ZERO_BIAS;
- zero_filter_flag &= ~FLAG_ZERO_FILTER;
- oh += h_work;
- }
- }
- }
-
- if (do_parallel_reduction() && jcp.nthr_mb > 1) {
- size_t reduct_start{ 0 }, reduct_end{ 0 };
- balance211(wei_size, nthr, ithr, reduct_start, reduct_end);
-
- const int acc_size = reduct_end - reduct_start;
- const size_t reduct_off = reduct_start;
- auto *acc_data = diff_weights + reduct_off;
-
- simple_barrier::barrier(&reduction_bctx, nthr);
-
- for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
- auto *src_data = diff_wei_reduction_buf
- + (thr_mb - 1) * wei_size + reduct_off;
- acc_ker_->accumulate(acc_data, src_data, acc_size);
- }
- }
- });
-
- if (jcp.nthr_mb <= 1) return;
-
- /* Apply single-threaded 'mb' reduction */
- for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
- size_t mb_accum_offset = (thr_mb - 1) * wei_size;
- size_t b_accum_offset = (thr_mb - 1) * bias_size;
-
- for (int g = 0; g < jcp.nb_ch; ++g) {
- /* Reduction on Bias */
- if (jcp.with_bias) {
- PRAGMA_OMP_SIMD()
- for (int g_block = 0; g_block < ch_block; ++g_block) {
- size_t bias_offset = g * ch_block + g_block;
- diff_bias[bias_offset] += diff_bia_reduction_buf[
- b_accum_offset + bias_offset];
- }
- }
-
- if (do_parallel_reduction()) continue;
-
- for (int kh = 0; kh < jcp.kh; ++kh)
- for (int kw = 0; kw < jcp.kw; ++kw)
- {
- size_t wei_offset = (g * jcp.kh + kh) * jcp.kw + kw;
- PRAGMA_OMP_SIMD()
- for (int g_block = 0; g_block < ch_block; ++g_block) {
- const size_t off = wei_offset * ch_block + g_block;
- diff_weights[off] +=
- diff_wei_reduction_buf[mb_accum_offset + off];
- }
- }
- }
- }
-}
-
-template struct _jit_uni_dw_convolution_bwd_weights_t<avx512_common>;
-template struct _jit_uni_dw_convolution_bwd_weights_t<avx2>;
-template struct _jit_uni_dw_convolution_bwd_weights_t<sse42>;
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp
deleted file mode 100644
index ca53749ec2..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp
+++ /dev/null
@@ -1,266 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_UNI_DW_CONVOLUTION_HPP
-#define CPU_JIT_UNI_DW_CONVOLUTION_HPP
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-
-#include "cpu_barrier.hpp"
-#include "cpu_convolution_pd.hpp"
-#include "cpu_primitive.hpp"
-#include "cpu_reducer.hpp"
-
-#include "jit_uni_dw_conv_kernel_f32.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <cpu_isa_t isa>
-struct _jit_uni_dw_convolution_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_fwd_pd_t {
- pd_t(engine_t *engine, const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const typename pd_t::base_class *hint_fwd_pd)
- : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""),
- _jit_uni_dw_convolution_fwd_t<isa>);
-
- status_t init() {
- bool ok = true
- && is_fwd()
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::f32, data_type::f32, data_type::f32)
- && !has_zero_dim_memory()
- && set_default_formats();
- if (!ok) return status::unimplemented;
-
- status_t status = jit_uni_dw_conv_fwd_kernel_f32<isa>::init_conf(
- jcp_, *desc(), src_md(), *weights_md(), *dst_md(), *attr());
- if (status != status::success) return status;
-
- auto scratchpad = scratchpad_registry().registrar();
- jit_uni_dw_conv_fwd_kernel_f32<isa>::init_scratchpad(scratchpad,
- jcp_);
-
- return status::success;
- }
-
- jit_conv_conf_t jcp_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
-
- auto dat_tag = isa == avx512_common ? nChw16c : nChw8c;
- auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g;
-
- return set_default_formats_common(dat_tag, wei_tag, dat_tag);
- }
- };
-
- _jit_uni_dw_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd)
- { kernel_ = new jit_uni_dw_conv_fwd_kernel_f32<isa>(pd()->jcp_); }
-
- ~_jit_uni_dw_convolution_fwd_t() { delete kernel_; }
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_uni_dw_conv_fwd_kernel_f32<isa> *kernel_;
-};
-
-using jit_avx512_common_dw_convolution_fwd_t =
- _jit_uni_dw_convolution_fwd_t<avx512_common>;
-using jit_avx2_dw_convolution_fwd_t = _jit_uni_dw_convolution_fwd_t<avx2>;
-using jit_sse42_dw_convolution_fwd_t = _jit_uni_dw_convolution_fwd_t<sse42>;
-
-template <cpu_isa_t isa>
-struct _jit_uni_dw_convolution_bwd_data_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_bwd_data_pd_t {
- pd_t(engine_t *engine,
- const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_()
- {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""),
- _jit_uni_dw_convolution_bwd_data_t);
-
- status_t init() {
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_data
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::undef, data_type::f32, data_type::f32)
- && !has_zero_dim_memory()
- && set_default_formats();
-
- if (!ok) return status::unimplemented;
-
- status_t status = jit_uni_dw_conv_bwd_data_kernel_f32<isa>::
- init_conf(jcp_, *desc(), *diff_src_md(), *weights_md(),
- *diff_dst_md());
- if (status != status::success) return status;
-
- auto scratchpad = scratchpad_registry().registrar();
- jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_scratchpad(
- scratchpad, jcp_);
-
- return status::success;
- }
-
- jit_conv_conf_t jcp_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
-
- auto dat_tag = isa == avx512_common ? nChw16c : nChw8c;
- auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g;
-
- return set_default_formats_common(dat_tag, wei_tag, dat_tag);
- }
- };
-
- _jit_uni_dw_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd)
- { kernel_ = new jit_uni_dw_conv_bwd_data_kernel_f32<isa>(pd()->jcp_); }
- ~_jit_uni_dw_convolution_bwd_data_t() { delete kernel_; };
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward_data(ctx);
- return status::success;
- }
-
-private:
- void execute_backward_data(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_uni_dw_conv_bwd_data_kernel_f32<isa> *kernel_;
-};
-
-using jit_avx512_common_dw_convolution_bwd_data_t =
- _jit_uni_dw_convolution_bwd_data_t<avx512_common>;
-using jit_avx2_dw_convolution_bwd_data_t =
- _jit_uni_dw_convolution_bwd_data_t<avx2>;
-using jit_sse42_dw_convolution_bwd_data_t =
- _jit_uni_dw_convolution_bwd_data_t<sse42>;
-
-template <cpu_isa_t isa>
-struct _jit_uni_dw_convolution_bwd_weights_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_bwd_weights_pd_t {
- pd_t(engine_t *engine,
- const convolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const convolution_fwd_pd_t *hint_fwd_pd)
- : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
- , jcp_() {}
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""),
- _jit_uni_dw_convolution_bwd_weights_t<isa>);
-
- status_t init() {
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_weights
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(data_type::f32, data_type::f32,
- data_type::f32, data_type::f32, data_type::f32)
- && !has_zero_dim_memory()
- && set_default_formats();
- if (!ok) return status::unimplemented;
-
- const int max_threads = mkldnn_in_parallel()
- ? 1 : mkldnn_get_max_threads();
-
- status_t status = jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::
- init_conf(jcp_, *desc(), *src_md(), *diff_weights_md(),
- *diff_dst_md(), max_threads);
- if (status != status::success) return status;
-
- auto scratchpad = scratchpad_registry().registrar();
- jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::init_scratchpad(
- scratchpad, jcp_);
-
- return status::success;
- }
-
- jit_conv_conf_t jcp_;
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
-
- auto dat_tag = isa == avx512_common ? nChw16c : nChw8c;
- auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g;
-
- return set_default_formats_common(dat_tag, wei_tag, dat_tag);
- }
- };
-
- _jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd);
- ~_jit_uni_dw_convolution_bwd_weights_t() {
- delete kernel_;
- delete acc_ker_;
- };
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward_weights(ctx);
- return status::success;
- }
-
-private:
- void execute_backward_weights(const exec_ctx_t &ctx) const;
- bool do_parallel_reduction() const { return false; }
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_uni_dw_conv_bwd_weights_kernel_f32<isa> *kernel_;
- cpu_accumulator_1d_t<data_type::f32> *acc_ker_;
-};
-
-using jit_avx512_common_dw_convolution_bwd_weights_t =
- _jit_uni_dw_convolution_bwd_weights_t<avx512_common>;
-using jit_avx2_dw_convolution_bwd_weights_t =
- _jit_uni_dw_convolution_bwd_weights_t<avx2>;
-using jit_sse42_dw_convolution_bwd_weights_t =
- _jit_uni_dw_convolution_bwd_weights_t<sse42>;
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp
deleted file mode 100644
index 2af6435871..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp
+++ /dev/null
@@ -1,1142 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "nstl.hpp"
-#include "utils.hpp"
-
-#include "jit_uni_eltwise.hpp"
-
-#define GET_OFF(field) offsetof(jit_args, field)
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace Xbyak;
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::injector_preamble(size_t start_idx,
- size_t end_idx) {
- preserved_vecs_count = 0;
- vecs_to_preserve = (size_t)aux_vecs_count(alg_);
- start_idx_tail = start_idx;
-
- // For sse42 mask register has to be Xmm(0)
- if (isa == sse42 && vecs_to_preserve > 0) {
- size_t idx = 0;
- assert(idx < start_idx);
- preserved_vec_idxs[preserved_vecs_count++] = idx;
- }
-
- for (size_t idx = preserved_vecs_count; idx < vecs_count; idx++) {
- if (preserved_vecs_count >= vecs_to_preserve) break;
- if (start_idx <= idx && idx < end_idx) continue;
-
- preserved_vec_idxs[preserved_vecs_count++] = idx;
- }
-
- size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count;
- for (size_t i = 0; i < preserved_vecs_count_tail; i++) {
- preserved_vec_idxs[preserved_vecs_count++] = start_idx_tail++;
- }
-
- assert(preserved_vecs_count == vecs_to_preserve);
-
- if (save_state_) {
- h->push(p_table);
-
- if (preserved_vecs_count)
- h->sub(h->rsp, preserved_vecs_count * vlen);
-
- for (size_t i = 0; i < preserved_vecs_count; ++i)
- h->uni_vmovups(h->ptr[h->rsp + i * vlen],
- Vmm(preserved_vec_idxs[i]));
-
- load_table_addr();
- }
-
- assign_regs();
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::injector_preamble_tail(size_t start_idx)
-{
- size_t tail_vecs_to_preserve = start_idx_tail - start_idx;
- if (tail_vecs_to_preserve == 0) return;
-
- const int idx_off = vecs_to_preserve - tail_vecs_to_preserve;
-
- if (save_state_) {
- if (idx_off)
- h->add(h->rsp, idx_off * vlen);
-
- for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
- h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]),
- h->ptr[h->rsp + i * vlen]);
- }
-
- for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
- preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
-
- if (save_state_) {
- for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
- h->uni_vmovups(h->ptr[h->rsp + i * vlen],
- Vmm(preserved_vec_idxs[idx_off + i]));
-
- if (idx_off)
- h->sub(h->rsp, idx_off * vlen);
- }
-
- assign_regs();
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::injector_postamble() {
- if (!save_state_) return;
-
- for (size_t i = 0; i < preserved_vecs_count; ++i)
- h->uni_vmovups(Vmm(preserved_vec_idxs[i]),
- h->ptr[h->rsp + i * vlen]);
-
- if (preserved_vecs_count)
- h->add(h->rsp, preserved_vecs_count * vlen);
-
- h->pop(p_table);
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::assign_regs() {
- vmm_mask = Vmm(preserved_vec_idxs[0]);
- vmm_aux0 = Vmm(preserved_vec_idxs[0]);
- vmm_aux1 = Vmm(preserved_vec_idxs[1]);
- vmm_aux2 = Vmm(preserved_vec_idxs[2]);
- vmm_aux3 = Vmm(preserved_vec_idxs[3]);
- vmm_aux4 = Vmm(preserved_vec_idxs[4]);
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::exp_compute_vector(const Vmm &vmm_src) {
- h->uni_vminps(vmm_src, vmm_src, table_val(10));
- h->uni_vmaxps(vmm_src, vmm_src, table_val(11));
- h->uni_vmovups(vmm_aux0, vmm_src);
- //calculate exp(x)
- // fx = x * log2ef + 0.5
- h->uni_vmulps(vmm_src, vmm_src, table_val(2));
- h->uni_vaddps(vmm_src, vmm_src, table_val(1));
-
- // tmp = floorf(fx)
- if (isa == avx512_common) {
- h->vcvtps2dq(vmm_aux1 | h->T_rd_sae, vmm_src);
- h->vcvtdq2ps(vmm_aux1, vmm_aux1);
-
- h->vcmpps(k_mask, vmm_aux1, vmm_src, _cmp_nle_us);
- h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0));
-
- h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux3);
- } else {
- h->uni_vroundps(vmm_aux1, vmm_src, _op_floor);
- }
-
- //keep fx for further computations
- h->uni_vmovups(vmm_src, vmm_aux1); //vmm_src = fx
-
- //x = x - fx * ln2
- h->uni_vfnmadd231ps(vmm_aux0, vmm_aux1, table_val(3));
-
- // compute 2^n
- h->uni_vcvtps2dq(vmm_aux1, vmm_src);
- h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4));
- h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //Vmm(6) = 2^-fx
-
- // y = p5
- h->uni_vmovups(vmm_src, table_val(9));
- // y = y * x + p4
- h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(8));
- // y = y * x + p3
- h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(7));
- // y = y * x + p2
- h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(6));
- // y = y * x + p1
- h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(0));
- // y = y * x + p0
- h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(5)); //exp(q)
- // y = y * 2^n
- h->uni_vmulps(vmm_src, vmm_src, vmm_aux1);
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::relu_compute_vector(const Vmm &vmm_src)
-{
- const int alpha_off = 0, zero_off = 1;
-
- h->uni_vmovups(vmm_aux1, vmm_src);
- if (isa == sse42) {
- h->movups(vmm_mask, vmm_src);
- h->mulps(vmm_src, table_val(alpha_off));
- h->cmpps(vmm_mask, table_val(zero_off), _cmp_nle_us);
- h->blendvps(vmm_src, vmm_aux1);
- } else if (isa == avx2) {
- h->vmulps(vmm_src, vmm_src, table_val(alpha_off));
- h->vcmpgtps(vmm_mask, vmm_aux1, table_val(zero_off));
- h->vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
- } else if (isa == avx512_common) {
- h->vmulps(vmm_src, vmm_src, table_val(alpha_off));
- h->vcmpps(k_mask, vmm_aux1, table_val(zero_off), _cmp_nle_us);
- h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1);
- }
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::relu_zero_ns_compute_vector(
- const Vmm &vmm_src) {
- const int zero_off = 1;
- h->uni_vmaxps(vmm_src, vmm_src, table_val(zero_off));
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::elu_compute_vector(const Vmm &vmm_src) {
- const int alpha_off = 23, zero_off = 24;
-
- // compute exponent
- h->uni_vmovups(vmm_aux2, vmm_src);
- exp_compute_vector(vmm_src);
-
- // alpha * (exp(x) - 1)
- h->uni_vsubps(vmm_src, vmm_src, table_val(0));
- h->uni_vmulps(vmm_src, vmm_src, table_val(alpha_off));
-
- // combine with mask
- if (isa == sse42) {
- h->pxor(vmm_mask, vmm_mask);
- h->cmpps(vmm_mask, vmm_aux2, _cmp_le_os);
- h->blendvps(vmm_src, vmm_aux2);
- } else if (isa == avx2) {
- h->uni_vcmpgtps(vmm_mask, vmm_aux2, table_val(zero_off));
- h->uni_vblendvps(vmm_src, vmm_src, vmm_aux2, vmm_mask);
- } else if (isa == avx512_common) {
- h->vcmpps(k_mask, vmm_aux2, table_val(zero_off), _cmp_nle_us);
- h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux2);
- }
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::tanh_compute_vector(const Vmm &vmm_src)
-{
- // # comes from Taylor expansion error bound
- // > linear_sat_point = single(sqrt(3) * 1b-12);
- // # comes from the exp formula cancellation
- // > exp_bound_point = (single(log(3)/2));
- // # comes from rounding accuracy in float
- // > one_sat_point = round(atanh(1 - 1b-25), single, RU);
- // > P = fpminimax(f, [|1, 3, 5, 7, 9|], [|24... |],
- // [linear_sat_point, exp_bound_point], relative, floating);
- // > err_bound = D(sup(supnorm(P, tanh(x),
- // [linear_sat_point, exp_bound_point], relative, theta)));
- // 0x1.fffd6f00b9539p-25
- // > P;
- // x * (0x1.fffffep-1 + x^0x1p1 * (-0x1.55539ep-2 + x^0x1p1 *
- // (0x1.10be3ep-3 + x^0x1p1 * (-0x1.ae57b4p-5
- // + x^0x1p1 * 0x1.09fa1p-6))))
-
- // register mapping
- // vmm_src contains input
- // vmm_aux0 contains mask of currently valid results.
- // 1 is need computation, 0 is already computed
- // vmm_aux1 contains current output
- // vmm_aux2, vmm_aux3 contains auxiliary values
- // vmm_aux4 contains the original sign of inputs
-
- Label end_tanh_label;
-
- auto test_exit =[&](Xbyak::Address threshold){
- // is not necessary for >AVX, but should not matter on perf
- h->uni_vmovups(vmm_aux0, vmm_src);
- if (isa == avx512_common){
- h->vcmpps(k_mask, vmm_aux0, threshold, 0x5);
- h->kortestw(k_mask, k_mask);
- } else {
- h->uni_vcmpgeps(vmm_aux0, vmm_aux0, threshold);
- h->uni_vtestps(vmm_aux0, vmm_aux0);
- }
- h->jz(end_tanh_label, Xbyak::CodeGenerator::T_NEAR);
- };
-
- auto blend_results=[&](Vmm vmm_partial_res){
- if (isa == avx512_common)
- h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_partial_res);
- else
- h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_partial_res, vmm_aux0);
- };
-
- // because tanh(x) = -tanh(-x), we extract sign to make x postive
- // and reapply sign at the end
- // mov is not necessary for >AVX, but should not matter for performance
- h->uni_vmovups(vmm_aux4, vmm_src);
- h->uni_vandps(vmm_aux4, vmm_aux4, table_val(12));
- h->uni_vandps(vmm_src, vmm_src, table_val(17));
-
- // if x < linear_sat_point for all inputs, we just return the input
- h->uni_vmovups(vmm_aux1, vmm_src);
- test_exit(table_val(13));
-
- // if one of the mask is one, we have to compute an better approx
- h->uni_vmovups(vmm_aux2, vmm_src);
- h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_aux2);
- h->uni_vmovups(vmm_aux3, table_val(22));
- h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(21));
- h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(20));
- h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(19));
- h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(18));
- h->uni_vmulps(vmm_aux3, vmm_aux3, vmm_src);
-
- // we blend only the result that need update
- blend_results(vmm_aux3);
-
- // if x < exp_bound_point, we go to return point
- test_exit(table_val(14));
-
- // if not we use a better approx 1 - 2 / (1 + exp(2x))
- // compute 2x
- h->uni_vmovups(vmm_aux3, vmm_src);
- h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux3);
-
- // Compute exp(2x)
- // We need to save kmask, vmm_aux0, vmm_aux1 and vmm_src as exp can use them
- // vmm_src is not more read afterwards, so we do not have to save it
- auto stack_size = 3 * vlen + (isa == avx512_common) * 4;
- h->sub(h->rsp, stack_size);
- h->uni_vmovups(h->ptr[h->rsp + 0 * vlen], vmm_aux0);
- h->uni_vmovups(h->ptr[h->rsp + 1 * vlen], vmm_aux1);
- h->uni_vmovups(h->ptr[h->rsp + 2 * vlen], vmm_src);
- if (isa == avx512_common)
- h->kmovw(h->ptr[h->rsp + 3 * vlen], k_mask);
-
- exp_compute_vector(vmm_aux3);
-
- h->uni_vmovups(vmm_aux0, h->ptr[h->rsp + 0 * vlen]);
- h->uni_vmovups(vmm_aux1, h->ptr[h->rsp + 1 * vlen]);
- h->uni_vmovups(vmm_src, h->ptr[h->rsp + 2 * vlen]);
- if (isa == avx512_common)
- h->kmovw(k_mask, h->ptr[h->rsp + 3 * vlen]);
- h->add(h->rsp, stack_size);
-
- // 1 + exp(2x)
- h->uni_vaddps(vmm_aux3, vmm_aux3, table_val(0));
-
- // 1 - 2 / (1 + exp(2x))
- h->uni_vmovups(vmm_aux2, table_val(16));
- h->uni_vdivps(vmm_aux2, vmm_aux2, vmm_aux3);
- h->uni_vaddps(vmm_aux2, vmm_aux2, table_val(0));
-
- // we blend only the result that need update
- blend_results(vmm_aux2);
-
- // finally, we saturate to 1 if needed
- // TODO: maybe move that up if most inputs saturate in practice
- if (isa == avx512_common)
- h->vcmpps(k_mask, vmm_aux0, table_val(15), 0x5);
- else {
- h->uni_vmovups(vmm_aux0, vmm_src);
- h->uni_vcmpgeps(vmm_aux0, vmm_aux0, table_val(15));
- }
- h->uni_vmovups(vmm_aux2, table_val(0));
- blend_results(vmm_aux2);
-
- h->L(end_tanh_label);
- {
- // we apply the sign of x to the result and we are done
- h->uni_vmovups(vmm_src, vmm_aux1);
- h->uni_vpxor(vmm_src, vmm_src, vmm_aux4);
- }
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::square_compute_vector(
- const Vmm &vmm_src) {
- h->uni_vmulps(vmm_src, vmm_src, vmm_src);
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::abs_compute_vector(const Vmm &vmm_src) {
- // compute abs(x) = _mm_and_ps(x, 01111..111));
- h->uni_vandps(vmm_src, vmm_src, table_val(0));
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::sqrt_compute_vector(const Vmm &vmm_src)
-{
- if (isa == avx512_common) {
- h->vcmpps(k_mask, vmm_src, table_val(0), _cmp_nle_us);
- h->uni_vsqrtps(vmm_aux1, vmm_src);
- h->uni_vmovups(vmm_src, table_val(0));
- h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1);
- } else {
- h->uni_vmovups(vmm_mask, vmm_src);
- h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(0));
- h->uni_vsqrtps(vmm_aux1, vmm_src);
- h->uni_vmovups(vmm_src, table_val(0));
- h->uni_vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
- }
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::linear_compute_vector(
- const Vmm &vmm_src) {
- // compute x = alpha * x + beta;
- h->uni_vmovups(vmm_aux0, table_val(0));
- h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(1));
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::bounded_relu_compute_vector(
- const Vmm &vmm_src) {
- // compute bounded relu */
- h->uni_vmaxps(vmm_src, vmm_src, table_val(1));
- h->uni_vminps(vmm_src, vmm_src, table_val(0));
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::soft_relu_compute_vector(
- const Vmm &vmm_src) {
- // duplicate src
- h->uni_vmovups(vmm_aux2, vmm_src);
-
- h->uni_vminps(vmm_src, vmm_src, table_val(24));
- h->uni_vmaxps(vmm_src, vmm_src, table_val(25));
- h->uni_vmovups(vmm_aux1, vmm_src);
- // calculate exp(x)
- // fx = x * log2ef + 0.5
- h->uni_vmulps(vmm_src, vmm_src, table_val(2));
- h->uni_vaddps(vmm_src, vmm_src, table_val(1));
-
- // tmp = floorf(fx)
- if (isa == avx512_common) {
- h->vcvtps2dq(vmm_aux0 | h->T_rd_sae, vmm_src);
- h->vcvtdq2ps(vmm_aux0, vmm_aux0);
-
- h->vcmpps(k_mask, vmm_aux0, vmm_src, _cmp_nle_us);
- h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0));
-
- h->vsubps(vmm_aux0, vmm_aux0, vmm_aux3);
- } else {
- h->uni_vroundps(vmm_aux0, vmm_src, _op_floor);
- }
-
- // keep fx for further computations
- h->uni_vmovups(vmm_src, vmm_aux0); //vmm_src = fx
- // calculation fx * ln2
- h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(3));
- // x = x - fx * ln2
- h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0);
- // y = p5
- h->uni_vmovups(vmm_aux3, table_val(22));
- // y = y * x + p4
- h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(21));
- // y = y * x + p3
- h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(20));
- // y = y * x + p2
- h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(19));
- // y = y * x + p1
- h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(0));
- // y = y * x + p0
- h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(17));
-
- // compute 2^(-n)
- if (isa == avx512_common) {
- h->vmulps(vmm_aux1, vmm_src, table_val(23));
- h->vcvtps2dq(vmm_aux1, vmm_aux1);
- } else {
- h->uni_vcvtps2dq(vmm_aux1, vmm_src);
- h->uni_vpsignd(vmm_aux1, vmm_aux1, table_val(23));
- }
-
- h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4));
- h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //vmm_aux1 = 2^-fx
- // calculate ln(1 + y)
- h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1);
- // x = y; y is free; keep x for further computations
- h->uni_vmovups(vmm_src, vmm_aux3);
- // frexp()
- h->uni_vpsrld(vmm_src, vmm_src, 23);
- h->uni_vcvtdq2ps(vmm_src, vmm_src);
- // got n. where n is x = 2^n * y. y = 0.5 .. 1
- h->uni_vsubps(vmm_src, vmm_src, table_val(5));
-
- h->uni_vandps(vmm_aux3, vmm_aux3, table_val(6));
- // got y. (mantisa) 0.5 < y < 1
- h->uni_vorps(vmm_aux3, vmm_aux3, table_val(7));
- // y = y - 1
- h->uni_vsubps(vmm_aux3, vmm_aux3, table_val(0));
- // y = p8
- h->uni_vmovups(vmm_aux1, table_val(16));
- // y = y * x + p7
- h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(15));
- // y = y * x + p6
- h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(14));
- // y = y * x + p5
- h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(13));
- // y = y * x + p4
- h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(12));
- // y = y * x + p3
- h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(11));
- // y = y * x + p2
- h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(10));
- // y = y * x + p1
- h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(9));
- // y = y * x + p0 ; p0 = 0
- h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(8));
- //calculate ln(2) * n
- h->uni_vmulps(vmm_src, vmm_src, table_val(3));
- h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_src);
- h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux0);
-
- // get vmm_mask = src > max logf
- h->uni_vmovups(vmm_mask, vmm_aux2);
- if (isa == avx512_common) {
- // y = (x < max log f) ? soft_relu(x) : x
- h->vcmpps(k_mask, vmm_mask, table_val(24), _cmp_nle_us);
- h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_aux2);
- } else {
- // y = (x < max log f) ? soft_relu(x) : x
- h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(24));
- h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_aux2, vmm_mask);
- }
-
- h->uni_vmovups(vmm_src, vmm_aux1);
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::logistic_compute_vector(
- const Vmm &vmm_src) {
- // we store the original sign and make x negative
- // IMPORTANT: we assume vmm_aux0 to be xmm0, as for sse4.2 path it is required
- // IMPORTANT: we use vmm_aux2 for the mask as exp_compute does not use it.
- h->uni_vmovups(vmm_aux2, vmm_src);
- h->uni_vandps(vmm_aux2, vmm_aux2, table_val(12));
- h->uni_vorps(vmm_src, vmm_src, table_val(12));
-
- exp_compute_vector(vmm_src);
- // dup exp(x)
- h->uni_vmovups(vmm_aux1, vmm_src);
- // (exp(x) + 1)
- h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(0));
- // y = exp(x) / (exp(x) + 1)
- h->uni_vdivps(vmm_src, vmm_src, vmm_aux1);
-
- // Now we have to apply the "symmetry" based on original sign
- h->uni_vmovups(vmm_aux3, table_val(0));
- h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src);
- if (isa == avx512_common) {
- h->vptestmd(k_mask, vmm_aux2, vmm_aux2);
- h->vblendmps(vmm_aux3 | k_mask, vmm_aux3, vmm_src);
- } else {
- h->uni_vmovups(vmm_aux0, vmm_aux2);// The mask should be xmm0 for sse4.2
- h->uni_vblendvps(vmm_aux3, vmm_aux3, vmm_src, vmm_aux0);
- }
- h->uni_vmovups(vmm_src, vmm_aux3);
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::relu_prepare_table() {
- for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
- for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::elu_prepare_table() {
- const unsigned int cvals[] = {
- 0x3f800000, // [0] 1.0f
- 0x3f000000, // [1] 0.5f
- 0x3fb8aa3b, // [2] log2ef = 1.44269502f
- 0x3f317218, // [3] ln2f = 0.69314718f
- 0x0000007f, // [4] 0x7f
- // exp(x) polynom
- 0x3f800001, // [5] p0 = 1.0000001f
- 0x3efffe85, // [6] p2 = 0.4999887f
- 0x3e2aaa3e, // [7] p3 = 0.16666505f
- 0x3d2bb1b1, // [8] p4 = 0.041917507f
- 0x3c091ec1, // [9] p5 = 0.008369149f
- 0x42b0c0a5, //[10] max logf = 88.3762589f
- 0xc1766666, //[11] min logf = -14.5f
- // tanh(x) constants,
- 0x80000000, //[12] mask to extract sign
- 0x39ddb3d7, //[13] arg below which tanh(x) = x
- 0x3f0c9f54, //[14] arg below which pol approx is valid
- 0x41102cb4, //[15] arg after which tanh(x) = 1
- 0xc0000000, //[16] -2.0f
- 0x7fffffff, //[17] mask to make positive
- // tanh pol approx
- 0x3f7fffff, //[18] p0
- 0xbeaaa9cf, //[19] p1
- 0x3e085f1f, //[20] p2
- 0xbd572bda, //[21] p3
- 0x3c84fd08, //[22] p4
- };
-
- for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
- for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(cvals[i]);
- }
-
- for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
- for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::soft_relu_prepare_table() {
- const unsigned int cvals[] = {
- 0x3f800000, // [0] 1.0f
- 0x3f000000, // [1] 0.5f
- 0x3fb8aa3b, // [2] log2ef = 1.44269502f
- 0x3f317218, // [3] ln2f = 0.69314718f
- 0x0000007f, // [4] 0x7f
- 0x42fc0000, // [5] 126
- 0x807fffff, // [6] and with (to get 0.5 * mantissa)
- 0x3f000000, // [7] or with (to get 0.5 * mantissa)
- // ln(1 + x) polynomial
- 0xb2b4637d, // [8] p0 = 0.0000000244f
- 0x3f7fff8e, // [9] p1 = 0.9999976971f
- 0xbf001759, //[10] p2 = -0.5002478215f
- 0x3ea70608, //[11] p3 = 0.3272714505f
- 0xbea3d7bf, //[12] p4 = -0.3153830071f
- 0xbe361d04, //[13] p5 = -0.1701777461f
- 0xbfa8f1e6, //[14] p6 = -1.3254635147f
- 0xbfe1e812, //[15] p7 = -1.7971917960f
- 0xbfc4d30e, //[16] p8 = -1.5652673123f
- // exp(x) polynomial
- 0x3f800001, //[17] p0 = 1.0000001f
- 0x3f800000, //[18] p1 = 1.0f
- 0x3efffe85, //[19] p2 = 0.4999887f
- 0x3e2aaa3e, //[20] p3 = 0.16666505f
- 0x3d2bb1b1, //[21] p4 = 0.041917507f
- 0x3c091ec1, //[22] p5 = 0.008369149f
- 0xbf800000, //[23] is required for sign changing
- 0x42b0c0a5, //[24] max logf = 88.3762589f
- 0xc1766666 //[25] min logf = -14.5f
- };
-
- for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
- for (size_t d = 0; d < vlen / sizeof(float); ++d) {
- h->dd(cvals[i]);
- }
- }
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::abs_prepare_table() {
- for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0x7fffffff);
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::sqrt_prepare_table() {
- for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::linear_prepare_table() {
- for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
- for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_));
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::bounded_relu_prepare_table() {
- for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
- for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
-}
-
-template <cpu_isa_t isa>
-int jit_uni_eltwise_injector_f32<isa>::aux_vecs_count(alg_kind_t alg_) {
- switch (alg_) {
- case alg_kind::eltwise_relu: return (alpha_ == 0.f) ? 0 : 2;
- case alg_kind::eltwise_elu: return 4;
- case alg_kind::eltwise_tanh: return 5;
- case alg_kind::eltwise_square: return 0;
- case alg_kind::eltwise_abs: return 0;
- case alg_kind::eltwise_sqrt: return 2;
- case alg_kind::eltwise_linear: return 1;
- case alg_kind::eltwise_bounded_relu: return 0;
- case alg_kind::eltwise_soft_relu: return 4;
- case alg_kind::eltwise_logistic: return 4;
- default: assert(!"unsupported eltwise algorithm");
- }
-
- return 0;
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::compute_body(size_t start_idx,
- size_t end_idx) {
- using namespace alg_kind;
- for (size_t idx = start_idx; idx < end_idx; idx++) {
- switch (alg_) {
- case eltwise_relu:
- if (alpha_ == 0.f) relu_zero_ns_compute_vector(Vmm(idx));
- else relu_compute_vector(Vmm(idx));
- break;
- case eltwise_elu: elu_compute_vector(Vmm(idx)); break;
- case eltwise_tanh: tanh_compute_vector(Vmm(idx)); break;
- case eltwise_square: square_compute_vector(Vmm(idx)); break;
- case eltwise_abs: abs_compute_vector(Vmm(idx)); break;
- case eltwise_sqrt: sqrt_compute_vector(Vmm(idx)); break;
- case eltwise_linear: linear_compute_vector(Vmm(idx)); break;
- case eltwise_bounded_relu: bounded_relu_compute_vector(Vmm(idx)); break;
- case eltwise_soft_relu: soft_relu_compute_vector(Vmm(idx)); break;
- case eltwise_logistic: logistic_compute_vector(Vmm(idx)); break;
- default: assert(!"unsupported eltwise algorithm");
- }
- }
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::compute_vector_range(size_t start_idx,
- size_t end_idx) {
- assert(start_idx < end_idx && end_idx <= vecs_count);
-
- injector_preamble(start_idx, end_idx);
- compute_body(start_idx_tail, end_idx);
- injector_preamble_tail(start_idx);
- compute_body(start_idx, start_idx_tail);
- injector_postamble();
-}
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_injector_f32<isa>::prepare_table(bool gen_table) {
- using namespace alg_kind;
-
- h->align(64);
- h->L(l_table);
-
- if (gen_table) {
- switch (alg_) {
- case eltwise_relu: relu_prepare_table(); break;
- case eltwise_elu:
- case eltwise_tanh:
- case eltwise_logistic:
- elu_prepare_table(); break;
- case eltwise_soft_relu: soft_relu_prepare_table(); break;
- case eltwise_abs: abs_prepare_table(); break;
- case eltwise_sqrt: sqrt_prepare_table(); break;
- case eltwise_linear: linear_prepare_table(); break;
- case eltwise_bounded_relu: bounded_relu_prepare_table(); break;
- case eltwise_square: break;
- default: assert(!"unsupported eltwise algorithm");
- }
- }
-}
-
-template struct jit_uni_eltwise_injector_f32<avx512_common>;
-template struct jit_uni_eltwise_injector_f32<avx2>;
-template struct jit_uni_eltwise_injector_f32<sse42>;
-
-
-struct jit_args {
- const float *from;
- const float *for_comparison;
- const float *to;
- size_t work_amount;
-};
-
-struct jit_uni_eltwise_kernel_f32 : public c_compatible {
- const eltwise_desc_t &desc_;
-
- void (*ker_)(const jit_args *);
- void operator()(const jit_args *args) { assert(ker_); ker_(args); }
-
- jit_uni_eltwise_kernel_f32(const eltwise_desc_t &desc)
- : desc_(desc), ker_(nullptr) {}
- virtual ~jit_uni_eltwise_kernel_f32() {}
-
-protected:
- bool is_bwd() const { return desc_.prop_kind == prop_kind::backward_data; }
-};
-
-/* jit kernels */
-namespace {
-
-template <cpu_isa_t isa>
-struct jit_uni_relu_kernel_f32 : public jit_uni_eltwise_kernel_f32,
- public jit_generator
-{
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_relu_kernel_f32)
-
- void compute_step(bool vectorize, const int uf, const int shift) {
- for (int i = 0; i < uf; i++) {
- if (vectorize) {
- uni_vmovups(Vmm(i + 1), ptr[reg_from + i * shift]);
- if (is_bwd())
- uni_vmovups(Vmm(uf + i + 1),
- ptr[reg_for_comparison + i * shift]);
- } else {
- movss(Xmm(i + 1), ptr[reg_from + i * shift]);
- if (is_bwd())
- movss(Xmm(uf + i + 1),
- ptr[reg_for_comparison + i * shift]);
- }
- }
-
- if (isa == sse42) {
- for (int i = 0; i < uf; i++) {
- movups(Vmm(2 * uf + i + 1), Vmm(i + 1));
- mulps(Vmm(2 * uf + i + 1), vmm_ns);
-
- Vmm mask = Vmm(0);
- if (is_bwd()) {
- movups(mask, Vmm(uf + i + 1));
- cmpps(mask, vmm_zero, _cmp_nle_us);
- } else {
- movups(mask, Vmm(i + 1));
- cmpps(mask, vmm_zero, _cmp_nle_us);
- }
- blendvps(Vmm(2 * uf + i + 1), Vmm(i + 1));
- }
- } else {
- for (int i = 0; i < uf; i++) {
- vmulps(Vmm(2 * uf + i + 1), Vmm(i + 1), vmm_ns);
- if (isa == avx2) {
- if (is_bwd())
- vcmpgtps(vmm_mask, Vmm(uf + i + 1), vmm_zero);
- else
- vcmpgtps(vmm_mask, Vmm(i + 1), vmm_zero);
-
- vblendvps(Vmm(2 * uf + i + 1), Vmm(2 * uf + i + 1),
- Vmm(i + 1), vmm_mask);
-
- } else {
- if (is_bwd())
- vcmpps(k_mask, Vmm(uf + i + 1), vmm_zero, _cmp_nle_us);
- else
- vcmpps(k_mask, Vmm(i + 1), vmm_zero, _cmp_nle_us);
- vblendmps(Vmm(2 * uf + i + 1) | k_mask, Vmm(2 * uf + i + 1),
- Vmm(i + 1));
- }
- }
- }
-
- for (int i = 0; i < uf; i++) {
- if (vectorize) {
- uni_vmovups(ptr[reg_to + i * shift], Vmm(2 * uf + i + 1));
- } else {
- movss(ptr[reg_to + i * shift], Xmm(2 * uf + i + 1));
- }
- }
- }
-
- jit_uni_relu_kernel_f32(const eltwise_desc_t &desc)
- : jit_uni_eltwise_kernel_f32(desc), jit_generator() {
- assert(desc.alg_kind == alg_kind::eltwise_relu);
- assert(isa == sse42 || isa == avx2 || isa == avx512_common);
-
- Reg64 param = abi_param1;
-
- const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
- const int loop_dec[] = {simd_w, 1};
- const int uf[] = {1, 1};
- const int shift[] = {cpu_isa_traits<isa>::vlen, sizeof(float)};
- const bool loop_vectorize[] = {true, false};
-
- this->preamble();
-
- mov(reg_from, ptr[param + GET_OFF(from)]);
- if (is_bwd())
- mov(reg_for_comparison, ptr[param + GET_OFF(for_comparison)]);
- mov(reg_to, ptr[param + GET_OFF(to)]);
- mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
-
- mov(imm_addr64, float2int(desc.alpha));
- movq(xmm_ns, imm_addr64);
- uni_vbroadcastss(vmm_ns, xmm_ns);
-
- uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
-
- Label loop_label[3];
-
- for (int id = 0; id < 2; id++) {
- L(loop_label[id]);
- cmp(reg_work_amount, uf[id] * loop_dec[id] - 1);
- jle(loop_label[id + 1], T_NEAR);
-
- compute_step(loop_vectorize[id], uf[id], shift[id]);
-
- add(reg_from, uf[id] * shift[id]);
- add(reg_to, uf[id] * shift[id]);
- if (is_bwd())
- add(reg_for_comparison, uf[id] * shift[id]);
-
- sub(reg_work_amount, uf[id] * loop_dec[id]);
- jmp(loop_label[id]);
- }
-
- L(loop_label[2]);
- this->postamble();
-
- ker_ = (decltype(ker_))this->getCode();
- }
-
-private:
- using Vmm = typename utils::conditional3<isa == sse42, Xmm,
- isa == avx2, Ymm, Zmm>::type;
-
- Reg64 reg_from = rax;
- Reg64 reg_for_comparison = is_bwd() ? rdx : reg_from;
- Reg64 reg_to = r8;
- Reg64 reg_work_amount = rsi;
- Reg64 imm_addr64 = rbx;
-
- Xmm xmm_ns = Xmm(14);
-
- Vmm vmm_ns = Vmm(isa == avx512_common ? 30 : 14);
- Vmm vmm_zero = Vmm(isa == avx512_common ? 31 : 15);
-
- Vmm vmm_mask = Vmm(isa == avx512_common ? 28 : 12);
- Opmask k_mask = Opmask(1);
-};
-
-template <cpu_isa_t isa>
-struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32,
- public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_kernel_fwd_f32)
-
- jit_uni_kernel_fwd_f32(const eltwise_desc_t &desc)
- : jit_uni_eltwise_kernel_f32(desc), jit_generator() {
-
- eltwise_injector_ = new jit_uni_eltwise_injector_f32<isa>(this,
- desc.alg_kind, desc.alpha, desc.beta, false, r9, Opmask(1));
-
- using namespace alg_kind;
-
- assert(is_bwd() == false);
- assert(utils::one_of(desc.alg_kind, eltwise_tanh, eltwise_elu,
- eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
- eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic));
-
- preamble();
-
- Reg64 param = abi_param1;
- mov(reg_from, ptr[param + GET_OFF(from)]);
- mov(reg_to, ptr[param + GET_OFF(to)]);
- mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
- eltwise_injector_->load_table_addr();
-
- Label reminder_loop_start, reminder_loop_end;
- Label vectorized_loop_start, vectorized_loop_end;
-
- cmp(reg_work_amount, simd_w);
- jl(reminder_loop_start, T_NEAR);
-
- L(vectorized_loop_start);
-
- uni_vmovups(vmm_src, ptr[reg_from]);
- eltwise_injector_->compute_vector(vmm_src.getIdx());
- uni_vmovups(ptr[reg_to], vmm_src);
-
- add(reg_from, vlen);
- add(reg_to, vlen);
-
- sub(reg_work_amount, simd_w);
- cmp(reg_work_amount, simd_w);
- jge(vectorized_loop_start, T_NEAR);
-
- L(vectorized_loop_end);
-
- L(reminder_loop_start);
-
- cmp(reg_work_amount, 0);
- jle(reminder_loop_end, T_NEAR);
-
- movss(xmm_src, ptr[reg_from]);
- eltwise_injector_->compute_vector(xmm_src.getIdx());
- movss(ptr[reg_to], xmm_src);
-
- add(reg_from, sizeof(float));
- add(reg_to, sizeof(float));
-
- dec(reg_work_amount);
- jmp(reminder_loop_start, T_NEAR);
-
- L(reminder_loop_end);
-
- postamble();
-
- eltwise_injector_->prepare_table();
-
- ker_ = (decltype(ker_))this->getCode();
- }
-
- ~jit_uni_kernel_fwd_f32() { delete eltwise_injector_; }
-
-private:
- using Vmm = typename utils::conditional3<isa == sse42, Xmm,
- isa == avx2, Ymm, Zmm>::type;
-
- const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
- const int vlen = cpu_isa_traits<isa>::vlen;
-
- Reg64 reg_from = rax;
- Reg64 reg_to = r8;
- Reg64 reg_work_amount = rsi;
- Reg64 imm_addr64 = rbx;
-
- Xmm xmm_src = Xmm(1);
- Vmm vmm_src = Vmm(1);
-
- jit_uni_eltwise_injector_f32<isa> *eltwise_injector_;
-};
-
-} /* namespace */
-
-template <cpu_isa_t isa>
-status_t jit_uni_eltwise_fwd_t<isa>::pd_t::init() {
- using namespace alg_kind;
-
- bool ok = true
- && mayiuse(isa)
- && is_fwd()
- && utils::everyone_is(data_type::f32, desc()->data_desc.data_type)
- && !has_zero_dim_memory()
- && utils::one_of(desc()->alg_kind, eltwise_relu, eltwise_tanh,
- eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt,
- eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu,
- eltwise_logistic)
- && memory_desc_wrapper(src_md()).is_dense(true)
- && IMPLICATION(!memory_desc_wrapper(src_md()).is_dense(false),
- math::eltwise_fwd_preserves_zero(desc()->alg_kind, true))
- && attr()->has_default_values();
-
- return ok ? status::success : status::unimplemented;
-}
-
-template <cpu_isa_t isa>
-jit_uni_eltwise_fwd_t<isa>::jit_uni_eltwise_fwd_t(const pd_t *apd)
- : cpu_primitive_t(apd), kernel_(nullptr) {
- const auto &desc = *pd()->desc();
- switch (desc.alg_kind) {
- case alg_kind::eltwise_relu:
- kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
- default:
- kernel_ = new jit_uni_kernel_fwd_f32<isa>(desc);
- }
-}
-
-template <cpu_isa_t isa>
-jit_uni_eltwise_fwd_t<isa>::~jit_uni_eltwise_fwd_t()
-{ delete kernel_; }
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_fwd_t<isa>::execute_forward(const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper data_d(pd()->src_md());
-
- const size_t nelems = data_d.nelems(true);
-
- src += data_d.offset0();
- dst += data_d.offset0();
-
- parallel(0, [&](const int ithr, const int nthr) {
- size_t start{0}, end{0};
-
- const int cache_line = 16;
-
- balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
- start = nstl::min(nelems, start * cache_line);
- end = nstl::min(nelems, end * cache_line);
-
- auto arg = jit_args();
- arg.from = &src[start];
- arg.for_comparison = &src[start];
- arg.to = &dst[start];
- arg.work_amount = end - start;
- if (arg.work_amount)
- (*kernel_)(&arg);
- });
-}
-
-template <cpu_isa_t isa>
-status_t jit_uni_eltwise_bwd_t<isa>::pd_t::init() {
- bool ok = true
- && !is_fwd()
- && utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu)
- && src_md()->data_type == data_type::f32
- && !has_zero_dim_memory()
- && mayiuse(isa)
- && memory_desc_wrapper(src_md()).is_dense()
- && memory_desc_wrapper(diff_dst_md()) == memory_desc_wrapper(src_md())
- && attr()->has_default_values();
-
- return ok ? status::success : status::unimplemented;
-}
-
-template <cpu_isa_t isa>
-jit_uni_eltwise_bwd_t<isa>::jit_uni_eltwise_bwd_t(const pd_t *apd)
- : cpu_primitive_t(apd), kernel_(nullptr) {
- const auto &desc = *pd()->desc();
- switch (desc.alg_kind) {
- case alg_kind::eltwise_relu:
- kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
- default: assert(!"unknown eltwise alg_kind");
- }
-}
-
-template <cpu_isa_t isa>
-jit_uni_eltwise_bwd_t<isa>::~jit_uni_eltwise_bwd_t()
-{ delete kernel_; }
-
-template <cpu_isa_t isa>
-void jit_uni_eltwise_bwd_t<isa>::execute_backward(const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const memory_desc_wrapper data_d(pd()->src_md());
- const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
-
- const size_t nelems = data_d.nelems();
-
- src += data_d.offset0();
- diff_dst += diff_data_d.offset0();
- diff_src += diff_data_d.offset0();
-
- parallel(0, [&](const int ithr, const int nthr) {
- size_t start{0}, end{0};
-
- const int cache_line = 16;
-
- balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
- start = nstl::min(nelems, start * cache_line);
- end = nstl::min(nelems, end * cache_line);
-
- auto arg = jit_args();
- arg.from = &diff_dst[start];
- arg.to = &diff_src[start];
- arg.for_comparison = &src[start];
- arg.work_amount = end - start;
- if (arg.work_amount)
- (*kernel_)(&arg);
- });
-}
-
-template struct jit_uni_eltwise_fwd_t<sse42>;
-template struct jit_uni_eltwise_bwd_t<sse42>;
-template struct jit_uni_eltwise_fwd_t<avx2>;
-template struct jit_uni_eltwise_bwd_t<avx2>;
-template struct jit_uni_eltwise_fwd_t<avx512_common>;
-template struct jit_uni_eltwise_bwd_t<avx512_common>;
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp
deleted file mode 100644
index 45436b9f46..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp
+++ /dev/null
@@ -1,193 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_UNI_ELTWISE_HPP
-#define CPU_JIT_UNI_ELTWISE_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_eltwise_pd.hpp"
-#include "cpu_primitive.hpp"
-
-#include "jit_generator.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <cpu_isa_t isa>
-struct jit_uni_eltwise_injector_f32 {
- using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
- isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
-
- jit_uni_eltwise_injector_f32(jit_generator *host, alg_kind_t alg,
- float alpha, float beta, bool save_state = true,
- Xbyak::Reg64 p_table = Xbyak::util::rax,
- Xbyak::Opmask k_mask = Xbyak::Opmask(1))
- : alg_(alg), alpha_(alpha), beta_(beta), h(host)
- , save_state_(save_state), p_table(p_table), k_mask(k_mask)
- {
- using namespace alg_kind;
- assert(utils::one_of(isa, sse42, avx2, avx512_common));
- assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
- eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
- eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic));
- }
-
- // note that eltwise.scale is ignored
- jit_uni_eltwise_injector_f32(jit_generator *host,
- const post_ops_t::entry_t::eltwise_t &eltwise,
- bool save_state = true, Xbyak::Reg64 p_table = Xbyak::util::rax,
- Xbyak::Opmask k_mask = Xbyak::Opmask(1))
- : jit_uni_eltwise_injector_f32(host, eltwise.alg, eltwise.alpha,
- eltwise.beta, save_state, p_table, k_mask) {}
-
- void compute_vector_range(size_t start_idx, size_t end_idx);
- void compute_vector(size_t idx) { compute_vector_range(idx, idx + 1); }
- void prepare_table(bool gen_table=true);
- void load_table_addr() { h->mov(p_table, l_table); }
-
- const alg_kind_t alg_;
- const float alpha_;
- const float beta_;
-
- jit_generator * const h;
-
- const bool save_state_;
- const Xbyak::Reg64 p_table;
- const Xbyak::Opmask k_mask;
- Xbyak::Label l_table;
-
-private:
- // if only the injector was inherited from jit_generator...
- enum {
- _cmp_le_os = jit_generator::_cmp_le_os,
- _cmp_nle_us = jit_generator::_cmp_nle_us,
- _op_floor = jit_generator::_op_floor,
- };
-
- size_t vlen = cpu_isa_traits<isa>::vlen;
-
- const static size_t preserved_vecs_max = 5;
-
- size_t vecs_to_preserve = 0;
- size_t vecs_count = isa == avx512_common ? 32 : 16;
- size_t preserved_vecs_count = 0;
- size_t preserved_vec_idxs[preserved_vecs_max] = {0};
- size_t start_idx_tail = 0;
-
- Vmm vmm_mask, vmm_aux0, vmm_aux1, vmm_aux2, vmm_aux3, vmm_aux4;
-
- Xbyak::Address table_val(int index)
- { return h->ptr[p_table + index * vlen]; }
-
- int aux_vecs_count(alg_kind_t alg);
-
- void compute_body(size_t start_idx, size_t end_idx);
- void injector_preamble(size_t start_idx, size_t end_idx);
- void injector_preamble_tail(size_t start_idx);
- void injector_postamble();
- void assign_regs();
-
- void exp_compute_vector(const Vmm &vmm_src);
- void relu_compute_vector(const Vmm &vmm_src);
- void relu_zero_ns_compute_vector(const Vmm &vmm_src);
- void elu_compute_vector(const Vmm &vmm_src);
- void tanh_compute_vector(const Vmm &vmm_src);
- void square_compute_vector(const Vmm &vmm_src);
- void abs_compute_vector(const Vmm &vmm_src);
- void sqrt_compute_vector(const Vmm &vmm_src);
- void linear_compute_vector(const Vmm &vmm_src);
- void bounded_relu_compute_vector(const Vmm &vmm_src);
- void soft_relu_compute_vector(const Vmm &vmm_src);
- void logistic_compute_vector(const Vmm &vmm_src);
-
- void relu_prepare_table();
- void elu_prepare_table();
- void soft_relu_prepare_table();
- void abs_prepare_table();
- void sqrt_prepare_table();
- void linear_prepare_table();
- void bounded_relu_prepare_table();
-};
-
-struct jit_uni_eltwise_kernel_f32;
-
-template <cpu_isa_t isa>
-struct jit_uni_eltwise_fwd_t : public cpu_primitive_t {
- struct pd_t : public cpu_eltwise_fwd_pd_t {
- using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t;
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit:", isa, ""),
- jit_uni_eltwise_fwd_t<isa>);
-
- status_t init();
- };
-
- jit_uni_eltwise_fwd_t(const pd_t *apd);
- ~jit_uni_eltwise_fwd_t();
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- jit_uni_eltwise_kernel_f32 *kernel_;
-};
-
-template <cpu_isa_t isa>
-struct jit_uni_eltwise_bwd_t : public cpu_primitive_t {
- struct pd_t : public cpu_eltwise_bwd_pd_t {
- using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t;
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit:", isa, ""),
- jit_uni_eltwise_bwd_t<isa>);
-
- status_t init();
- };
-
- jit_uni_eltwise_bwd_t(const pd_t *apd);
- ~jit_uni_eltwise_bwd_t();
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward(ctx);
- return status::success;
- }
-
-private:
- void execute_backward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- jit_uni_eltwise_kernel_f32 *kernel_;
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp
deleted file mode 100644
index a3ca6273a0..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp
+++ /dev/null
@@ -1,949 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "jit_uni_i8i8_pooling.hpp"
-
-#include <math.h>
-
-#include "mkldnn_thread.hpp"
-#include "utils.hpp"
-
-#include "jit_generator.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace Xbyak;
-
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::types;
-using namespace alg_kind;
-
-template <cpu_isa_t isa>
-struct jit_uni_i8i8_pooling_fwd_ker_t: public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_i8i8_pooling_fwd_ker_t)
-
- struct call_params_t {
- const char *src_i8;
- const char *dst_i8;
- size_t kw_range;
- size_t kh_range;
- float idivider;
- };
-
- using Vmm = typename cpu_isa_traits<isa>::Vmm;
- Xmm xreg(int idx) const { return Xmm(idx); }
- Ymm yreg(int idx) const { return Ymm(xreg(idx).getIdx()); }
- Vmm vreg(int idx) const { return Vmm(xreg(idx).getIdx()); }
-
- // In case of avx2 with data type i8 we need to use
- // maskmovdqu instruction which has its destination hardcoded in rdi.
- // Windows ABI: abi_param1 is rcx - nothing to do else
- // Unix ABI: abi_param1 is rdi - copy it to rcx and use it as abi_param1
- Reg64 reg_param = rcx; // Our "unified abi_param1"
- Reg64 reg_ptr_src_i8 = r8;
- Reg64 reg_ptr_dst_i8 = r9;
- Reg64 reg_ptr_maskmovdqu_dst = rdi; // store destination - must be rdi
-
- Reg64 ki = r10;
- Reg64 kj = r11;
- Reg64 reg_kw = r12;
- Reg64 reg_kh = r13;
- Reg64 c_iter = r14;
-
- Reg64 aux_reg_src_h = rax;
- Reg64 aux_reg_src_w = rbx;
-
- Reg64 reg_tmp = rdx;
-
- Reg64 reg_mask = r15;
-
- Opmask k_cmp_mask = Opmask(7);
-
- Opmask mask(int idx) {
- return Opmask(6 - idx);
- }
-
- // ref to any of XYZ-regs via xreg/yreg/vreg functions
- Xmm xmm_tmp = xreg(0); // temp to init vreg_tmp
- Vmm vreg_tmp = vreg(0); // max pooling : holds minimum values for data_type
- Vmm vreg_zeros = vreg(1);
-
- // only in case of <isa> == avx2
- Vmm vreg_mask = vreg(2); // full byte-mask
- Xmm xreg_mask_lo = xreg(2); // low 128-bits part of byte-mask (alias for xmm part of vreg_mask)
- Xmm xreg_mask_hi = xreg(3); // "max" - high 128-bits part of byte-mask (stored separately)
- Xmm xreg_mask_q = xreg(3); // "avg" - 1/4 part of the mask for s8/u8 operations
- Vmm vreg_mask_q = vreg(3); // "avg" - 1/4 part for non-zero tails
-
- enum:int {vidx_base = isa == avx2 ? 4 : 2};
- Vmm base_vr(int idx) const { return vreg(vidx_base + idx); }
-
- size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); }
- size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); }
-
- /* max pooling */
- Vmm vreg_src(int idx) const { return base_vr(idx); } // [0 .. ur_c-1]
- Vmm vreg_dst(int idx) const { return base_vr(jpp.ur_c + idx); } // [ur_c .. 2*ur_c-1]
-
- /* avg pooling */
- // s32 used for processing of s8/u8 data
- // thus we need to take into account ratio of sizes s32/i8 = 4
- static constexpr data_type_t avg_proc_dt = data_type::s32;
- enum:int {
- s32_to_i8_ratio = sizeof(typename prec_traits<avg_proc_dt>::type)
- / sizeof(typename prec_traits<data_type::u8>::type),
- max_num_ll = s32_to_i8_ratio
- };
- Vmm vreg_src_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 0*max_num_ll); } // ll: 0..4 [0..3]
- Vmm vreg_dst_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 1*max_num_ll); } // ll: 0..4 [4..7]
- Vmm vreg_dst_f32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 2*max_num_ll); } // ll: 0..4 [8..11]
-
- void (*ker_)(const call_params_t *);
- jit_pool_conf_t jpp;
-
- void init_tmp_reg();
- void init_mask();
-
- void load_vreg_mask_q(int ll) {};
-
- void load_src_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
- void load_src_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
- void load_src(int jj, int ll, int c_tail);
-
- void store_dst_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
- void store_dst_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
- void store_dst(int jj, int ll, int c_tail);
-
- void compute_avg_step(int ur_c, int c_tail);
- void compute_max_op(const int jj);
- void compute_max_step(int ur_c, int c_tail);
- void compute_step(int ur_c, int c_tail);
-
- void compute_c_block();
- void generate();
-
- static status_t init_conf(jit_pool_conf_t &jpp, const pooling_pd_t *ppd);
-
- jit_uni_i8i8_pooling_fwd_ker_t(const jit_pool_conf_t &jpp_)
- : jpp(jpp_) {
- generate();
- ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(
- getCode()));
- }
-};
-
-template <>
-void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_vreg_mask_q(int ll) {
-
- // extract ll-th part of mask (ll-th QWORD)
- vpblendd(vreg_mask_q, vreg_zeros, vreg_mask, 0x3 << ll); // 0x3 - mask for 2 x DWORD
-
- // Move mask from ll-th pos to 0-th pos
- if (ll>0)
- vpermq(vreg_mask_q, vreg_mask_q, ll);
-};
-
-template <>
-void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_max_op(int jj, int ll,
- size_t offset, bool masked, uint64_t msk) {
- using namespace data_type;
-
- if (masked) {
- if (jpp.src_dt == s32) {
- vpblendd(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], static_cast<uint8_t>(msk));
- } else {
- vpblendvb(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], vreg_mask);
- }
- } else
- vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
-};
-
-template <>
-void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_max_op(int jj, int ll,
- size_t offset, bool masked, uint64_t msk) {
- using namespace data_type;
-
- if (masked) {
- if (jpp.src_dt == s32)
- vmovups(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]);
- else
- vmovdqu8(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]);
- } else
- vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
-};
-
-template <>
-void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_avg_op(int jj, int ll,
- size_t offset, bool masked, uint64_t msk) {
- using namespace data_type;
-
- // Don't generate useless code
- if (masked && !msk)
- return;
-
- auto load_i8 = [&](bool is_signed, const Vmm& vr_src) {
-
- // Need to use mask of tail?
- if (masked) {
-
- // load ll-th part of mask into vreg_mask_q
- load_vreg_mask_q(ll);
-
- // Load by mask from mem into register vr_src
- vpblendvb(vr_src, vreg_zeros, ptr[aux_reg_src_w + offset], vreg_mask_q);
-
- // Conversion s8/u8 -> s32
- if (is_signed)
- vpmovsxbd(vr_src, vr_src);
- else
- vpmovzxbd(vr_src, vr_src);
- } else {
-
- // Load from mem into vr_src with conversion
- if (is_signed)
- vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]);
- else
- vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]);
- }
- };
-
- switch (jpp.src_dt) {
- case s32:
- if (masked)
- vpblendd(vreg_src_s32(jj, ll), vreg_zeros, ptr[aux_reg_src_w + offset],
- static_cast<uint8_t>(msk));
- else
- vmovups(vreg_src_s32(jj, ll), ptr[aux_reg_src_w + offset]);
- break;
- case s8:
- load_i8(true, vreg_src_s32(jj, ll));
- break;
- case u8:
- load_i8(false, vreg_src_s32(jj, ll));
- break;
- default: assert(!"unsupported src data type");
- }
-};
-
-template <>
-void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_avg_op(int jj, int ll,
- size_t offset, bool masked, uint64_t msk) {
- using namespace data_type;
-
- // Don't generate useless code
- if (masked && !msk)
- return;
-
- const Vmm& vr_src = masked ?
- vreg_src_s32(jj, ll) | mask(ll) :
- vreg_src_s32(jj, ll);
-
- switch (jpp.src_dt) {
- case s32:
- vmovups(vr_src, ptr[aux_reg_src_w + offset]);
- break;
- case s8:
- vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]);
- break;
- case u8:
- vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]);
- break;
- default: assert(!"unsupported src data type");
- }
-};
-
-template <cpu_isa_t isa>
-void jit_uni_i8i8_pooling_fwd_ker_t<isa>::load_src(int jj, int ll, int c_tail) {
- using namespace data_type;
-
- int c_block = jpp.c_block;
- int ur_c = jpp.ur_c;
-
- switch (jpp.alg) {
- case pooling_max: {
- auto offset = jj*c_block*sizeof_src_dt();
- bool masked = jj == ur_c - 1 && c_tail;
- load_src_max_op(jj, ll, offset, masked, jpp.tail[0]);
- break;
- }
- case pooling_avg_include_padding:
- case pooling_avg_exclude_padding: {
- auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_src_dt();
- bool masked = jj == ur_c - 1 && c_tail;
- load_src_avg_op(jj, ll, offset, masked, jpp.tail[ll]);
- break;
- }
- default: assert(!"unsupported algorithm");
- }
-}
-
-template <>
-void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_max_op(int jj, int ll,
- size_t offset, bool masked, uint64_t msk) {
- using namespace data_type;
-
- int c_block = jpp.c_block;
-
- if (masked) {
- switch (jpp.src_dt) {
- case s32:
- vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst(jj));
- break;
- case s8:
- case u8: {
- // Store low half by mask (bytes 0...15)
- lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]);
- maskmovdqu(vreg_dst(jj), xreg_mask_lo);
-
- // Do we need to store high half (bytes 16...31) ?
- const uint64_t low_mask = (1ULL << (c_block/2))-1;
- if (msk & ~low_mask) {
- vextracti128(Xmm(vreg_dst(jj).getIdx()), vreg_dst(jj), 1);
- add(reg_ptr_maskmovdqu_dst, c_block / 2);
- maskmovdqu(vreg_dst(jj), xreg_mask_hi);
- }
- } break;
- default: assert(!"unsupported src data type");
- }
- } else
- vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
-}
-
-template <>
-void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_max_op(int jj, int ll,
- size_t offset, bool masked, uint64_t msk) {
- using namespace data_type;
-
- if (masked) {
- switch (jpp.src_dt) {
- case s32:
- vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0));
- break;
- case s8:
- case u8:
- vmovdqu8(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0));
- break;
- default: assert(!"unsupported src data type");
- }
- } else
- vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
-}
-
-template <>
-void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_avg_op(int jj, int ll,
- size_t offset, bool masked, uint64_t msk){
- using namespace data_type;
-
- // Don't generate useless code
- if (masked && !msk)
- return;
-
- auto s32_to_i8 = [&](bool is_signed, const Vmm& vr_dst) {
-
- // conversion: s32 -> s16/u16 : {8 x s32}{8 x 0} -> {16 x s16/u16}
- // Result QWORDs (qw0, qw1) permuted: {qw0, 0, qw1, 0}
- if (is_signed)
- vpackssdw(vr_dst, vr_dst, vreg_zeros);
- else
- vpackusdw(vr_dst, vr_dst, vreg_zeros);
-
- // Permute qwords to restore original order
- // {qw0, 0, qw1, 0} -> {qw0, qw1, 0, 0}
- vpermq(vr_dst, vr_dst, 0x58);
-
- // conversion: s16/u16 -> s8/u8 : {16 x s16/u16}{16 x 0} -> {32 x s8/u8}
- // Target QWORD qw = {8 x s8/u8} has proper position: {qw, xx, xx, xx}
- if (is_signed)
- vpacksswb(vr_dst, vr_dst, vreg_zeros);
- else
- vpackuswb(vr_dst, vr_dst, vreg_zeros);
-
- };
-
- auto store_i8 = [&](bool is_signed, bool is_masked, const Vmm& vr_dst) {
-
- // Conversion s32 -> s8/u8
- s32_to_i8(is_signed, vr_dst);
-
- // Need to use mask of tail?
- if (is_masked) {
- // load ll-th part of mask into vreg_mask_q
- load_vreg_mask_q(ll);
- }
-
- // store 8 bytes
- lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]);
- maskmovdqu(vr_dst, xreg_mask_q);
- };
-
- switch (jpp.dst_dt) {
- case s32:
- if (masked) {
- vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst_s32(jj, ll));
- } else
- vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst_s32(jj, ll));
- break;
- case s8:
- store_i8(true, masked, vreg_dst_s32(jj, ll));
- break;
- case u8:
- store_i8(false, masked, vreg_dst_s32(jj, ll));
- break;
- default: assert(!"unsuppotred dst data_type");
- }
-}
-
-template <>
-void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_avg_op(int jj, int ll,
- size_t offset, bool masked, uint64_t msk) {
- using namespace data_type;
-
- // Don't generate useless code
- if (masked && !msk)
- return;
-
- const Vmm& vr_dst = masked ?
- vreg_dst_s32(jj, ll) | mask(ll) :
- vreg_dst_s32(jj, ll);
-
- switch (jpp.dst_dt) {
- case s32:
- vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst);
- break;
- case s8:
- vpmovdb(ptr[reg_ptr_dst_i8 + offset], vr_dst);
- break;
- case u8:
- vpmovusdb(ptr[reg_ptr_dst_i8 + offset], vr_dst);
- break;
- default: assert(!"unsupported dst data_type");
- }
-}
-
-
-template <cpu_isa_t isa>
-void jit_uni_i8i8_pooling_fwd_ker_t<isa>::store_dst(int jj, int ll,
- int c_tail) {
- using namespace data_type;
-
- int c_block = jpp.c_block;
- int ur_c = jpp.ur_c;
-
- switch(jpp.alg) {
- case pooling_max: {
- auto offset = jj*c_block*sizeof_dst_dt();
- bool masked = jj == ur_c - 1 && c_tail;
- store_dst_max_op(jj, ll, offset, masked, jpp.tail[ll]);
- break;
- }
- case pooling_avg_include_padding:
- case pooling_avg_exclude_padding: {
- auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_dst_dt();
- bool masked = jj == ur_c - 1 && c_tail;
- store_dst_avg_op(jj, ll, offset, masked, jpp.tail[ll]);
- break;
- }
- default: assert(!"unsupported pooling algorithm");
- }
-}
-
-template <>
-void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::compute_max_op(const int jj)
-{
- using namespace data_type;
- switch (jpp.src_dt) {
- case s32:
- vpmaxsd(vreg_dst(jj), vreg_dst(jj), vreg_src(jj));
- break;
- case s8:
- vpmaxsb(vreg_dst(jj), vreg_dst(jj), vreg_src(jj));
- break;
- case u8:
- vpmaxub(vreg_dst(jj), vreg_dst(jj), vreg_src(jj));
- break;
- default: assert(!"unsupported src data type");
- }
-}
-
-template <>
-void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::compute_max_op(const int jj)
-{
- using namespace data_type;
-
- // Compare
- switch (jpp.src_dt) {
- case s32:
- vpcmpd(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
- break;
- case s8:
- vpcmpb(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
- break;
- case u8:
- vpcmpub(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
- break;
- default: assert(!"unsupported src data type");
- }
-
- // move max values into vreg_dst
- if (jpp.src_dt == s32)
- vpblendmd(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj));
- else
- vpblendmb(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj));
-}
-
-
-template <cpu_isa_t isa>
-void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_max_step(int ur_c, int c_tail)
-{
- Label l_kw, l_kh;
-
- int iw = jpp.iw;
- int c = jpp.c;
-
- for (int jj = 0; jj < ur_c; jj++)
- vmovups(vreg_dst(jj), vreg_tmp);
-
- mov(aux_reg_src_h, reg_ptr_src_i8);
-
- xor_(kj, kj);
- L(l_kh);
- {
- mov(aux_reg_src_w, aux_reg_src_h);
- xor_(ki, ki);
- L(l_kw);
- {
- for (int jj = 0; jj < ur_c; jj++) {
- load_src(jj, 0, c_tail);
- compute_max_op(jj);
- }
- add(aux_reg_src_w, c * sizeof_src_dt());
- inc(ki);
- cmp(ki, reg_kw);
- jl(l_kw, T_NEAR);
- }
- add(aux_reg_src_h, iw * c * sizeof_src_dt());
- inc(kj);
- cmp(kj, reg_kh);
- jl(l_kh, T_NEAR);
- }
-
- for (int jj = 0; jj < ur_c; jj++)
- store_dst(jj, 0, c_tail);
-}
-
-template <cpu_isa_t isa>
-void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_avg_step(int ur_c, int c_tail)
-{
- using namespace data_type;
-
- Label l_kw, l_kh;
-
- int iw = jpp.iw;
- int c = jpp.c;
-
- const int num_ll = data_type_size(avg_proc_dt)/data_type_size(jpp.src_dt);
-
- for (int jj = 0; jj < ur_c; jj++) {
- for (int ll = 0; ll < num_ll; ll++) {
- bool masked = jj == ur_c - 1 && c_tail;
- size_t msk = jpp.tail[ll];
- if (!(masked && !msk)) {
- uni_vpxor(vreg_src_s32(jj, ll), vreg_src_s32(jj, ll), vreg_src_s32(jj, ll));
- uni_vpxor(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll));
- }
- }
- }
-
- mov(aux_reg_src_h, reg_ptr_src_i8);
-
- xor_(kj, kj);
- L(l_kh);
- {
- mov(aux_reg_src_w, aux_reg_src_h);
- xor_(ki, ki);
- L(l_kw);
- {
- for (int jj = 0; jj < ur_c; jj++) {
- for (int ll = 0; ll < num_ll; ll++) {
- bool masked = jj == ur_c - 1 && c_tail;
- size_t msk = jpp.tail[ll];
- if (!(masked && !msk)) {
- load_src(jj, ll, c_tail);
- vpaddd(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll),
- vreg_src_s32(jj, ll));
- }
- }
- }
- add(aux_reg_src_w, c * sizeof_src_dt());
- inc(ki);
- cmp(ki, reg_kw);
- jl(l_kw, T_NEAR);
- }
- add(aux_reg_src_h, iw * c * sizeof_src_dt());
- inc(kj);
- cmp(kj, reg_kh);
- jl(l_kh, T_NEAR);
- }
-
- for (int jj = 0; jj < ur_c; jj++) {
- for (int ll = 0; ll < num_ll; ll++) {
- bool masked = jj == ur_c - 1 && c_tail;
- size_t msk = jpp.tail[ll];
- if (!(masked && !msk)) {
- vcvtdq2ps(vreg_dst_f32(jj, ll), vreg_dst_s32(jj, ll));
- vfmadd132ps(vreg_dst_f32(jj, ll), vreg_zeros, vreg_tmp);
- vcvtps2dq(vreg_dst_s32(jj, ll), vreg_dst_f32(jj, ll));
- store_dst(jj, ll, c_tail);
- }
- }
- }
-}
-
-template <cpu_isa_t isa>
-void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_step(int ur_c, int c_tail) {
- switch (jpp.alg) {
- case pooling_max:
- compute_max_step(ur_c, c_tail); break;
- case pooling_avg_include_padding:
- case pooling_avg_exclude_padding:
- compute_avg_step(ur_c, c_tail); break;
- default: assert(!"unsupported pooling algorithm");
- }
-}
-
-template <cpu_isa_t isa>
-void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_c_block(){
- Label l_main_loop;
-
- int nb_c = jpp.nb_c;
- int c_block = jpp.c_block;
- int ur_c = jpp.ur_c;
- int ur_c_tail = jpp.ur_c_tail;
- int c_steps = nb_c / ur_c;
- int c_tail = jpp.c_tail;
-
- xor_(c_iter, c_iter);
- if (c_steps > 0) {
- L(l_main_loop); {
- compute_step(ur_c, 0);
- add(reg_ptr_src_i8, ur_c*c_block*sizeof_src_dt());
- add(reg_ptr_dst_i8, ur_c*c_block*sizeof_dst_dt());
- inc(c_iter);
- cmp(c_iter, c_steps);
- jl(l_main_loop, T_NEAR);
- }
- }
-
- if (ur_c_tail != 0) {
- compute_step(ur_c_tail, c_tail);
- }
-}
-
-template<>
-void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::init_mask() {
- using namespace data_type;
- using cpu_isa = cpu_isa_traits<avx2>;
-
- // AVX2 mask initialization: mask stored in Ymm-regs
- auto init = [&](uint64_t bit_mask, bool init_mask_q) {
- const size_t QW_PER_VREG = cpu_isa::vlen / sizeof(uint64_t);
-
- uint64_t vmask[QW_PER_VREG];
- for (size_t i = 0; i < QW_PER_VREG; i++){
-
- uint64_t qw_vmask=0ULL;
- const size_t DBITS = 8*sizeof_src_dt();
- const uint64_t VMSK = 1ULL << (DBITS-1);
- const size_t D_PER_QW = (8*sizeof(qw_vmask))/DBITS;
- for (size_t j = 0; j < D_PER_QW; j++) {
- if (bit_mask & 1)
- qw_vmask |= VMSK << DBITS * j;
- bit_mask >>= 1;
- }
- vmask[i] = qw_vmask;
- }
-
- // Put QWORDS with target mask into xmm regs
- const int xdst_i[QW_PER_VREG] = {
- xreg_mask_lo.getIdx(),
- xreg_mask_lo.getIdx(),
- xreg_mask_hi.getIdx(),
- xreg_mask_hi.getIdx()
- };
- const int xsrc_i[QW_PER_VREG] = {
- vreg_zeros.getIdx(), // 0-th qword insert in zeros -> {qw0, 0}
- xreg_mask_lo.getIdx(), // 1-st and 0-th merge -> {qw0,qw1}
- vreg_zeros.getIdx(),
- xreg_mask_hi.getIdx()
- };
- const uint8 qw_dst_idx[QW_PER_VREG] = {0, 1, 0, 1}; // qword index in 128-bit xreg
-
- for (size_t i = 0; i < QW_PER_VREG; i++) {
- mov(reg_mask, vmask[i]);
- vpinsrq(Xmm(xdst_i[i]), Xmm(xsrc_i[i]), reg_mask, qw_dst_idx[i]);
- }
-
- // Merge Low (xreg_mask_lo alias for vreg_mask.xreg)
- // and High (xreg_mask_hi) into full vreg_mask
- // vreg_mask -> {xreg_mask_hi, vreg_mask.xreg}
- vinserti128(vreg_mask, vreg_mask, xreg_mask_hi, 1);
-
- // Keep only low qword of mask in xreg_mask_q
- if (init_mask_q) {
- mov(reg_mask, vmask[0]);
- vpinsrq(xreg_mask_q, Xmm(vreg_zeros.getIdx()), reg_mask, 0);
- }
- };
-
- uint64_t tail_mask = (1ULL << jpp.c_tail) - 1;
- switch (jpp.alg) {
- case pooling_max:
- // For "max" we need mask only in case of non-zero tail
- if (tail_mask)
- init(tail_mask, false);
- break;
- case pooling_avg_include_padding:
- case pooling_avg_exclude_padding:
- // For "avg" we need mask:
- // - s32 - in case of the non-zero tail
- // - s8/u8 - irrespective of the tail
- switch (jpp.src_dt) {
- case s32:
- if (tail_mask)
- init(tail_mask, false);
- break;
- case s8:
- case u8:
- init(tail_mask ? tail_mask : ~0ULL, tail_mask == 0);
- break;
- default: assert(!"unsupported src data type");
- }
- break;
- default: assert(!"unsupported pooling algorithm");
- }
-}
-
-template<>
-void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::init_mask() {
-
- for (int ll = 0; ll < max_num_ll; ll++) {
- mov(reg_mask, jpp.tail[ll]);
- kmovq(mask(ll), reg_mask);
- }
-}
-
-template <cpu_isa_t isa>
-void jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_tmp_reg() {
- using namespace data_type;
-
- switch (jpp.alg) {
- case pooling_avg_include_padding:
- case pooling_avg_exclude_padding:
- mov(reg_tmp, ptr[reg_param + offsetof(call_params_t, idivider)]);
- movq(xmm_tmp, reg_tmp);
- vpbroadcastd(vreg_tmp, xmm_tmp);
- break;
- case pooling_max:
- switch (jpp.src_dt) {
- case s32:
- mov(reg_tmp, nstl::numeric_limits<int32_t>::lowest());
- break;
- case s8:
- mov(reg_tmp, nstl::numeric_limits<int8_t>::lowest());
- break;
- case u8:
- mov(reg_tmp, nstl::numeric_limits<uint8_t>::lowest());
- break;
- default: assert(!"unsupported src data_type");
- }
-
- movq(xmm_tmp, reg_tmp);
- if (jpp.src_dt == s32)
- vpbroadcastd(vreg_tmp, xmm_tmp);
- else
- vpbroadcastb(vreg_tmp, xmm_tmp);
- break;
- default: assert(!"unsupported pooling algorithm");
- }
-
-}
-
-template <cpu_isa_t isa>
-void jit_uni_i8i8_pooling_fwd_ker_t<isa>::generate() {
- preamble();
-
-#if !defined(_WIN32)
- // Always use rcx as abi_param1 -
- // see the note about maskmovdqu near reg_param.
- mov(rcx, rdi);
-#endif
-
-# define READ_PARAM(reg, field) \
- mov(reg, ptr[reg_param + offsetof(call_params_t, field)])
- READ_PARAM(reg_ptr_src_i8, src_i8);
- READ_PARAM(reg_ptr_dst_i8, dst_i8);
- READ_PARAM(reg_kw, kw_range);
- READ_PARAM(reg_kh, kh_range);
-
-# undef READ_PARAM
-
- uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
-
- init_mask();
-
- init_tmp_reg();
-
- compute_c_block();
-
- postamble();
-}
-
-template <cpu_isa_t isa>
-status_t jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(jit_pool_conf_t &jpp,
- const pooling_pd_t *ppd) {
- if (!mayiuse(isa))
- return status::unimplemented;
-
- const auto &pd = *ppd->desc();
- const memory_desc_wrapper src_d(ppd->src_md());
- const memory_desc_wrapper dst_d(ppd->dst_md());
-
- jpp.mb = src_d.dims()[0];
- jpp.c = src_d.dims()[1];
- jpp.ih = src_d.dims()[2];
- jpp.iw = src_d.dims()[3];
- jpp.oh = dst_d.dims()[2];
- jpp.ow = dst_d.dims()[3];
-
- jpp.stride_h = pd.strides[0];
- jpp.stride_w = pd.strides[1];
- jpp.kh = pd.kernel[0];
- jpp.kw = pd.kernel[1];
-
- jpp.t_pad = pd.padding[0][0];
- jpp.l_pad = pd.padding[0][1];
-
- jpp.alg = pd.alg_kind;
-
- jpp.src_dt = pd.src_desc.data_type;
- jpp.dst_dt = pd.dst_desc.data_type;
-
- // data_type items per one vreg on the <isa>
- // isa == avx2 : 32 bytes -> 32 for s8/u8, 8 for s32
- // isa == avx512* : 64 bytes -> 64 for s8/u8, 16 for s32
- int simd_w = cpu_isa_traits<isa>::vlen / data_type_size(jpp.src_dt);
-
- jpp.c_block = simd_w;
- jpp.c_tail = jpp.c % jpp.c_block;
- jpp.nb_c = jpp.c / jpp.c_block;
- jpp.ur_c = 1;
- jpp.ur_c_tail = jpp.nb_c - (jpp.nb_c / jpp.ur_c)*jpp.ur_c +
- (jpp.c_tail != 0);
-
- size_t tail_mask = (1ULL << jpp.c_tail) - 1;
-
- switch (jpp.alg) {
- case pooling_max:
- jpp.tail[0] = tail_mask;
- jpp.tail[1] = 0;
- jpp.tail[2] = 0;
- jpp.tail[3] = 0;
- break;
- case pooling_avg_include_padding:
- case pooling_avg_exclude_padding: {
- // avg_proc_dt (s32) defines granularity (because u8/s8 processed as s32)
- // avx2 : 8, avx512 : 16
- const size_t msk_gran = cpu_isa_traits<isa>::vlen / data_type_size(avg_proc_dt);
- const size_t msk_msk = (1ULL << msk_gran) - 1;
- size_t m = tail_mask;
- for (size_t ll = 0; ll < max_num_ll; ll++) {
- jpp.tail[ll] = m & msk_msk;
- m = m >> msk_gran;
- }
- break;
- }
- default: return status::unimplemented;
- }
-
- return status::success;
-}
-
-template <cpu_isa_t isa>
-status_t jit_uni_i8i8_pooling_fwd_t<isa>::pd_t::jit_conf() {
- return jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(jpp_, this);
-}
-
-template <cpu_isa_t isa>
-jit_uni_i8i8_pooling_fwd_t<isa>::
-jit_uni_i8i8_pooling_fwd_t(const pd_t *apd)
- : cpu_primitive_t(apd), ker_(nullptr)
-{ ker_ = new jit_uni_i8i8_pooling_fwd_ker_t<isa>(pd()->jpp_); }
-
-template <cpu_isa_t isa>
-jit_uni_i8i8_pooling_fwd_t<isa>::
-~jit_uni_i8i8_pooling_fwd_t() { delete ker_; }
-
-template <cpu_isa_t isa>
-void jit_uni_i8i8_pooling_fwd_t<isa>::execute_forward(
- const exec_ctx_t &ctx) const {
- auto src_i8 = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC);
- auto dst_i8 = CTX_OUT_MEM(char *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
-
- const auto &jpp = pd()->jpp_;
-
- parallel_nd(jpp.mb, jpp.oh, jpp.ow,
- [&](int n, int oh, int ow) {
- const int ih = nstl::max(oh*jpp.stride_h - jpp.t_pad, 0);
- const int iw = nstl::max(ow*jpp.stride_w - jpp.l_pad, 0);
-
- const int kh_start = nstl::max(0, jpp.t_pad - oh * jpp.stride_h);
- const int kh_end = nstl::min(jpp.kh,
- jpp.ih + jpp.t_pad - oh * jpp.stride_h);
- const int kw_start = nstl::max(0, jpp.l_pad - ow * jpp.stride_w);
- const int kw_end = nstl::min(jpp.kw,
- jpp.iw + jpp.l_pad - ow * jpp.stride_w);
-
- auto p = typename jit_uni_i8i8_pooling_fwd_ker_t<isa>::call_params_t();
- p.src_i8 = &src_i8[
- src_d.blk_off(n, 0, ih, iw) * src_d.data_type_size()];
- p.dst_i8 = &dst_i8[
- dst_d.blk_off(n, 0, oh, ow) * dst_d.data_type_size()];
- p.kw_range = (size_t)(kw_end - kw_start);
- p.kh_range = (size_t)(kh_end - kh_start);
- p.idivider = 1.0f / ((jpp.alg == pooling_avg_exclude_padding) ?
- p.kh_range*p.kw_range : jpp.kw*jpp.kh);
-
- ker_->ker_(&p);
- });
-}
-
-// Explicit instantiation only for supported <isa> values.
-//
-template struct jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>;
-template struct jit_uni_i8i8_pooling_fwd_t<avx512_core>;
-
-template struct jit_uni_i8i8_pooling_fwd_ker_t<avx2>;
-template struct jit_uni_i8i8_pooling_fwd_t<avx2>;
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp
deleted file mode 100644
index d757679df5..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp
+++ /dev/null
@@ -1,89 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_UNI_I8I8_POOLING_HPP
-#define CPU_JIT_UNI_I8I8_POOLING_HPP
-
-#include "c_types_map.hpp"
-
-#include "cpu_pooling_pd.hpp"
-#include "cpu_primitive.hpp"
-
-#include "cpu_isa_traits.hpp"
-#include "jit_primitive_conf.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <cpu_isa_t isa>
-struct jit_uni_i8i8_pooling_fwd_ker_t;
-
-template <cpu_isa_t isa>
-struct jit_uni_i8i8_pooling_fwd_t : public cpu_primitive_t {
- struct pd_t : public cpu_pooling_fwd_pd_t {
- using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t;
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit:", isa, ""),
- jit_uni_i8i8_pooling_fwd_t<isa>);
-
- status_t init() {
- bool ok = true
- && mayiuse(isa)
- && ndims() == 4
- && set_default_params() == status::success
- && desc()->prop_kind == prop_kind::forward_inference
- && utils::one_of(desc()->alg_kind, alg_kind::pooling_max,
- alg_kind::pooling_avg_include_padding,
- alg_kind::pooling_avg_exclude_padding)
- && utils::one_of(src_md()->data_type, data_type::s32,
- data_type::s8, data_type::u8)
- && src_md()->data_type == dst_md()->data_type
- && attr()->has_default_values()
- && memory_desc_matches_tag(*src_md(), format_tag::nhwc)
- && memory_desc_matches_tag(*dst_md(), format_tag::nhwc);
- if (!ok) return status::unimplemented;
-
- return jit_conf();
- }
-
- jit_pool_conf_t jpp_;
-
- protected:
- status_t jit_conf();
- };
-
- jit_uni_i8i8_pooling_fwd_t(const pd_t *apd);
- ~jit_uni_i8i8_pooling_fwd_t();
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_uni_i8i8_pooling_fwd_ker_t<isa> *ker_;
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp
deleted file mode 100644
index 2c5a8e8973..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp
+++ /dev/null
@@ -1,305 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "jit_uni_lrn_kernel_f32.hpp"
-#include "jit_uni_lrn.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::format_tag;
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::utils;
-
-template <cpu_isa_t isa>
-jit_uni_lrn_fwd_t<isa>::jit_uni_lrn_fwd_t(const pd_t *apd)
- : cpu_primitive_t(apd), ker_(nullptr)
- , ker_first_(nullptr), ker_last_(nullptr)
-{
- using namespace alg_kind;
-
- const int C = pd()->C();
- const int H = pd()->H();
- const int W = pd()->W();
- const int ls = pd()->desc()->local_size;
- float A = pd()->desc()->lrn_alpha / ls;
- float K = pd()->desc()->lrn_k;
-
- auto pk = pd()->desc()->prop_kind;
- auto ak = pd()->desc()->alg_kind;
- auto dat_tag = pd()->dat_tag_;
-
- if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) {
- ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
- nchw8c_across(H, W, 0), A, K, pk);
- ker_first_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
- nchw8c_across(H, W, -1), A, K, pk);
- ker_last_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
- nchw8c_across(H, W, +1), A, K, pk);
- } else if (dat_tag == nChw8c && ak == lrn_within_channel) {
- /* within channel, local_size (x) local_size */
- A /= ls; /* XXX: why? */
- ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
- nchw8c_within(H, W, ls), A, K, pk);
- } else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) {
- ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
- nchw_across(C, H*W, 0), A, K, pk);
- int remind = (H*W) % VECTOR_LENGTH;
- if (remind != 0) {
- ker_last_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
- nchw_across(C, H*W, remind), A, K, pk);
- }
- } else if (true /* XXX: why */) {
- ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(nhwc_across(C), A, K, pk);
- }
-}
-
-template <cpu_isa_t isa>
-jit_uni_lrn_fwd_t<isa>::~jit_uni_lrn_fwd_t()
-{ delete ker_; delete ker_first_; delete ker_last_; }
-
-template <cpu_isa_t isa>
-void jit_uni_lrn_fwd_t<isa>::execute_forward(const exec_ctx_t &ctx) const {
- using namespace alg_kind;
-
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
- auto ws = CTX_OUT_MEM(data_t *, MKLDNN_ARG_WORKSPACE);
-
- const int N = pd()->MB();
- const int C = pd()->C();
- const int HW = pd()->H() * pd()->W();
- const int ls = pd()->desc()->local_size;
-
- auto ak = pd()->desc()->alg_kind;
- auto dat_tag = pd()->dat_tag_;
-
- if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) {
- parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) {
- jit_args_fwd_t args;
- args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH];
- args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH];
- args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH];
- if (c8 == 0)
- (*ker_first_)(&args);
- else if (c8 == C / VECTOR_LENGTH - 1)
- (*ker_last_)(&args);
- else
- (*ker_)(&args);
- });
- }
- else if (dat_tag == nChw8c && ak == lrn_within_channel) {
- parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) {
- jit_args_fwd_t args;
- args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH];
- args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH];
- args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH];
- (*ker_)(&args);
- });
- }
- else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) {
- parallel_nd(N, (HW + VECTOR_LENGTH - 1) / VECTOR_LENGTH,
- [&](int n, int hw8) {
- jit_args_fwd_t args;
- args.src = &src[n*HW*C + hw8 * VECTOR_LENGTH];
- args.dst = &dst[n*HW*C + hw8 * VECTOR_LENGTH];
- args.scratch = &ws[n*HW*C + hw8 * VECTOR_LENGTH];
- if ((hw8 + 1)*VECTOR_LENGTH > HW)
- (*ker_last_)(&args);
- else
- (*ker_)(&args);
- });
- }
- else { // nhwc
- parallel_nd(N, HW, [&](int n, int hw) {
- jit_args_fwd_t args;
- args.src = &src[n*HW*C + hw * C];
- args.dst = &dst[n*HW*C + hw * C];
- args.scratch = &ws[n*HW*C + hw * C];
- (*ker_)(&args);
- });
- }
-}
-
-template <cpu_isa_t isa>
-status_t jit_uni_lrn_fwd_t<isa>::pd_t::init() {
- using namespace prop_kind;
- using namespace alg_kind;
-
- const memory_desc_wrapper data_d(src_md());
- bool ok = true
- && mayiuse(isa)
- && is_fwd()
- && everyone_is(data_type::f32, data_d.data_type())
- && !has_zero_dim_memory()
- && data_d.ndims() == 4
- && data_d.dims()[1] % VECTOR_LENGTH == 0
- && data_d.dims()[1] >= 2 * VECTOR_LENGTH
- && desc()->lrn_beta == 0.75
- && attr()->has_default_values();
- if (!ok) return unimplemented;
-
- if (desc_.prop_kind == forward_training) ws_md_ = *src_md();
-
- dat_tag_ = memory_desc_matches_one_of_tag(*src_md(), nChw8c, nchw, nhwc);
-
- bool args_ok_across = true
- && desc()->alg_kind == lrn_across_channels
- && desc()->local_size == 5
- && one_of(dat_tag_, nChw8c, nchw, nhwc);
-
- const int jit_max_local_size = 5; // bigger size triggers too big code size
- bool args_ok_within = true
- && desc()->alg_kind == lrn_within_channel
- && desc()->local_size <= ( jit_max_local_size <= MAX_LOCAL_SIZE
- ? jit_max_local_size : MAX_LOCAL_SIZE)
- && data_d.dims()[2] >= desc()->local_size
- && data_d.dims()[3] >= desc()->local_size
- && one_of(dat_tag_, nChw8c);
-
- return args_ok_across || args_ok_within ? success : unimplemented;
-}
-
-template <cpu_isa_t isa>
-jit_uni_lrn_bwd_t<isa>::jit_uni_lrn_bwd_t(const pd_t *apd)
- : cpu_primitive_t(apd)
- , ker_(nullptr), ker_first_(nullptr), ker_last_(nullptr)
-{
- using namespace alg_kind;
- const int C = pd()->C();
- const int H = pd()->H();
- const int W = pd()->W();
- const int ls = pd()->desc()->local_size;
- float A = pd()->desc()->lrn_alpha / ls;
- float B = pd()->desc()->lrn_beta;
-
- int use_h_parallelizm = 0;// XXX
- if (C / VECTOR_LENGTH == 1) {
- ker_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
- nchw8c_across(H, W, 3), A, B, use_h_parallelizm);
- }
- else {
- ker_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
- nchw8c_across(H, W, 0), A, B, use_h_parallelizm);
- ker_first_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
- nchw8c_across(H, W, -1), A, B, use_h_parallelizm);
- ker_last_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
- nchw8c_across(H, W, +1), A, B, use_h_parallelizm);
- }
-}
-
-template <cpu_isa_t isa>
-jit_uni_lrn_bwd_t<isa>::~jit_uni_lrn_bwd_t()
-{
- delete ker_; delete ker_first_; delete ker_last_;
-}
-
-template <cpu_isa_t isa>
-void jit_uni_lrn_bwd_t<isa>::execute_backward(const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto ws = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WORKSPACE);
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const int N = pd()->MB();
- const int C = pd()->C();
- const int H = pd()->H();
- const int W = pd()->W();
-
- int use_h_parallelizm = 0; // XXX
- if (use_h_parallelizm) {
- parallel_nd(N, C / VECTOR_LENGTH, H, [&](int n, int c8, int h) {
- auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH
- + h*W*VECTOR_LENGTH;
- jit_args_bwd_t args;
- args.src = &src[offset];
- args.diff_dst = &diff_dst[offset];
- args.scratch = &ws[offset];
- args.diff_src = &diff_src[offset];
- if (C / VECTOR_LENGTH == 1)
- (*ker_)(&args);
- else if (c8 == 0)
- (*ker_first_)(&args);
- else if (c8 == C / VECTOR_LENGTH - 1)
- (*ker_last_)(&args);
- else
- (*ker_)(&args);
- });
- }
- else {
- parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) {
- auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH;
- jit_args_bwd_t args;
- args.src = &src[offset];
- args.diff_dst = &diff_dst[offset];
- args.scratch = &ws[offset];
- args.diff_src = &diff_src[offset];
- if (C / VECTOR_LENGTH == 1)
- (*ker_)(&args);
- else if (c8 == 0)
- (*ker_first_)(&args);
- else if (c8 == C / VECTOR_LENGTH - 1)
- (*ker_last_)(&args);
- else
- (*ker_)(&args);
- });
- }
-}
-
-template <cpu_isa_t isa>
-status_t jit_uni_lrn_bwd_t<isa>::pd_t::init() {
- using namespace prop_kind;
- using namespace alg_kind;
-
- const memory_desc_wrapper data_d(src_md());
- bool ok = true
- && mayiuse(isa)
- && !is_fwd()
- && utils::everyone_is(data_type::f32, data_d.data_type())
- && !has_zero_dim_memory()
- && data_d.ndims() == 4
- && data_d.dims()[1] % VECTOR_LENGTH == 0
- && desc()->lrn_beta == 0.75
- && attr()->has_default_values();
- if (!ok) return unimplemented;
-
- ws_md_ = *src_md();
- if (!compare_ws(hint_fwd_pd_)) return unimplemented;
-
- dat_tag_ = memory_desc_matches_one_of_tag(*src_md(), nChw8c);
-
- bool args_ok_across = true
- && desc()->alg_kind == lrn_across_channels
- && desc()->local_size == 5
- && utils::one_of(dat_tag_, nChw8c);
-
- return args_ok_across ? success : unimplemented;
-}
-
-template struct jit_uni_lrn_fwd_t<sse42>;
-template struct jit_uni_lrn_fwd_t<avx2>;
-template struct jit_uni_lrn_bwd_t<avx2>;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp
deleted file mode 100644
index 333cd3396d..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp
+++ /dev/null
@@ -1,103 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_UNI_LRN_HPP
-#define CPU_JIT_UNI_LRN_HPP
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_isa_traits.hpp"
-#include "cpu_lrn_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <cpu_isa_t isa> struct jit_uni_lrn_fwd_kernel_f32;
-template <cpu_isa_t isa> struct jit_uni_lrn_bwd_kernel_f32;
-
-template <cpu_isa_t isa>
-struct jit_uni_lrn_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_lrn_fwd_pd_t {
- using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t;
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit:", isa, ""),
- jit_uni_lrn_fwd_t<isa>);
-
- status_t init();
-
- format_tag_t dat_tag_;
- };
-
- jit_uni_lrn_fwd_t(const pd_t *apd);
- ~jit_uni_lrn_fwd_t();
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_uni_lrn_fwd_kernel_f32<isa> *ker_, *ker_first_, *ker_last_;
-};
-
-template <cpu_isa_t isa>
-struct jit_uni_lrn_bwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_lrn_bwd_pd_t {
- using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t;
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit:", isa, ""),
- jit_uni_lrn_bwd_t<isa>);
-
- status_t init();
-
- format_tag_t dat_tag_;
- };
-
- jit_uni_lrn_bwd_t(const pd_t *apd);
- ~jit_uni_lrn_bwd_t();
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward(ctx);
- return status::success;
- }
-
-private:
- void execute_backward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- jit_uni_lrn_bwd_kernel_f32<isa> *ker_, *ker_first_, *ker_last_;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp
deleted file mode 100644
index 89af47272c..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp
+++ /dev/null
@@ -1,1487 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "nstl.hpp"
-#include "utils.hpp"
-
-#include "jit_uni_lrn_kernel_f32.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace Xbyak;
-
-//////////////////////////////////////////////////////////////////////////////
-// forward kernel
-template<cpu_isa_t isa>
-void jit_uni_lrn_fwd_kernel_f32<isa>::within_body(
- int hoff, int Hoff, int woff, int Woff, int stride,
- Xbyak::Ymm ysum, Xbyak::Ymm ydst, Xbyak::Ymm ytmp, Xbyak::Ymm ysum2,
- prop_kind_t pk)
-{
- vxorps(ysum, ysum, ysum);
- for (int i = hoff; i <= Hoff; ++i)
- {
- for (int j = woff; j <= Woff; ++j)
- {
- if (i == 0 && j == 0)
- {
- vmovups(ydst, ptr[src]);
- vfmadd231ps(ysum, ydst, ydst);
- }
- else
- {
- vmovups(ytmp, ptr[src + (i*stride + j)*VECTOR_LENGTH*4]);
- vfmadd231ps(ysum, ytmp, ytmp);
- }
- }
- }
- vfmadd132ps(ysum, yk, yalpha); // ysum <- ysum*yalpha+yk
- vmovaps(ytmp, ysum);
- if (pk != prop_kind::forward_inference)
- vmovups(ptr[scratch], ytmp);
- vmulps(ysum2, ysum, ysum);
- vmulps(ysum, ysum, ysum2); // ysum = (ysum*yalpha+yk)^3;
- vsqrtps(ysum, ysum);
- vsqrtps(ysum, ysum); // ysum = (ysum*yalpha+yk)^0.75
- vdivps(ydst, ydst, ysum); // ydst <- ydst / ysum
- vmovups(ptr[dst], ydst);
- add(src, 32);
- add(dst, 32);
- if (pk != prop_kind::forward_inference)
- add(scratch, 32);
-}
-
-template<cpu_isa_t isa>
-void jit_uni_lrn_fwd_kernel_f32<isa>::within_body_sse42(
- int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk)
-{
- Xbyak::Xmm xtmp_lo = xmm12;
- Xbyak::Xmm xtmp_hi = xmm13;
- Xbyak::Xmm xsum_lo = xmm8;
- Xbyak::Xmm xsum_hi = xmm9;
- Xbyak::Xmm xdst_lo = xmm10;
- Xbyak::Xmm xdst_hi = xmm11;
- Xbyak::Xmm xsum2_lo = xmm14;
- Xbyak::Xmm xsum2_hi = xmm15;
-
- xorps(xsum_lo, xsum_lo);
- xorps(xsum_hi, xsum_hi);
- for (int i = hoff; i <= Hoff; ++i)
- {
- for (int j = woff; j <= Woff; ++j)
- {
- if (i == 0 && j == 0)
- {
- movups(xdst_lo, ptr[src]);
- movups(xdst_hi, ptr[src + 4 * sizeof(float)]);
- mulps(xdst_lo, xdst_lo);
- mulps(xdst_hi, xdst_hi);
- addps(xsum_lo, xdst_lo);
- addps(xsum_hi, xdst_hi);
- }
- else
- {
- movups(xtmp_lo, ptr[src + (i*stride + j)*VECTOR_LENGTH * 4]);
- movups(xtmp_hi, ptr[src + (i*stride + j)*VECTOR_LENGTH * 4 + 4 * sizeof(float)]);
- mulps(xtmp_lo, xtmp_lo);
- mulps(xtmp_hi, xtmp_hi);
- addps(xsum_lo, xtmp_lo);
- addps(xsum_hi, xtmp_hi);
- }
- }
- }
- mulps(xsum_lo, xalpha);
- mulps(xsum_hi, xalpha);
- addps(xsum_lo, xk);
- addps(xsum_hi, xk); // xsum <- xsum*xalpha+xk
- movaps(xtmp_lo, xsum_lo);
- movaps(xtmp_hi, xsum_hi);
- if (pk != prop_kind::forward_inference) {
- movups(ptr[scratch], xtmp_lo);
- movups(ptr[scratch + 4 * sizeof(float)], xtmp_hi);
- }
- movaps(xsum2_lo, xsum_lo);
- movaps(xsum2_hi, xsum_hi);
- mulps(xsum2_lo, xsum_lo);
- mulps(xsum2_hi, xsum_hi);
- mulps(xsum_lo, xsum2_lo);
- mulps(xsum_hi, xsum2_hi); // xsum = (xsum*xalpha+xk)^3;
-
- sqrtps(xsum_lo, xsum_lo);
- sqrtps(xsum_hi, xsum_hi);
- sqrtps(xsum_lo, xsum_lo);
- sqrtps(xsum_hi, xsum_hi); // xsum = (xsum*xalpha+xk)^0.75
-
- movups(xdst_lo, ptr[src]);
- movups(xdst_hi, ptr[src + 4 * sizeof(float)]);
- divps(xdst_lo, xsum_lo);
- divps(xdst_hi, xsum_hi); // xdst <- xdst / xsum
-
- movups(ptr[dst], xdst_lo);
- movups(ptr[dst + 4 * sizeof(float)], xdst_hi);
- add(src, 32);
- add(dst, 32);
- if (pk != prop_kind::forward_inference)
- add(scratch, 32);
-}
-
-template <cpu_isa_t isa>
-jit_uni_lrn_fwd_kernel_f32<isa>::jit_uni_lrn_fwd_kernel_f32(
- const struct nchw8c_within &J,
- float A,
- float K,
- prop_kind_t pk,
- void *code_ptr,
- size_t code_size)
- : jit_generator(code_ptr, code_size)
- , alpha(A), k(K)
-{
- Xbyak::Reg64 h = r9;
- Xbyak::Reg64 w = r10;
- Vmm ysum = Vmm(isa == avx2 ? 9 : 9);
- Vmm ysum2 = Vmm(isa == avx2 ? 10 : 10);
- Vmm ydst = Vmm(isa == avx2 ? 11 : 11);
- Vmm ytmp = Vmm(isa == avx2 ? 12 : 12);
-
- this->preamble();
-
- mov(src, ptr[this->param1 + 0]);
- mov(dst, ptr[this->param1 + 8]);
- if (pk != prop_kind::forward_inference)
- mov(scratch, ptr[this->param1 + 16]);
-
- mov(imm_addr64, float2int(this->alpha));
- movq(xalpha, imm_addr64);
- if (isa == avx2) {
- vbroadcastss(yalpha, xalpha);
- } else {
- shufps(xalpha, xalpha, 0);
- }
-
- mov(imm_addr64, float2int(this->k));
- movq(xk, imm_addr64);
- if (isa == avx2) {
- vbroadcastss(yk, xk);
- } else {
- shufps(xk, xk, 0);
- }
-
- int s2 = (J.size - 1) / 2, S2 = J.size - s2 - 1;
-
- for (int i = 0; i < s2; ++i)
- {
- Label label_t;
- for (int j = 0; j < s2; ++j) {
- if (isa == avx2) {
- within_body(-i, S2, -j, S2, J.W, ysum, ydst, ytmp, ysum2, pk);
- }
- else {
- within_body_sse42(-i, S2, -j, S2, J.W, pk);
- }
- }
- mov(w, J.W - J.size + 1);
- L(label_t);
- if (isa == avx2) {
- within_body(-i, S2, -s2, S2, J.W, ysum, ydst, ytmp, ysum2, pk);
- } else {
- within_body_sse42(-i, S2, -s2, S2, J.W, pk);
- }
- dec(w);
- cmp(w, 0);
- jne(label_t, T_NEAR);
- for (int j = J.W - S2; j < J.W; ++j) {
- if (isa == avx2) {
- within_body(-i, S2, -s2, J.W - 1 - j, J.W,
- ysum, ydst, ytmp, ysum2, pk);
- } else {
- within_body_sse42(-i, S2, -s2, J.W - 1 - j, J.W, pk);
- }
- }
- }
-
- mov(h, J.H - J.size + 1);
- Label lrn_loop_h;
- L(lrn_loop_h);
- for (int j = 0; j < s2; ++j) {
- if (isa == avx2) {
- within_body(-s2, S2, -j, S2, J.W, ysum, ydst, ytmp, ysum2, pk);
- } else {
- within_body_sse42(-s2, S2, -j, S2, J.W, pk);
- }
- }
- mov(w, J.W - J.size + 1);
- Label lrn_loop_w;
- L(lrn_loop_w);
- if (isa == avx2) {
- within_body(-s2, S2, -s2, S2, J.W, ysum, ydst, ytmp, ysum2, pk);
- } else {
- within_body_sse42(-s2, S2, -s2, S2, J.W, pk);
- }
- dec(w);
- cmp(w, 0);
- jne(lrn_loop_w, T_NEAR);
- for (int j = J.W - S2; j < J.W; ++j) {
- if (isa == avx2) {
- within_body(-s2, S2, -s2, J.W - 1 - j, J.W,
- ysum, ydst, ytmp, ysum2, pk);
- } else {
- within_body_sse42(-s2, S2, -s2, J.W - 1 - j, J.W, pk);
- }
- }
- dec(h);
- cmp(h, 0);
- jne(lrn_loop_h, T_NEAR);
-
- for (int i = J.H - S2; i < J.H; ++i)
- {
- for (int j = 0; j < s2; ++j) {
- if (isa == avx2) {
- within_body(-s2, J.H - 1 - i, -j, S2, J.W,
- ysum, ydst, ytmp, ysum2, pk);
- } else {
- within_body_sse42(-s2, J.H - 1 - i, -j, S2, J.W, pk);
- }
- }
-
- mov(w, J.W - J.size + 1);
- Label label_b;
- L(label_b);
- if (isa == avx2) {
- within_body(-s2, J.H - 1 - i, -s2, S2, J.W,
- ysum, ydst, ytmp, ysum2, pk);
- } else {
- within_body_sse42(-s2, J.H - 1 - i, -s2, S2, J.W, pk);
- }
- dec(w);
- cmp(w, 0);
- jne(label_b, T_NEAR);
-
- for (int j = J.W - S2; j < J.W; ++j) {
- if (isa == avx2) {
- within_body(-s2, J.H - 1 - i, -s2, J.W - 1 - j, J.W,
- ysum, ydst, ytmp, ysum2, pk);
- } else {
- within_body_sse42(-s2, J.H - 1 - i, -s2, J.W - 1 - j, J.W, pk);
- }
- }
- }
-
- this->postamble();
-
- ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
- this->getCode()));
-}
-
-template<>
-jit_uni_lrn_fwd_kernel_f32<avx2>::jit_uni_lrn_fwd_kernel_f32(
- const struct nchw8c_across &J,
- float A,
- float K,
- prop_kind_t pk,
- void *code_ptr,
- size_t code_size)
- : jit_generator(code_ptr, code_size)
- , alpha(A), k(K)
-{
- Xbyak::Reg64 t = rsp;
- Xbyak::Reg64 hw = r9;
- Xbyak::Xmm xsrc_prev = xmm2;
- Xbyak::Ymm ysrc = ymm3;
- Xbyak::Ymm yc = ymm3;
- Xbyak::Xmm xsrc_next = xmm4;
- Xbyak::Ymm ya = ymm5;
- Xbyak::Ymm yb = ymm6;
- Xbyak::Ymm yd = ymm7;
- Xbyak::Ymm ye = ymm8;
- Xbyak::Ymm ysum = ymm9;
- Xbyak::Ymm ysum2 = ymm10;
- Xbyak::Ymm ydst = ymm11;
- Xbyak::Ymm ybase = ymm12;
-
- this->preamble();
-
- mov(src, ptr[this->param1 + 0]);
- mov(dst, ptr[this->param1 + 8]);
- if (pk != prop_kind::forward_inference)
- mov(scratch, ptr[this->param1 + 16]);
- sub(t, 64);
- mov(imm_addr64, float2int(this->alpha));
- movq(xalpha, imm_addr64);
- vbroadcastss(yalpha, xalpha);
-
- mov(imm_addr64, float2int(this->k));
- movq(xk, imm_addr64);
- vbroadcastss(yk, xk);
-
- if (J.version == -1)
- {
- vxorps(xsrc_prev, xsrc_prev, xsrc_prev);
- vmovups(ptr[t + 0], xsrc_prev);
- }
- if (J.version == +1)
- {
- vxorps(xsrc_next, xsrc_next, xsrc_next);
- vmovups(ptr[t + 48], xsrc_next);
- }
-
- mov(hw, J.H*J.W);
-
- Label lrn_loop;
- L(lrn_loop);
-
- if (J.version != -1) vmovups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]);
- vmovups(ysrc, ptr[src]);
- if (J.version != +1) vmovups(xsrc_next, ptr[src + J.H*J.W * 32]);
-
- if (J.version != -1) vmovups(ptr[t + 0], xsrc_prev);
- vmovups(ptr[t + 16], ysrc);
- if (J.version != +1) vmovups(ptr[t + 48], xsrc_next);
-
- vmovups(ya, ptr[t + 16 - 8]);
- vmovups(yb, ptr[t + 16 - 4]);
- vmovups(yd, ptr[t + 16 + 4]);
- vmovups(ye, ptr[t + 16 + 8]);
- vmulps(ysum, yc, yc);
- vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya*ya
- vfmadd231ps(ysum, yb, yb);
- vfmadd231ps(ysum, yd, yd);
- vfmadd231ps(ysum, ye, ye);
- vfmadd132ps(ysum, yk, yalpha); // ysum <- ysum*yalpha+yk
-
- vmovaps(ybase, ysum);
- if (pk != prop_kind::forward_inference)
- vmovups(ptr[scratch], ybase);
- vmulps(ysum2, ysum, ysum);
- vmulps(ysum, ysum, ysum2); // ysum = ybase^3;
- vsqrtps(ysum, ysum);
- vsqrtps(ysum, ysum); // ysum = ybase^0.75
- vdivps(ydst, ysrc, ysum); // ydst = ysrc / ysum
- vmovups(ptr[dst], ydst);
-
- add(src, 32);
- add(dst, 32);
- if (pk != prop_kind::forward_inference)
- add(scratch, 32);
- dec(hw);
- cmp(hw, 0);
- jne(lrn_loop, T_NEAR);
-
- add(t, 64);
- this->postamble();
-
- ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
- this->getCode()));
-}
-
-template<>
-jit_uni_lrn_fwd_kernel_f32<sse42>::jit_uni_lrn_fwd_kernel_f32(
- const struct nchw8c_across &J,
- float A,
- float K,
- prop_kind_t pk,
- void *code_ptr,
- size_t code_size)
- : jit_generator(code_ptr, code_size)
- , alpha(A), k(K)
-{
- Xbyak::Reg64 t = rsp;
- Xbyak::Reg64 hw = r9;
-
- Xbyak::Xmm xsrc_lo = xmm2;
- Xbyak::Xmm xsrc_hi = xmm3;
- Xbyak::Xmm xc_lo = xmm4;
- Xbyak::Xmm xc_hi = xmm5;
- Xbyak::Xmm xsum_lo = xc_lo;
- Xbyak::Xmm xsum_hi = xc_hi;
- Xbyak::Xmm xsrc_prev = xmm6;
- Xbyak::Xmm xsrc_next = xmm7;
- Xbyak::Xmm xa_lo = xmm8;
- Xbyak::Xmm xa_hi = xmm9;
- Xbyak::Xmm xb_lo = xmm10;
- Xbyak::Xmm xb_hi = xmm11;
- Xbyak::Xmm xd_lo = xmm12;
- Xbyak::Xmm xd_hi = xmm13;
- Xbyak::Xmm xe_lo = xmm14;
- Xbyak::Xmm xe_hi = xmm15;
- Xbyak::Xmm xbase_lo = xmm14;
- Xbyak::Xmm xbase_hi = xmm15;
-
- this->preamble();
-
- mov(src, ptr[this->param1 + 0]);
- mov(dst, ptr[this->param1 + 8]);
- if (pk != prop_kind::forward_inference)
- mov(scratch, ptr[this->param1 + 16]);
- sub(t, 64);
- mov(imm_addr64, float2int(this->alpha));
- movq(xalpha, imm_addr64);
- shufps(xalpha, xalpha, 0);
-
- mov(imm_addr64, float2int(this->k));
- movq(xk, imm_addr64);
- shufps(xk, xk, 0);
-
- if (J.version == -1)
- {
- xorps(xsrc_prev, xsrc_prev);
- movups(ptr[t + 0], xsrc_prev);
- }
- if (J.version == +1)
- {
- xorps(xsrc_next, xsrc_next);
- movups(ptr[t + 48], xsrc_next);
- }
-
- mov(hw, J.H*J.W);
- Label lrn_loop;
- L(lrn_loop);
-
- if (J.version != -1) movups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]);
- movups(xsrc_lo, ptr[src]);
- movups(xsrc_hi, ptr[src + 4 * sizeof(float)]);
- if (J.version != +1) movups(xsrc_next, ptr[src + J.H*J.W * 32]);
-
- if (J.version != -1) movups(ptr[t + 0], xsrc_prev);
- movups(ptr[t + 16], xsrc_lo);
- movups(ptr[t + 16 + 4 * sizeof(float)], xsrc_hi);
- if (J.version != +1) movups(ptr[t + 48], xsrc_next);
-
- movups(xa_lo, ptr[t + 16 - 8]);
- movups(xa_hi, ptr[t + 16 - 8 + 4 * sizeof(float)]);
- movups(xb_lo, ptr[t + 16 - 4]);
- movups(xb_hi, ptr[t + 16 - 4 + 4 * sizeof(float)]);
- movups(xd_lo, ptr[t + 16 + 4]);
- movups(xd_hi, ptr[t + 16 + 4 + 4 * sizeof(float)]);
- movups(xe_lo, ptr[t + 16 + 8]);
- movups(xe_hi, ptr[t + 16 + 8 + 4 * sizeof(float)]);
- movaps(xc_lo, xsrc_lo);
- movaps(xc_hi, xsrc_hi);
- mulps(xsum_lo, xc_lo);
- mulps(xsum_hi, xc_hi);
- mulps(xa_lo, xa_lo);
- mulps(xa_hi, xa_hi);
- addps(xsum_lo, xa_lo);
- addps(xsum_hi, xa_hi); // xsum <- xsum + xa*xa
- mulps(xb_lo, xb_lo);
- mulps(xb_hi, xb_hi);
- addps(xsum_lo, xb_lo);
- addps(xsum_hi, xb_hi);
- mulps(xd_lo, xd_lo);
- mulps(xd_hi, xd_hi);
- addps(xsum_lo, xd_lo);
- addps(xsum_hi, xd_hi);
- mulps(xe_lo, xe_lo);
- mulps(xe_hi, xe_hi);
- addps(xsum_lo, xe_lo);
- addps(xsum_hi, xe_hi);
-
- mulps(xsum_lo, xalpha);
- mulps(xsum_hi, xalpha);
- addps(xsum_lo, xk);
- addps(xsum_hi, xk); // xsum <- xsum*xalpha+xk
-
- movaps(xbase_lo, xsum_lo);
- movaps(xbase_hi, xsum_hi);
- if (pk != prop_kind::forward_inference) {
- movups(ptr[scratch], xbase_lo);
- movups(ptr[scratch + 4 * sizeof(float)], xbase_hi);
- }
- mulps(xsum_lo, xsum_lo);
- mulps(xsum_hi, xsum_hi);
- mulps(xsum_lo, xbase_lo);
- mulps(xsum_hi, xbase_hi); // xsum = xbase^3;
- sqrtps(xsum_lo, xsum_lo);
- sqrtps(xsum_hi, xsum_hi);
- sqrtps(xsum_lo, xsum_lo);
- sqrtps(xsum_hi, xsum_hi); // xsum = xbase^0.75
- divps(xsrc_lo, xsum_lo);
- divps(xsrc_hi, xsum_hi); // xdst = xsrc / xsum
- movups(ptr[dst], xsrc_lo);
- movups(ptr[dst + 4 * sizeof(float)], xsrc_hi);
-
- add(src, 32);
- add(dst, 32);
- if (pk != prop_kind::forward_inference)
- add(scratch, 32);
- dec(hw);
- cmp(hw, 0);
- jne(lrn_loop, T_NEAR);
-
- add(t, 64);
- this->postamble();
-
- ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
- this->getCode()));
-}
-
-template<>
-jit_uni_lrn_fwd_kernel_f32<avx2>::jit_uni_lrn_fwd_kernel_f32(
- const struct nhwc_across &J,
- float A,
- float K,
- prop_kind_t pk,
- void *code_ptr,
- size_t code_size)
- : jit_generator(code_ptr, code_size)
- , alpha(A), k(K)
-{
- static const uint32_t mask[] = {
- 0, 0, 0x80000000, 0x80000000, 0x80000000, 0x80000000,
- 0x80000000, 0x80000000, 0x80000000, 0, 0
- };
-
- Xbyak::Reg64 c = r9;
- Xbyak::Ymm ya = ymm2;
- Xbyak::Ymm yb = ymm3;
- Xbyak::Ymm yc = ymm4;
- Xbyak::Ymm yd = ymm5;
- Xbyak::Ymm ye = ymm6;
- Xbyak::Ymm ysum = ymm7;
- Xbyak::Ymm ydst = ymm8;
- Xbyak::Ymm ybase = ymm9;
- Xbyak::Ymm ymask = ymm10;
-
- this->preamble();
-
- mov(src, ptr[this->param1 + 0]);
- mov(dst, ptr[this->param1 + 8]);
- if (pk != prop_kind::forward_inference)
- mov(scratch, ptr[this->param1 + 16]);
- mov(imm_addr64, float2int(this->alpha));
- movq(xalpha, imm_addr64);
- vbroadcastss(yalpha, xalpha);
-
- mov(imm_addr64, float2int(this->k));
- movq(xk, imm_addr64);
- vbroadcastss(yk, xk);
-
- vxorps(ysum, ysum, ysum);
-
- mov(imm_addr64, reinterpret_cast<size_t>(&mask[0]));
- vmovups(ymask, ptr[imm_addr64]);
- vmaskmovps(ya, ymask, ptr[src - 8]);
- vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
-
- mov(imm_addr64, reinterpret_cast<size_t>(&mask[1]));
- vmovups(ymask, ptr[imm_addr64]);
- vmaskmovps(yb, ymask, ptr[src - 4]);
- vfmadd231ps(ysum, yb, yb);
-
- mov(c, J.C / 8 - 1);
- Label lrn_loop;
- L(lrn_loop);
-
- vmovups(yc, ptr[src]);
- vmovups(yd, ptr[src + 4]);
- vmovups(ye, ptr[src + 8]);
- vfmadd231ps(ysum, yc, yc);
- vfmadd231ps(ysum, yd, yd);
- vfmadd231ps(ysum, ye, ye);
-
- vmovups(ydst, ysum);
- vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk
-
- vmovaps(ybase, ydst);
- if (pk != prop_kind::forward_inference)
- vmovups(ptr[scratch], ybase);
- vmulps(ydst, ydst, ydst);
- vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3;
- vsqrtps(ydst, ydst);
- vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75
-
- vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75
- vmovups(ptr[dst], ydst);
-
- vxorps(ysum, ysum, ysum);
-
- add(src, 32);
- add(dst, 32);
- if (pk != prop_kind::forward_inference)
- add(scratch, 32);
-
- vmovups(ya, ptr[src - 8]);
- vfmadd231ps(ysum, ya, ya);
- vmovups(yb, ptr[src - 4]);
- vfmadd231ps(ysum, yb, yb);
-
- dec(c);
- cmp(c, 0);
- jne(lrn_loop, T_NEAR);
-
- vmovups(yc, ptr[src]);
- vfmadd231ps(ysum, yc, yc);
-
- mov(imm_addr64, reinterpret_cast<size_t>(&mask[2]));
- vmovups(ymask, ptr[imm_addr64]);
- vmaskmovps(yd, ymask, ptr[src + 4]);
- vfmadd231ps(ysum, yd, yd); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
-
- mov(imm_addr64, reinterpret_cast<size_t>(&mask[3]));
- vmovups(ymask, ptr[imm_addr64]);
- vmaskmovps(ye, ymask, ptr[src + 8]);
- vfmadd231ps(ysum, ye, ye);
-
- vmovups(ydst, ysum);
- vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk
-
- vmovaps(ybase, ydst);
- if (pk != prop_kind::forward_inference)
- vmovups(ptr[scratch], ybase);
- vmulps(ydst, ydst, ydst);
- vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3;
- vsqrtps(ydst, ydst);
- vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75
- vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75
-
- vmovups(ptr[dst], ydst);
-
- this->postamble();
-
- ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
- this->getCode()));
-}
-
-template<>
-jit_uni_lrn_fwd_kernel_f32<sse42>::jit_uni_lrn_fwd_kernel_f32(
- const struct nhwc_across &J,
- float A,
- float K,
- prop_kind_t pk,
- void *code_ptr,
- size_t code_size)
- : jit_generator(code_ptr, code_size)
- , alpha(A), k(K)
-{
- static const uint32_t mask[] = {
- 0, 0, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
- 0xffffffff, 0xffffffff, 0xffffffff, 0, 0
- };
-
- static uint32_t store[] = {
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
- };
- Xbyak::Reg64 c = r9;
-
- Xbyak::Xmm xdst_lo = xmm0;
- Xbyak::Xmm xdst_hi = xmm1;
- Xbyak::Xmm xa_lo = xmm2;
- Xbyak::Xmm xa_hi = xmm3;
- Xbyak::Xmm xb_lo = xmm2;
- Xbyak::Xmm xb_hi = xmm3;
- Xbyak::Xmm xc_lo = xmm4;
- Xbyak::Xmm xc_hi = xmm5;
- Xbyak::Xmm xd_lo = xmm6;
- Xbyak::Xmm xd_hi = xmm7;
- Xbyak::Xmm xe_lo = xmm8;
- Xbyak::Xmm xe_hi = xmm9;
- Xbyak::Xmm xsum_lo = xmm10;
- Xbyak::Xmm xsum_hi = xmm11;
- Xbyak::Xmm xmask_lo = xmm12;
- Xbyak::Xmm xmask_hi = xmm13;
- Xbyak::Xmm xbase_lo = xmm14;
- Xbyak::Xmm xbase_hi = xmm15;
-
- this->preamble();
-
- mov(src, ptr[this->param1 + 0]);
- mov(dst, ptr[this->param1 + 8]);
- if (pk != prop_kind::forward_inference)
- mov(scratch, ptr[this->param1 + 16]);
- mov(imm_addr64, float2int(this->alpha));
- movq(xalpha, imm_addr64);
- shufps(xalpha, xalpha, 0);
-
- mov(imm_addr64, float2int(this->k));
- movq(xk, imm_addr64);
- shufps(xk, xk, 0);
-
- mov(store_addr, reinterpret_cast<size_t>(&store[0]));
- and_(store_addr, -15);
- movups(ptr[store_addr], xalpha);
- movups(ptr[store_addr + 4 * sizeof(float)], xk);
-
- xorps(xsum_lo, xsum_lo);
- xorps(xsum_hi, xsum_hi);
-
- mov(imm_addr64, reinterpret_cast<size_t>(&mask[0]));
- movups(xmask_lo, ptr[imm_addr64]);
- movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
- movups(xa_lo, ptr[src - 8]);
- movups(xa_hi, ptr[src - 8 + 4 * sizeof(float)]);
- andps(xa_lo, xmask_lo);
- andps(xa_hi, xmask_hi);
- mulps(xa_lo, xa_lo);
- mulps(xa_hi, xa_hi);
- addps(xsum_lo, xa_lo);
- addps(xsum_hi, xa_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
-
- mov(imm_addr64, reinterpret_cast<size_t>(&mask[1]));
- movups(xmask_lo, ptr[imm_addr64]);
- movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
- movups(xb_lo, ptr[src - 4]);
- movups(xb_hi, ptr[src - 4 + 4 * sizeof(float)]);
- andps(xb_lo, xmask_lo);
- andps(xb_hi, xmask_hi);
- mulps(xb_lo, xb_lo);
- mulps(xb_hi, xb_hi);
- addps(xsum_lo, xb_lo);
- addps(xsum_hi, xb_hi);
-
- mov(c, J.C / 8 - 1);
- Label lrn_loop;
- L(lrn_loop);
-
- movups(xc_lo, ptr[src]);
- movups(xc_hi, ptr[src + 4 * sizeof(float)]);
- movups(xd_lo, ptr[src + 4]);
- movups(xd_hi, ptr[src + 4 + 4 * sizeof(float)]);
- movups(xe_lo, ptr[src + 8]);
- movups(xe_hi, ptr[src + 8 + 4 * sizeof(float)]);
- mulps(xc_lo, xc_lo);
- mulps(xc_hi, xc_hi);
- addps(xsum_lo, xc_lo);
- addps(xsum_hi, xc_hi);
- mulps(xd_lo, xd_lo);
- mulps(xd_hi, xd_hi);
- addps(xsum_lo, xd_lo);
- addps(xsum_hi, xd_hi);
- mulps(xe_lo, xe_lo);
- mulps(xe_hi, xe_hi);
- addps(xsum_lo, xe_lo);
- addps(xsum_hi, xe_hi);
-
- movaps(xdst_lo, xsum_lo);
- movaps(xdst_hi, xsum_hi);
- // xdst <- xsum*xalpha+xk
- mulps(xdst_lo, ptr[store_addr]);
- mulps(xdst_hi, ptr[store_addr]);
- addps(xdst_lo, ptr[store_addr + 4 * sizeof(float)]);
- addps(xdst_hi, ptr[store_addr + 4 * sizeof(float)]);
-
- movaps(xbase_lo, xdst_lo);
- movaps(xbase_hi, xdst_hi);
- if (pk != prop_kind::forward_inference) {
- movups(ptr[scratch], xbase_lo);
- movups(ptr[scratch + 4 * sizeof(float)], xbase_hi);
- }
- mulps(xdst_lo, xdst_lo);
- mulps(xdst_hi, xdst_hi);
- mulps(xdst_lo, xbase_lo);
- mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3;
- sqrtps(xdst_lo, xdst_lo);
- sqrtps(xdst_hi, xdst_hi);
- sqrtps(xdst_lo, xdst_lo);
- sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75
-
- movups(xc_lo, ptr[src]);
- movups(xc_hi, ptr[src + 4 * sizeof(float)]);
- divps(xc_lo, xdst_lo);
- divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75
- movups(ptr[dst], xc_lo);
- movups(ptr[dst + 4 * sizeof(float)], xc_hi);
-
- xorps(xsum_lo, xsum_lo);
- xorps(xsum_hi, xsum_hi);
-
- add(src, 32);
- add(dst, 32);
- if (pk != prop_kind::forward_inference)
- add(scratch, 32);
-
- movups(xa_lo, ptr[src - 8]);
- movups(xa_hi, ptr[src - 8 + 4 * sizeof(float)]);
- mulps(xa_lo, xa_lo);
- mulps(xa_hi, xa_hi);
- addps(xsum_lo, xa_lo);
- addps(xsum_hi, xa_hi);
- movups(xb_lo, ptr[src - 4]);
- movups(xb_hi, ptr[src - 4 + 4 * sizeof(float)]);
- mulps(xb_lo, xb_lo);
- mulps(xb_hi, xb_hi);
- addps(xsum_lo, xb_lo);
- addps(xsum_hi, xb_hi);
-
- dec(c);
- cmp(c, 0);
- jne(lrn_loop, T_NEAR);
-
- movups(xc_lo, ptr[src]);
- movups(xc_hi, ptr[src + 4 * sizeof(float)]);
- mulps(xc_lo, xc_lo);
- mulps(xc_hi, xc_hi);
- addps(xsum_lo, xc_lo);
- addps(xsum_hi, xc_hi);
-
- mov(imm_addr64, reinterpret_cast<size_t>(&mask[2]));
- movups(xmask_lo, ptr[imm_addr64]);
- movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
- movups(xd_lo, ptr[src + 4]);
- movups(xd_hi, ptr[src + 4 + 4 * sizeof(float)]);
- andps(xd_lo, xmask_lo);
- andps(xd_hi, xmask_hi);
- mulps(xd_lo, xd_lo);
- mulps(xd_hi, xd_hi);
- addps(xsum_lo, xd_lo);
- addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
-
- mov(imm_addr64, reinterpret_cast<size_t>(&mask[3]));
- movups(xmask_lo, ptr[imm_addr64]);
- movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
- movups(xe_lo, ptr[src + 8]);
- movups(xe_hi, ptr[src + 8 + 4 * sizeof(float)]);
- andps(xe_lo, xmask_lo);
- andps(xe_hi, xmask_hi);
- mulps(xe_lo, xe_lo);
- mulps(xe_hi, xe_hi);
- addps(xsum_lo, xe_lo);
- addps(xsum_hi, xe_hi);
-
- movups(xdst_lo, xsum_lo);
- movups(xdst_hi, xsum_hi);
- // xdst <- xsum*xalpha+xk
- mulps(xdst_lo, ptr[store_addr]);
- mulps(xdst_hi, ptr[store_addr]);
- addps(xdst_lo, ptr[store_addr + 4 * sizeof(float)]);
- addps(xdst_hi, ptr[store_addr + 4 * sizeof(float)]);
-
- movaps(xbase_lo, xdst_lo);
- movaps(xbase_hi, xdst_hi);
- if (pk != prop_kind::forward_inference) {
- movups(ptr[scratch], xbase_lo);
- movups(ptr[scratch + 4 * sizeof(float)], xbase_hi);
- }
- mulps(xdst_lo, xdst_lo);
- mulps(xdst_hi, xdst_hi);
- mulps(xdst_lo, xbase_lo);
- mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3;
- sqrtps(xdst_lo, xdst_lo);
- sqrtps(xdst_hi, xdst_hi);
- sqrtps(xdst_lo, xdst_lo);
- sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75
- movups(xc_lo, ptr[src]);
- movups(xc_hi, ptr[src + 4 * sizeof(float)]);
- divps(xc_lo, xdst_lo);
- divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75
-
- movups(ptr[dst], xc_lo);
- movups(ptr[dst + 4 * sizeof(float)], xc_hi);
-
- this->postamble();
-
- ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
- this->getCode()));
-}
-
-template<>
-void jit_uni_lrn_fwd_kernel_f32<sse42>::nchw_body(
- int tail, int HW, prop_kind_t pk,
- Xbyak::Ymm ymask,
- Xbyak::Ymm ya,
- Xbyak::Ymm yb,
- Xbyak::Ymm yc,
- Xbyak::Ymm yd,
- Xbyak::Ymm ye,
- Xbyak::Ymm ysum) {}
-
-template<>
-void jit_uni_lrn_fwd_kernel_f32<avx2>::nchw_body(
- int tail, int HW, prop_kind_t pk,
- Xbyak::Ymm ymask,
- Xbyak::Ymm ya,
- Xbyak::Ymm yb,
- Xbyak::Ymm yc,
- Xbyak::Ymm yd,
- Xbyak::Ymm ye,
- Xbyak::Ymm ysum)
-{
- Xbyak::Ymm ydst = ymm14;
- Xbyak::Ymm ybase = ymm15;
-
- vfmadd231ps(ysum, ye, ye);
-
- vmovups(ydst, ysum);
- vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk
-
- vmovaps(ybase, ydst);
- if (pk != prop_kind::forward_inference)
- {
- if (tail != 0)
- vmaskmovps(ptr[scratch], ymask, ybase);
- else
- vmovups(ptr[scratch], ybase);
- }
- vmulps(ydst, ydst, ydst);
- vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3;
- vsqrtps(ydst, ydst);
- vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75
- vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75
-
- if (tail != 0)
- vmaskmovps(ptr[dst], ymask, ydst);
- else
- vmovups(ptr[dst], ydst);
-
-
- vfnmadd231ps(ysum, ya, ya);
- vmovups(ya, yb);
- vmovups(yb, yc);
- vmovups(yc, yd);
- vmovups(yd, ye);
-}
-
-template<>
-void jit_uni_lrn_fwd_kernel_f32<avx2>::nchw_tail_sse42(
- int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi)
-{}
-
-template<>
-void jit_uni_lrn_fwd_kernel_f32<sse42>::nchw_tail_sse42(
- int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi)
-{
- Xbyak::Xmm xmm_tmp = xmm10;
- movaps(xmm_tmp, xtail_lo);
- size_t offset = 0;
-
- if (tail > 4) {
- movups(ptr[reg_dst], xtail_lo);
- movaps(xmm_tmp, xtail_hi);
- offset += 4 * sizeof(float);
- tail -= 4;
- }
- movss(ptr[reg_dst + offset], xmm_tmp);
- for (int i = 1; i < tail; i++)
- {
- psrldq(xmm_tmp, 4);
- movss(ptr[reg_dst + offset + i * sizeof(float)], xmm_tmp);
- }
-}
-
-template<>
-void jit_uni_lrn_fwd_kernel_f32<sse42>::nchw_body_sse42(
- int tail, int HW, prop_kind_t pk,
- Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi,
- Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi,
- Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi)
-{
- Xbyak::Xmm xdst_lo = xmm0;
- Xbyak::Xmm xdst_hi = xmm1;
- Xbyak::Xmm xbase_lo = xmm6;
- Xbyak::Xmm xbase_hi = xmm7;
- Xbyak::Xmm xtmp_lo = xmm8;
- Xbyak::Xmm xtmp_hi = xmm9;
- Xbyak::Xmm xa_lo = xmm6;
- Xbyak::Xmm xa_hi = xmm7;
- Xbyak::Xmm xb_lo = xmm8;
- Xbyak::Xmm xb_hi = xmm9;
- Xbyak::Xmm xc_lo = xmm10;
- Xbyak::Xmm xc_hi = xmm11;
- Xbyak::Xmm xd_lo = xmm12;
- Xbyak::Xmm xd_hi = xmm13;
-
- // store xe
- movaps(ptr[store_addr + 10 * 4 * sizeof(float)], xe_lo);
- movaps(ptr[store_addr + 11 * 4 * sizeof(float)], xe_hi);
-
- mulps(xe_lo, xe_lo);
- mulps(xe_hi, xe_hi);
- addps(xsum_lo, xe_lo);
- addps(xsum_hi, xe_hi);
-
- // xdst <- xsum*xalpha+xk
- movaps(xdst_lo, xsum_lo);
- movaps(xdst_hi, xsum_hi);
- mulps(xdst_lo, ptr[store_addr + 0 * 4 * sizeof(float)]);
- mulps(xdst_hi, ptr[store_addr + 0 * 4 * sizeof(float)]);
- addps(xdst_lo, ptr[store_addr + 1 * 4 * sizeof(float)]);
- addps(xdst_hi, ptr[store_addr + 1 * 4 * sizeof(float)]);
-
- movaps(xbase_lo, xdst_lo);
- movaps(xbase_hi, xdst_hi);
- if (pk != prop_kind::forward_inference)
- {
- if (tail != 0) {
- nchw_tail_sse42(tail, scratch, xbase_lo, xbase_hi);
- }
- else {
- movups(ptr[scratch], xbase_lo);
- movups(ptr[scratch + 4 * sizeof(float)], xbase_hi);
- }
- }
- mulps(xdst_lo, xdst_lo);
- mulps(xdst_hi, xdst_hi);
- mulps(xdst_lo, xbase_lo);
- mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3;
- sqrtps(xdst_lo, xdst_lo);
- sqrtps(xdst_hi, xdst_hi);
- sqrtps(xdst_lo, xdst_lo);
- sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75
- movaps(xtmp_lo, ptr[store_addr + 6 * 4 * sizeof(float)]);
- movaps(xtmp_hi, ptr[store_addr + 7 * 4 * sizeof(float)]);
- divps(xtmp_lo, xdst_lo);
- divps(xtmp_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75
- movaps(xdst_lo, xtmp_lo);
- movaps(xdst_hi, xtmp_hi);
-
- if (tail != 0) {
- nchw_tail_sse42(tail, dst, xdst_lo, xdst_hi);
- }
- else {
- movups(ptr[dst], xdst_lo);
- movups(ptr[dst + 4 * sizeof(float)], xdst_hi);
- }
-
- movaps(xa_lo, ptr[store_addr + 2 * 4 * sizeof(float)]);
- movaps(xa_hi, ptr[store_addr + 3 * 4 * sizeof(float)]);
- mulps(xa_lo, xa_lo);
- mulps(xa_hi, xa_hi);
- subps(xsum_lo, xa_lo);
- subps(xsum_hi, xa_hi);
-
- // xa <- xb
- movaps(xb_lo, ptr[store_addr + 4 * 4 * sizeof(float)]);
- movaps(xb_hi, ptr[store_addr + 5 * 4 * sizeof(float)]);
- movaps(ptr[store_addr + 2 * 4 * sizeof(float)], xb_lo);
- movaps(ptr[store_addr + 3 * 4 * sizeof(float)], xb_hi);
-
- // xb <- xc
- movaps(xc_lo, ptr[store_addr + 6 * 4 * sizeof(float)]);
- movaps(xc_hi, ptr[store_addr + 7 * 4 * sizeof(float)]);
- movaps(ptr[store_addr + 4 * 4 * sizeof(float)], xc_lo);
- movaps(ptr[store_addr + 5 * 4 * sizeof(float)], xc_hi);
-
- // xc <- xd
- movaps(xd_lo, ptr[store_addr + 8 * 4 * sizeof(float)]);
- movaps(xd_hi, ptr[store_addr + 9 * 4 * sizeof(float)]);
- movaps(ptr[store_addr + 6 * 4 * sizeof(float)], xd_lo);
- movaps(ptr[store_addr + 7 * 4 * sizeof(float)], xd_hi);
-
- // xd <- xe
- movaps(xe_lo, ptr[store_addr + 10 * 4 * sizeof(float)]);
- movaps(xe_hi, ptr[store_addr + 11 * 4 * sizeof(float)]);
- movaps(ptr[store_addr + 8 * 4 * sizeof(float)], xe_lo);
- movaps(ptr[store_addr + 9 * 4 * sizeof(float)], xe_hi);
-}
-
-template<>
-void jit_uni_lrn_fwd_kernel_f32<avx2>::nchw_body_sse42(
- int tail, int HW, prop_kind_t pk,
- Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi,
- Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi,
- Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi) {}
-
-template<>
-jit_uni_lrn_fwd_kernel_f32<avx2>::jit_uni_lrn_fwd_kernel_f32(
- struct nchw_across J,
- float A,
- float K,
- prop_kind_t pk,
- void* code_ptr,
- size_t code_size)
- : jit_generator(code_ptr, code_size)
- , alpha(A), k(K)
-{
- static const uint32_t mask[] = {
- 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000,
- 0x80000000, 0x80000000, 0, 0, 0, 0, 0, 0, 0
- };
- Xbyak::Reg64 c = r10;
- Xbyak::Ymm ymask = ymm2;
- Xbyak::Ymm ye = ymm3;
- Xbyak::Ymm ya = ymm4;
- Xbyak::Ymm yb = ymm5;
- Xbyak::Ymm yc = ymm6;
- Xbyak::Ymm yd = ymm7;
- Xbyak::Ymm ysum = ymm8;
-
- this->preamble();
-
- if (J.tail != 0)
- {
- mov(imm_addr64, reinterpret_cast<size_t>(&mask[7 - J.tail]));
- vmovups(ymask, ptr[imm_addr64]);
- }
- mov(imm_addr64, float2int(this->alpha));
- movq(xalpha, imm_addr64);
- vbroadcastss(yalpha, xalpha);
-
- mov(imm_addr64, float2int(this->k));
- movq(xk, imm_addr64);
- vbroadcastss(yk, xk);
-
- mov(src, ptr[this->param1 + 0]);
- mov(dst, ptr[this->param1 + 8]);
- if (pk != prop_kind::forward_inference)
- mov(scratch, ptr[this->param1 + 16]);
-
- vxorps(ya, ya, ya);
- vxorps(yb, yb, yb);
- if (J.tail != 0)
- vmaskmovps(yc, ymask, ptr[src + J.HW * 0]);
- else
- vmovups(yc, ptr[src + J.HW * 0]);
- if (J.tail != 0)
- vmaskmovps(yd, ymask, ptr[src + J.HW * 4]);
- else
- vmovups(yd, ptr[src + J.HW * 4]);
-
- vxorps(ysum, ysum, ysum);
- vfmadd231ps(ysum, yc, yc); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
- vfmadd231ps(ysum, yd, yd);
-
- mov(c, J.C - 2);
- Label lrn_loop;
- L(lrn_loop);
-
- if (J.tail != 0)
- vmaskmovps(ye, ymask, ptr[src + J.HW * 8]);
- else
- vmovups(ye, ptr[src + J.HW * 8]);
-
- nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum);
-
- add(src, J.HW * 4);
- add(dst, J.HW * 4);
- if (pk != prop_kind::forward_inference)
- add(scratch, J.HW * 4);
- dec(c);
- cmp(c, 0);
- jne(lrn_loop, T_NEAR);
-
- vxorps(ye, ye, ye);
-
- nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum);
- add(src, J.HW * 4);
- add(dst, J.HW * 4);
- if (pk != prop_kind::forward_inference)
- add(scratch, J.HW * 4);
-
- nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum);
-
- this->postamble();
-
- ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
- this->getCode()));
-}
-
-template<>
-jit_uni_lrn_fwd_kernel_f32<sse42>::jit_uni_lrn_fwd_kernel_f32(
- struct nchw_across J,
- float A,
- float K,
- prop_kind_t pk,
- void* code_ptr,
- size_t code_size)
- : jit_generator(code_ptr, code_size)
- , alpha(A), k(K)
-{
- static const uint32_t mask[] = {
- 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
- 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0
- };
-
- Xbyak::Reg64 c = r10;
-
- Xbyak::Xmm xmask_lo = xmm2;
- Xbyak::Xmm xmask_hi = xmm3;
- Xbyak::Xmm xsum_lo = xmm4;
- Xbyak::Xmm xsum_hi = xmm5;
- Xbyak::Xmm xa_lo = xmm6;
- Xbyak::Xmm xa_hi = xmm7;
- Xbyak::Xmm xb_lo = xmm8;
- Xbyak::Xmm xb_hi = xmm9;
- Xbyak::Xmm xc_lo = xmm10;
- Xbyak::Xmm xc_hi = xmm11;
- Xbyak::Xmm xd_lo = xmm12;
- Xbyak::Xmm xd_hi = xmm13;
- Xbyak::Xmm xe_lo = xmm14;
- Xbyak::Xmm xe_hi = xmm15;
-
- this->preamble();
-
- mov(src, ptr[this->param1 + 0]);
- mov(dst, ptr[this->param1 + 8]);
- if (pk != prop_kind::forward_inference)
- mov(scratch, ptr[this->param1 + 16]);
-
- sub(rsp, stack_space_needed);
- mov(store_addr, rsp);
- and_(store_addr, -15);
-
- mov(imm_addr64, float2int(this->alpha));
- movq(xalpha, imm_addr64);
- shufps(xalpha, xalpha, 0);
-
- mov(imm_addr64, float2int(this->k));
- movq(xk, imm_addr64);
- shufps(xk, xk, 0);
-
- // put alpha and k into store (free up regs)
- movaps(ptr[store_addr + 0 * 4 * sizeof(float)], xalpha);
- movaps(ptr[store_addr + 1 * 4 * sizeof(float)], xk);
-
- if (J.tail != 0)
- {
- mov(imm_addr64, reinterpret_cast<size_t>(&mask[7 - J.tail]));
- movups(xmask_lo, ptr[imm_addr64]);
- movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
- }
- // init xa, xb
- xorps(xa_lo, xa_lo);
- xorps(xa_hi, xa_hi);
- xorps(xb_lo, xb_lo);
- xorps(xb_hi, xb_hi);
-
- // read xc, xd
- if (J.tail != 0) {
- movups(xc_lo, ptr[src + J.HW * 0]);
- movups(xc_hi, ptr[src + J.HW * 0 + 4 * sizeof(float)]);
- andps(xc_lo, xmask_lo);
- andps(xc_hi, xmask_hi);
- }
- else {
- movups(xc_lo, ptr[src + J.HW * 0]);
- movups(xc_hi, ptr[src + J.HW * 0 + 4 * sizeof(float)]);
- }
- if (J.tail != 0) {
- movups(xd_lo, ptr[src + J.HW * 4]);
- movups(xd_hi, ptr[src + J.HW * 4 + 4 * sizeof(float)]);
- andps(xd_lo, xmask_lo);
- andps(xd_hi, xmask_hi);
- }
- else {
- movups(xd_lo, ptr[src + J.HW * 4]);
- movups(xd_hi, ptr[src + J.HW * 4 + 4 * sizeof(float)]);
- }
-
- // put xa, xb, xc, xd into store to free-up regs
- movaps(ptr[store_addr + 2 * 4 * sizeof(float)], xa_lo);
- movaps(ptr[store_addr + 3 * 4 * sizeof(float)], xa_hi);
- movaps(ptr[store_addr + 4 * 4 * sizeof(float)], xb_lo);
- movaps(ptr[store_addr + 5 * 4 * sizeof(float)], xb_hi);
- movaps(ptr[store_addr + 6 * 4 * sizeof(float)], xc_lo);
- movaps(ptr[store_addr + 7 * 4 * sizeof(float)], xc_hi);
- movaps(ptr[store_addr + 8 * 4 * sizeof(float)], xd_lo);
- movaps(ptr[store_addr + 9 * 4 * sizeof(float)], xd_hi);
-
- xorps(xsum_lo, xsum_lo);
- xorps(xsum_hi, xsum_hi);
- mulps(xc_lo, xc_lo);
- mulps(xc_hi, xc_hi);
- addps(xsum_lo, xc_lo);
- addps(xsum_hi, xc_hi);
- mulps(xd_lo, xd_lo);
- mulps(xd_hi, xd_hi);
- addps(xsum_lo, xd_lo);
- addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
-
- mov(c, J.C - 2);
- Label lrn_loop;
- L(lrn_loop);
-
- if (J.tail != 0) {
- movups(xe_lo, ptr[src + J.HW * 8]);
- movups(xe_hi, ptr[src + J.HW * 8 + 4 * sizeof(float)]);
- andps(xe_lo, xmask_lo);
- andps(xe_hi, xmask_hi);
- }
- else {
- movups(xe_lo, ptr[src + J.HW * 8]);
- movups(xe_hi, ptr[src + J.HW * 8 + 4 * sizeof(float)]);
- }
-
- nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi,
- xe_lo, xe_hi,
- xsum_lo, xsum_hi);
-
- add(src, J.HW * 4);
- add(dst, J.HW * 4);
- if (pk != prop_kind::forward_inference)
- add(scratch, J.HW * 4);
- dec(c);
- cmp(c, 0);
- jne(lrn_loop, T_NEAR);
-
- xorps(xe_lo, xe_lo);
- xorps(xe_hi, xe_hi);
-
- nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi,
- xe_lo, xe_hi,
- xsum_lo, xsum_hi);
- add(src, J.HW * 4);
- add(dst, J.HW * 4);
- if (pk != prop_kind::forward_inference)
- add(scratch, J.HW * 4);
-
- nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi,
- xe_lo, xe_hi,
- xsum_lo, xsum_hi);
-
- add(rsp, stack_space_needed);
-
- this->postamble();
-
- ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
- this->getCode()));
-}
-
-//////////////////////////////////////////////////////////////////////////////
-// backward kernel
-template <cpu_isa_t isa>
-jit_uni_lrn_bwd_kernel_f32<isa>::jit_uni_lrn_bwd_kernel_f32(
- const struct nchw8c_across &J,
- float A,
- float B,
- int use_h_parallel,
- void *code_ptr,
- size_t code_size)
- : jit_generator(code_ptr, code_size)
- , nalphabeta(-2 * A*B)
- , use_h_parallelizm(use_h_parallel)
-{
- Xbyak::Reg64 t = rsp;
- Xbyak::Reg64 hw = r10;
-
- Xbyak::Xmm xsrc_prev = xmm1;
- Xbyak::Xmm xws_prev = xmm2;
- Xbyak::Xmm xdiffdst_prev = xmm3;
- Xbyak::Ymm ysrc = ymm4;
- Xbyak::Ymm yws = ymm5;
- Xbyak::Ymm ydiffdst = ymm6;
- Xbyak::Xmm xsrc_next = xmm7;
- Xbyak::Xmm xws_next = xmm8;
- Xbyak::Xmm xdiffdst_next = xmm9;
- Xbyak::Ymm ya = ymm10;
- Xbyak::Xmm xa = xmm10;
- Xbyak::Ymm yb = ymm11;
- Xbyak::Ymm yd = ymm12;
- Xbyak::Ymm ye = ymm13;
- Xbyak::Ymm ysum = ymm14;
- Xbyak::Ymm ydiffsrc = ymm15;
-
- this->preamble();
-
- mov(src, ptr[this->param1 + 0]);
- mov(diffdst, ptr[this->param1 + 8]);
- mov(workspace, ptr[this->param1 + 16]);
- mov(diffsrc, ptr[this->param1 + 24]);
-
- sub(t, 64);
- mov(imm_addr64, float2int(this->nalphabeta));
- movq(xnalphabeta, imm_addr64);
- vbroadcastss(ynalphabeta, xnalphabeta);
-
- bool is_single = J.version == 3;
- bool is_first = J.version == -1 || J.version == -2;
- bool is_last = J.version == +1 || J.version == -2;
-
- if (is_first || is_single) {
- vxorps(xsrc_prev, xsrc_prev, xsrc_prev);
- vmovups(ptr[t + 0], xsrc_prev);
- }
- if (is_last || is_single) {
- vxorps(xsrc_next, xsrc_next, xsrc_next);
- vmovups(ptr[t + 48], xsrc_next);
- }
- mov(hw, this->use_h_parallelizm ? J.W : J.H*J.W);
- Label lrn_loop;
- L(lrn_loop);
- {
- if (!is_first && !is_single) {
- vmovups(xws_prev, ptr[workspace - J.H*J.W * 32 + 16]);
- vmovups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]);
- vmovups(xdiffdst_prev, ptr[diffdst - J.H*J.W * 32 + 16]);
- vmulps(xa, xws_prev, xws_prev);
- vmulps(xa, xa, xws_prev);
- vsqrtps(xa, xa);
- vsqrtps(xa, xa);
- vmulps(xa, xa, xws_prev);
- vdivps(xsrc_prev, xsrc_prev, xa);
- vmulps(xdiffdst_prev, xdiffdst_prev, xsrc_prev);
- }
-
- vmovups(ysrc, ptr[src]);
- vmovups(yws, ptr[workspace]);
- vmovups(ydiffdst, ptr[diffdst]);
- vmulps(ya, yws, yws);
- vmulps(ya, ya, yws);
- vsqrtps(ya, ya);
- vsqrtps(ya, ya);
- vdivps(ydiffsrc, ydiffdst, ya);
- vdivps(ysum, ydiffsrc, yws);
- vmulps(ysum, ysum, ysrc);
-
- if (!is_last && !is_single) {
- vmovups(xws_next, ptr[workspace + J.H*J.W * 32]);
- vmovups(xsrc_next, ptr[src + J.H*J.W * 32]);
- vmovups(xdiffdst_next, ptr[diffdst + J.H*J.W * 32]);
- vmulps(xa, xws_next, xws_next);
- vmulps(xa, xa, xws_next);
- vsqrtps(xa, xa);
- vsqrtps(xa, xa);
- vmulps(xa, xa, xws_next);
- vdivps(xsrc_next, xsrc_next, xa);
- vdivps(xsrc_next, xsrc_next, xws_next);
- vmulps(xdiffdst_next, xdiffdst_next, xsrc_next);
- }
-
- if (!is_first && !is_single) vmovups(ptr[t + 0], xdiffdst_prev);
- vmovups(ptr[t + 16], ysum);
- if (!is_last && !is_single) vmovups(ptr[t + 48], xdiffdst_next);
-
- vmovups(ya, ptr[t + 16 - 8]);
- vmovups(yb, ptr[t + 16 - 4]);
- vaddps(ysum, ysum, ya);
- vmulps(ysrc, ysrc, ynalphabeta);
- vaddps(ysum, ysum, yb);
-
- vmovups(yd, ptr[t + 16 + 4]);
- vmovups(ye, ptr[t + 16 + 8]);
- vaddps(ysum, ysum, yd);
- vaddps(ysum, ysum, ye);
-
- vfmadd231ps(ydiffsrc, ysum, ysrc);
-
- vmovups(ptr[diffsrc], ydiffsrc);
-
- add(src, 32);
- add(diffsrc, 32);
- add(diffdst, 32);
- add(workspace, 32);
-
- dec(hw);
- cmp(hw, 0);
- jne(lrn_loop, T_NEAR);
- }
-
- add(t, 64);
- this->postamble();
-
- ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
- this->getCode()));
-}
-
-template struct jit_uni_lrn_fwd_kernel_f32<sse42>;
-template struct jit_uni_lrn_fwd_kernel_f32<avx2>;
-template struct jit_uni_lrn_bwd_kernel_f32<avx2>;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp
deleted file mode 100644
index 2b3ed43cd4..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp
+++ /dev/null
@@ -1,183 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_UNI_LRN_KERNEL_F32_HPP
-#define CPU_JIT_UNI_LRN_KERNEL_F32_HPP
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-
-#include "jit_generator.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace Xbyak;
-
-enum params { VECTOR_LENGTH = 8, MAX_LOCAL_SIZE = 32 };
-
-typedef struct {
- const float *src;
- float *dst, *scratch;
-} jit_args_fwd_t;
-
-typedef struct {
- const float *src, *diff_dst, *scratch;
- float *diff_src;
-} jit_args_bwd_t;
-
-struct nchw8c_across {
- /* version:
- * -1: channels 0..7,
- * 1: channels C-8 .. C-1,
- * 0: other channels
- * 3: channels only for this kernel(without prev and next)
- */
- int H, W, version;
- nchw8c_across(int h, int w, int v) : H(h), W(w), version(v) {}
-};
-
-struct nchw8c_within {
- int H, W, size;
- nchw8c_within(int h, int w, int s) : H(h), W(w), size(s) {}
-};
-
-struct nchw_across {
- int C, HW, tail;
- nchw_across(int c, int hw, int t) : C(c), HW(hw), tail(t) {}
-};
-
-struct nhwc_across {
- int C;
- nhwc_across(int c) : C(c) {}
-};
-
-template <cpu_isa_t isa>
-struct jit_uni_lrn_fwd_kernel_f32 : public jit_generator {
- Xbyak::Reg64 src = rax;
- Xbyak::Reg64 dst = r8;
- Xbyak::Reg64 scratch = rdx;
- Xbyak::Reg64 imm_addr64 = rbx;
- Xbyak::Reg64 store_addr = rbp;
-
- Xbyak::Xmm xalpha = xmm0;
- Xbyak::Ymm yalpha = ymm0;
- Xbyak::Xmm xk = xmm1;
- Xbyak::Ymm yk = ymm1;
-
- float alpha;
- float k;
-
- int stack_space_needed = 11 * 4 * sizeof(float) + 16;
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_fwd_kernel_f32)
-
- /* cpu specific part */
- using Vmm = typename utils::conditional<isa == avx2, Ymm, Zmm>::type;
-
- jit_uni_lrn_fwd_kernel_f32(
- const struct nchw8c_within &J,
- float A,
- float K,
- prop_kind_t pk,
- void *code_ptr = nullptr,
- size_t code_size = 4 * Xbyak::DEFAULT_MAX_CODE_SIZE);
- jit_uni_lrn_fwd_kernel_f32(
- const struct nchw8c_across &J,
- float A,
- float K,
- prop_kind_t pk,
- void *code_ptr = nullptr,
- size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE);
- jit_uni_lrn_fwd_kernel_f32(
- const struct nhwc_across &J,
- float A,
- float K,
- prop_kind_t pk,
- void *code_ptr = nullptr,
- size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE);
- jit_uni_lrn_fwd_kernel_f32(
- struct nchw_across J,
- float A,
- float K,
- prop_kind_t pk,
- void* code_ptr = nullptr,
- size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE);
-
- void within_body(
- int hoff, int Hoff, int woff, int Woff, int stride,
- Xbyak::Ymm ysum, Xbyak::Ymm ydst, Xbyak::Ymm ytmp, Xbyak::Ymm ysum2,
- prop_kind_t pk);
- void within_body_sse42(
- int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk);
-
-
- void nchw_body(int tail, int HW, prop_kind_t pk,
- Xbyak::Ymm ymask,
- Xbyak::Ymm ya,
- Xbyak::Ymm yb,
- Xbyak::Ymm yc,
- Xbyak::Ymm yd,
- Xbyak::Ymm ye,
- Xbyak::Ymm ysum);
- void nchw_body_sse42(int tail, int HW, prop_kind_t pk,
- Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi,
- Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi,
- Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi);
- void nchw_tail_sse42(int tail, Xbyak::Reg64 reg_dst,
- Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi);
-
- void operator()(jit_args_fwd_t *arg) { ker(arg); }
- void(*ker)(jit_args_fwd_t *);
-};
-
-template <cpu_isa_t isa>
-struct jit_uni_lrn_bwd_kernel_f32 : public jit_generator {
- Xbyak::Reg64 src = rax;
- Xbyak::Reg64 diffsrc = r8;
- Xbyak::Reg64 diffdst = r9;
- Xbyak::Reg64 workspace = rdx;
- Xbyak::Reg64 imm_addr64 = rsi;
-
- Xbyak::Xmm xnalphabeta = xmm0;
- Xbyak::Ymm ynalphabeta = ymm0;
-
- float nalphabeta;
-
- int use_h_parallelizm;
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_bwd_kernel_f32)
-
- jit_uni_lrn_bwd_kernel_f32(
- const struct nchw8c_across &J,
- float A,
- float B,
- int use_h_parallel,
- void *code_ptr = nullptr,
- size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE);
-
- void operator()(jit_args_bwd_t *arg) { ker(arg); }
- void(*ker)(jit_args_bwd_t *);
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp
deleted file mode 100644
index bf8e609d23..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp
+++ /dev/null
@@ -1,699 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-* Copyright 2018 YANDEX LLC
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "nstl.hpp"
-#include "utils.hpp"
-#include "cpu_pooling_pd.hpp"
-
-#include "jit_uni_pool_kernel_f32.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace Xbyak;
-using namespace alg_kind;
-
-#define GET_OFF(field) offsetof(jit_pool_call_s, field)
-
-template <cpu_isa_t isa>
-status_t jit_uni_pool_kernel_f32<isa>::init_conf(jit_pool_conf_t &jpp,
- const pooling_pd_t *ppd) {
- const auto &pd = *ppd->desc();
- const memory_desc_wrapper src_d(
- ppd->is_fwd() ? ppd->src_md() : ppd->diff_src_md());
- const memory_desc_wrapper dst_d(
- ppd->is_fwd() ? ppd->dst_md() : ppd->diff_dst_md());
-
- bool args_ok = true
- && mayiuse(isa)
- && utils::one_of(pd.alg_kind, pooling_max,
- pooling_avg_include_padding,
- pooling_avg_exclude_padding);
- if (!args_ok) return status::unimplemented;
-
- const int simd_w = isa == avx512_common ? 16 : 8;
- const int ndims = src_d.ndims();
-
- jpp.ndims = ndims;
- jpp.mb = src_d.dims()[0];
-
- jpp.c = utils::rnd_up(src_d.dims()[1], simd_w);
- if (jpp.c > src_d.padded_dims()[1])
- return status::unimplemented;
-
- jpp.id = (ndims == 5) ? src_d.dims()[2] : 1;
- jpp.ih = src_d.dims()[ndims-2];
- jpp.iw = src_d.dims()[ndims-1];
- jpp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
- jpp.oh = dst_d.dims()[ndims-2];
- jpp.ow = dst_d.dims()[ndims-1];
-
- jpp.stride_d = (ndims == 5 ) ? pd.strides[0] : 1;
- jpp.stride_h = pd.strides[ndims-4];
- jpp.stride_w = pd.strides[ndims-3];
- jpp.kd = (ndims == 5) ? pd.kernel[0] : 1;
- jpp.kh = pd.kernel[ndims-4];
- jpp.kw = pd.kernel[ndims-3];
-
- jpp.f_pad = (ndims == 5 ) ? pd.padding[0][0] : 0;
- jpp.t_pad = pd.padding[0][ndims-4];
- jpp.l_pad = pd.padding[0][ndims-3];
-
- jpp.alg = pd.alg_kind;
-
- jpp.is_training = pd.prop_kind == prop_kind::forward_training;
- jpp.is_backward = pd.prop_kind == prop_kind::backward_data;
- jpp.ind_dt = ppd->workspace_md()
- ? ppd->workspace_md()->data_type : data_type::undef;
-
- jpp.simple_alg = jpp.is_training
- || IMPLICATION(jpp.is_backward, jpp.kd <= jpp.stride_d);
-
- jpp.c_block = simd_w;
-
- jpp.nb_c = jpp.c / jpp.c_block;
- if (jpp.alg == pooling_max) {
- jpp.ur_w = isa == avx512_common ? 16 : 4;
- if (jpp.is_training)
- jpp.ur_w = isa == avx512_common ? 9 : 3;
- else if (jpp.is_backward)
- jpp.ur_w = isa == avx512_common ? 6 : 3;
- } else {
- if (jpp.is_backward)
- jpp.ur_w = isa == avx512_common ? 12 : 6;
- else
- jpp.ur_w = isa == avx512_common ? 24 : 12;
- }
- if (jpp.ow < jpp.ur_w) jpp.ur_w = jpp.ow;
- if (jpp.l_pad > jpp.ur_w) return status::unimplemented;
-
- jpp.ur_w_tail = jpp.ow % jpp.ur_w;
-
- return status::success;
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_pool_kernel_f32<isa>::maybe_recalculate_divisor(int jj,
- int ur_w, int pad_l, int pad_r) {
- if (jpp.alg == pooling_avg_exclude_padding) {
- int kw = jpp.kw;
- int stride_w = jpp.stride_w;
-
- int non_zero_kw = kw;
- non_zero_kw -= nstl::max(0, pad_l - jj*stride_w);
- non_zero_kw -= nstl::max(0, pad_r - (ur_w - 1 - jj)*stride_w);
-
- if (non_zero_kw != prev_kw) {
- mov(tmp_gpr, float2int((float)non_zero_kw));
- movq(xmm_tmp, tmp_gpr);
- uni_vbroadcastss(vmm_tmp, xmm_tmp);
- uni_vmulps(vmm_tmp, vmm_tmp, vmm_ker_area_h);
- prev_kw = non_zero_kw;
- }
- }
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_pool_kernel_f32<isa>::avg_step(int ur_w, int pad_l,
- int pad_r) {
-
- int iw = jpp.iw;
- int kw = jpp.kw;
- int stride_w = jpp.stride_w;
- int c_block = jpp.c_block;
- Label kd_label, kh_label;
-
- for (int jj = 0; jj < ur_w; jj++) {
- if (jpp.is_backward) {
- uni_vmovups(vreg(jj), ptr[reg_output + sizeof(float)*jj*c_block]);
- maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r);
- uni_vdivps(vreg(jj), vreg(jj), vmm_tmp);
- } else {
- uni_vpxor(vreg(jj), vreg(jj), vreg(jj));
- }
- }
-
- if (jpp.simple_alg && jpp.ndims == 5) {
- push(reg_input);
- push(reg_output);
- mov(aux_reg_input_d, reg_input);
- mov(ki, ptr[reg_param + GET_OFF(kd_padding)]);
- L(kd_label);
- mov(aux_reg_input, aux_reg_input_d);
- } else {
- mov(aux_reg_input, reg_input);
- }
-
- xor_(kj, kj);
- L(kh_label);
- {
- for (int ki = 0; ki < kw; ki++) {
- int jj_start = nstl::max(0, pad_l - ki);
- int jj_end = ur_w
- - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w);
- for (int jj = jj_start; jj < jj_end; jj++) {
- int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block;
- if (aux_input_offset > iw * c_block)
- continue;
- int input_offset = sizeof(float)*aux_input_offset;
- if (jpp.is_backward) {
- uni_vmovups(vreg(ur_w+jj),
- ptr[aux_reg_input + input_offset]);
- uni_vaddps(vreg(ur_w+jj), vreg(ur_w+jj), vreg(jj));
- uni_vmovups(vmmword[aux_reg_input + input_offset],
- vreg(ur_w+jj));
- } else {
- uni_vaddps(vreg(jj), vreg(jj),
- ptr[aux_reg_input + input_offset]);
- }
- }
- }
- add(aux_reg_input, sizeof(float) * iw * c_block);
- inc(kj);
- cmp(kj, reg_kh);
- jl(kh_label, T_NEAR);
- }
-
- if (jpp.simple_alg && jpp.ndims == 5)
- {
- add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block);
- dec(ki);
- cmp(ki, 0);
- jg(kd_label, T_NEAR);
- pop(reg_output);
- pop(reg_input);
- }
-
- if (!jpp.is_backward) {
- for (int jj = 0; jj < ur_w; jj++) {
- maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r);
- uni_vdivps(vreg(jj), vreg(jj), vmm_tmp);
- uni_vmovups(vmmword[reg_output + sizeof(float)*jj*c_block],
- vreg(jj));
- }
- }
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_pool_kernel_f32<isa>::max_step_fwd(int ur_w, int pad_l,
- int pad_r) {
- int iw = jpp.iw;
- int kw = jpp.kw;
- int stride_w = jpp.stride_w;
- int c_block = jpp.c_block;
- Label kd_label, kh_label;
-
- mov(tmp_gpr, float2int(nstl::numeric_limits<float>::lowest()));
- movq(xmm_tmp, tmp_gpr);
- uni_vbroadcastss(vmm_tmp, xmm_tmp);
-
- for (int jj = 0; jj < ur_w; jj++) {
- uni_vmovups(vreg(jj), vmm_tmp);
- if (jpp.is_training)
- uni_vpxor(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(2*ur_w+jj));
- }
- if (jpp.is_training)
- {
- movq(xmm_tmp, reg_k_shift);
- uni_vpbroadcastd(vmm_k_offset, xmm_tmp);
- }
-
- if (jpp.ndims == 5) {
- push(reg_input);
- push(reg_output);
- mov(aux_reg_input_d, reg_input);
- mov(ki, ptr[reg_param + GET_OFF(kd_padding)]);
- L(kd_label);
- mov(aux_reg_input, aux_reg_input_d);
- } else {
- mov(aux_reg_input, reg_input);
- }
- xor_(kj, kj);
- L(kh_label);
- {
- for (int ki = 0; ki < kw; ki++) {
- int jj_start = nstl::max(0, pad_l - ki);
- int jj_end = ur_w
- - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w);
- for (int jj = jj_start; jj < jj_end; jj++) {
- int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block;
- if (aux_input_offset > iw * c_block)
- continue;
- int input_offset = sizeof(float)*aux_input_offset;
- uni_vmovups(vreg(ur_w+jj), ptr[aux_reg_input + input_offset]);
- if (isa == sse42) {
- movups(vmm_mask, vreg(jj));
- cmpps(vmm_mask, vreg(ur_w+jj), _cmp_lt_os);
- blendvps(vreg(jj), vreg(ur_w+jj));
- if (jpp.is_training)
- blendvps(vreg(2*ur_w+jj), vmm_k_offset);
- } else if (isa == avx) {
- vcmpps(vreg(3*ur_w+jj), vreg(jj), vreg(ur_w+jj),
- _cmp_lt_os);
- vblendvps(vreg(jj), vreg(jj), vreg(ur_w+jj),
- vreg(3*ur_w+jj));
- if (jpp.is_training)
- vblendvps(vreg(2*ur_w+jj), vreg(2*ur_w+jj),
- vmm_k_offset, vreg(3*ur_w+jj));
- } else {
- vcmpps(k_store_mask, vreg(jj), vreg(ur_w+jj), _cmp_lt_os);
- vblendmps(vreg(jj) | k_store_mask, vreg(jj), vreg(ur_w+jj));
- if (jpp.is_training)
- vblendmps(vreg(2*ur_w+jj) | k_store_mask,
- vreg(2*ur_w+jj), vmm_k_offset);
- }
- }
- if (jpp.is_training) {
- if (isa == avx && !mayiuse(avx2)) {
- avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp);
- } else {
- uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one);
- }
- }
- }
- add(aux_reg_input, sizeof(float) * iw * c_block);
- inc(kj);
- cmp(kj, reg_kh);
- jl(kh_label, T_NEAR);
- }
-
- if (jpp.ndims == 5)
- {
- add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block);
- if (jpp.is_training) {
- mov(tmp_gpr, ptr[reg_param + GET_OFF(kd_padding_shift)]);
- movq(xmm_tmp, tmp_gpr);
- uni_vpbroadcastd(vmm_tmp, xmm_tmp);
- if (isa == avx && !mayiuse(avx2)) {
- Xmm t(vmm_mask.getIdx());
- avx_vpadd1(vmm_k_offset, xmm_tmp, t);
- } else {
- uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp);
- }
- }
-
- dec(ki);
- cmp(ki, 0);
- jg(kd_label, T_NEAR);
- pop(reg_output);
- pop(reg_input);
- }
-
- for (int jj = 0; jj < ur_w; jj++) {
- uni_vmovups(vmmword[reg_output + sizeof(float)*jj*c_block], vreg(jj));
- if (jpp.is_training) {
- const size_t step_index
- = jj * c_block * types::data_type_size(jpp.ind_dt);
-
- auto x = xreg(2 * ur_w + jj);
- if (jpp.ind_dt == data_type::u8) {
- if (isa == sse42) {
- for (int i = 0; i < 4; ++i)
- pextrb(ptr[reg_index + step_index + i], x, 4*i);
- } else if (isa == avx) {
- auto y = yreg(2 * ur_w + jj);
- if (jj == 0) {
- movd(xmm_tmp, reg_shuf_mask);
- uni_vpbroadcastd(vmm_tmp, xmm_tmp);
- }
- if (mayiuse(avx2)) {
- vpshufb(y, y, vmm_tmp);
- movd(ptr[reg_index + step_index], x);
- vperm2i128(y, y, y, 0x1u);
- movd(ptr[reg_index + step_index + 4], x);
- } else {
- Xmm t(vmm_mask.getIdx());
- vextractf128(t, y, 0);
- vpshufb(t, t, xmm_tmp);
- movd(ptr[reg_index + step_index], t);
- vextractf128(t, y, 1);
- vpshufb(t, t, xmm_tmp); // ymm_tmp[:128]==ymm_tmp[127:0]
- movd(ptr[reg_index + step_index + 4], t);
- }
- } else {
- auto v = vreg(2 * ur_w + jj);
- vpmovusdb(x, v);
- vmovups(ptr[reg_index + step_index], v | k_index_mask);
- }
- } else {
- uni_vmovups(ptr[reg_index + step_index], vreg(2*ur_w+jj));
- }
- }
- }
-}
-
-template <cpu_isa_t isa>
-inline void jit_uni_pool_kernel_f32<isa>::max_step_bwd(int ur_w, int pad_l,
- int pad_r) {
-
- int iw = jpp.iw;
- int kw = jpp.kw;
- int stride_w = jpp.stride_w;
- int c_block = jpp.c_block;
- Label kd_label, kh_label;
-
- for (int jj = 0; jj < ur_w; jj++) {
- uni_vmovups(vreg(jj), ptr[reg_output + sizeof(float)*jj*c_block]);
-
- const size_t step_index
- = jj * c_block * types::data_type_size(jpp.ind_dt);
- if (jpp.ind_dt == data_type::u8) {
- if (isa == sse42) {
- movd(xreg(ur_w+jj), ptr[reg_index + step_index]);
- pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj));
- } else if (isa == avx) {
- movq(xreg(ur_w+jj), ptr[reg_index + step_index]);
- if (!mayiuse(avx2)) {
- avx_pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj), xmm_tmp);
- } else {
- vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj));
- }
- } else {
- vmovups(vreg(ur_w+jj) | k_index_mask,
- ptr[reg_index + step_index]);
- vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj));
- }
- } else {
- uni_vmovups(vreg(ur_w+jj), ptr[reg_index + step_index]);
- }
- }
- movq(xmm_tmp, reg_k_shift);
- uni_vpbroadcastd(vmm_k_offset, xmm_tmp);
-
- if (jpp.simple_alg && jpp.ndims == 5) {
- push(reg_input);
- push(reg_output);
- if (isa == sse42) {
- // Save rdi since it is used in maskmovdqu
- assert(dst_ptr == rdi);
- push(dst_ptr);
- }
- mov(aux_reg_input_d, reg_input);
- mov(ki, ptr[reg_param + GET_OFF(kd_padding)]);
- mov(reg_kd_pad_shift, ptr[reg_param + GET_OFF(kd_padding_shift)]);
- L(kd_label);
- mov(aux_reg_input, aux_reg_input_d);
- } else {
- mov(aux_reg_input, reg_input);
- }
-
- xor_(kj, kj);
- L(kh_label);
- {
- for (int ki = 0; ki < kw; ki++) {
- int jj_start = nstl::max(0, pad_l - ki);
- int jj_end = ur_w
- - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w);
- for (int jj = jj_start; jj < jj_end; jj++) {
- int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block;
- if (aux_input_offset > iw * c_block)
- continue;
- int input_offset = sizeof(float)*aux_input_offset;
- uni_vmovups(vreg(2*ur_w+jj), ptr[aux_reg_input + input_offset]);
- if (isa == sse42) {
- mov(dst_ptr, aux_reg_input);
- add(dst_ptr, input_offset);
-
- movups(vreg(3*ur_w+jj), vreg(ur_w+jj));
- pcmpeqd(vreg(3*ur_w+jj), vmm_k_offset);
- addps(vreg(2*ur_w+jj), vreg(jj));
- maskmovdqu(vreg(2*ur_w+jj), vreg(3*ur_w+jj));
- } else if (isa == avx) {
- if (mayiuse(avx2)) {
- vpcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset);
- } else {
- avx_pcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset, xmm_tmp);
- }
- vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(jj));
- vmaskmovps(vmmword[aux_reg_input + input_offset],
- vreg(3*ur_w+jj), vreg(2*ur_w+jj));
- } else {
- vpcmpeqd(k_store_mask, vreg(ur_w+jj), vmm_k_offset);
- vblendmps(vmm_tmp | k_store_mask | T_z, vreg(jj), vreg(jj));
- vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vmm_tmp);
- vmovups(vmmword[aux_reg_input +
- sizeof(float)*aux_input_offset], vreg(2*ur_w+jj));
- }
- }
- if (isa == avx && !mayiuse(avx2)) {
- avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp);
- } else {
- uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one);
- }
- }
- add(aux_reg_input, sizeof(float) * iw * c_block);
- inc(kj);
- cmp(kj, reg_kh);
- jl(kh_label, T_NEAR);
- }
- if (jpp.simple_alg && jpp.ndims == 5)
- {
- add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block);
-
- mov(tmp_gpr, reg_kd_pad_shift);
- movq(xmm_tmp, tmp_gpr);
- uni_vpbroadcastd(vmm_tmp, xmm_tmp);
- if (isa == avx && !mayiuse(avx2)) {
- Xmm t(vmm_mask.getIdx());
- avx_vpadd1(vmm_k_offset, vmm_tmp, t);
- } else {
- uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp);
- }
-
- dec(ki);
- cmp(ki, 0);
- jg(kd_label, T_NEAR);
- if (isa == sse42) {
- // Save rdi since it is used in maskmovdqu
- assert(dst_ptr == rdi);
- pop(dst_ptr);
- }
- pop(reg_output);
- pop(reg_input);
- }
-}
-
-template <cpu_isa_t isa>
-void jit_uni_pool_kernel_f32<isa>::maybe_zero_diff_src() {
- assert(jpp.c_block * sizeof(float) % cpu_isa_traits<isa>::vlen == 0);
- Label l_skip, l_zero;
-
- auto reg_oh = tmp_gpr;
- mov(reg_oh, ptr[reg_param + GET_OFF(oh)]);
- cmp(reg_oh, 0);
- jz(l_skip, T_NEAR);
-
- if (jpp.ndims == 5) {
- mov(zero_size, ptr[reg_param + GET_OFF(oh)]);
- mov(tmp_gpr, jpp.ih * jpp.iw * jpp.c_block * sizeof(float));
- imul(zero_size, tmp_gpr);
- }
-
- auto vzero = vmm_tmp;
- uni_vpxor(vzero, vzero, vzero);
-
- auto reg_off = tmp_gpr;
- xor_(reg_off, reg_off);
-
- L(l_zero);
- {
- const int dim = jpp.iw * jpp.c_block * sizeof(float);
- for (int i = 0; i < dim; i += cpu_isa_traits<isa>::vlen)
- uni_vmovups(ptr[reg_input + reg_off + i], vzero);
- add(reg_off, dim);
- if (jpp.ndims == 5) cmp(reg_off, zero_size);
- else cmp(reg_off, jpp.ih * dim);
- jl(l_zero, T_NEAR);
- }
-
- L(l_skip);
-}
-
-template <cpu_isa_t isa>
-void jit_uni_pool_kernel_f32<isa>::generate() {
-
- this->preamble();
-
- int ow = jpp.ow;
- int iw = jpp.iw;
- int kw = jpp.kw;
- int kh = jpp.kh;
- int ur_w = jpp.ur_w;
- int c_block = jpp.c_block;
- int stride_w = jpp.stride_w;
- int l_pad = jpp.l_pad;
- int ur_w_tail = jpp.ur_w_tail;
-
- int n_oi = ow / ur_w;
-
- prev_kw = 0;
-
- int vlen = cpu_isa_traits<isa>::vlen;
-
-#if defined(_WIN32)
- // Always mimic the Unix ABI (see the note about maskmovdqu in the header
- // file).
- xor_(rdi, rcx);
- xor_(rcx, rdi);
- xor_(rdi, rcx);
-#endif
-
- mov(reg_input, ptr[reg_param + GET_OFF(src)]);
- mov(reg_output, ptr[reg_param + GET_OFF(dst)]);
- if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
- mov(reg_index, ptr[reg_param + GET_OFF(indices)]);
- mov(reg_kh, ptr[reg_param + GET_OFF(kh_padding)]);
- mov(reg_k_shift, ptr[reg_param + GET_OFF(kh_padding_shift)]);
- mov(reg_ker_area_h, ptr[reg_param + GET_OFF(ker_area_h)]);
-
- if (jpp.is_backward)
- maybe_zero_diff_src();
-
- if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) {
- mov(tmp_gpr, 1);
- movq(xmm_one, tmp_gpr);
- uni_vpbroadcastd(vmm_one, xmm_one);
-
- if (isa == avx) {
- mov(reg_shuf_mask, 0x0c080400);
- } else if (isa >= avx512_common) {
- mov(tmp_gpr.cvt32(), 0x000f);
- kmovw(k_index_mask, tmp_gpr.cvt32());
- }
- }
-
- int r_pad = nstl::max(0, ((ow-1)*stride_w) + kw - 1 - (iw + l_pad - 1));
- int r_pad1 = (ur_w*n_oi - 1)*stride_w + kw - 1 - (iw + l_pad - 1);
- if (r_pad1 > 0) n_oi--;
-
- if (jpp.alg == pooling_avg_exclude_padding) {
- movq(xmm_ker_area_h, reg_ker_area_h);
- uni_vpbroadcastd(vmm_ker_area_h, xmm_ker_area_h);
- }
-
- if (jpp.alg == pooling_avg_include_padding) {
- mov(tmp_gpr, float2int((float)(kw * kh * jpp.kd)));
- movq(xmm_tmp, tmp_gpr);
- uni_vpbroadcastd(vmm_tmp, xmm_tmp);
- }
- if (l_pad > 0) {
- n_oi--;
- if (n_oi < 0 && r_pad1 > 0) {
- step(ur_w, l_pad, r_pad1);
- } else {
- step(ur_w, l_pad, 0);
- }
-
- if (isa == sse42) {
- if (n_oi < 0 && r_pad1 > 0) {
- step_high_half(ur_w, l_pad, r_pad1);
- } else {
- step_high_half(ur_w, l_pad, 0);
- }
- }
-
- if (isa == sse42) {
- add(reg_input, sizeof(float)*(ur_w*stride_w-l_pad)*c_block - vlen);
- add(reg_output, sizeof(float)*ur_w*c_block - vlen);
- if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
- add(reg_index, (2 * ur_w - 1) * c_block / 2
- * types::data_type_size(jpp.ind_dt));
- } else {
- add(reg_input, sizeof(float)*(ur_w*stride_w - l_pad)*c_block);
- add(reg_output, sizeof(float)*ur_w*c_block);
- if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
- add(reg_index, ur_w * c_block
- * types::data_type_size(jpp.ind_dt));
- }
- }
-
- xor_(oi_iter, oi_iter);
- if (n_oi > 0) {
- Label ow_loop;
- L(ow_loop); {
- step(ur_w, 0, 0);
-
- if (isa == sse42) {
- step_high_half(ur_w, 0, 0);
- }
-
- if (isa == sse42) {
- add(reg_input, sizeof(float)*ur_w*stride_w*c_block - vlen);
- add(reg_output, sizeof(float)*ur_w*c_block - vlen);
- if (jpp.alg == pooling_max &&
- (jpp.is_training || jpp.is_backward))
- add(reg_index, (2 * ur_w - 1) * c_block / 2
- * types::data_type_size(jpp.ind_dt));
- } else {
- add(reg_input, sizeof(float)*ur_w*stride_w*c_block);
- add(reg_output, sizeof(float)*ur_w*c_block);
- if (jpp.alg == pooling_max &&
- (jpp.is_training || jpp.is_backward))
- add(reg_index, ur_w * c_block
- * types::data_type_size(jpp.ind_dt));
- }
-
- inc(oi_iter);
- cmp(oi_iter, n_oi);
- jl(ow_loop, T_NEAR);
- }
- }
-
- if (r_pad1 > 0 && n_oi >= 0) {
- step(ur_w, 0, r_pad1);
-
- if (isa == sse42) {
- step_high_half(ur_w, 0, r_pad1);
- }
-
- if (isa == sse42) {
- add(reg_input, sizeof(float)*ur_w*stride_w*c_block - vlen);
- add(reg_output, sizeof(float)*ur_w*c_block - vlen);
- if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
- add(reg_index, (2 * ur_w - 1) * c_block / 2
- * types::data_type_size(jpp.ind_dt));
- } else {
- add(reg_input, sizeof(float)*ur_w*stride_w*c_block);
- add(reg_output, sizeof(float)*ur_w*c_block);
- if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
- add(reg_index, ur_w * c_block
- * types::data_type_size(jpp.ind_dt));
- }
- }
-
- if (ur_w_tail != 0) {
- step(ur_w_tail, 0, r_pad);
-
- if (isa == sse42) {
- step_high_half(ur_w_tail, 0, r_pad);
- }
- }
-
- this->postamble();
-}
-
-template struct jit_uni_pool_kernel_f32<sse42>;
-template struct jit_uni_pool_kernel_f32<avx>; // implements both <avx> and <avx2>
-template struct jit_uni_pool_kernel_f32<avx512_common>;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp
deleted file mode 100644
index 992b526587..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp
+++ /dev/null
@@ -1,192 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-* Copyright 2018 YANDEX LLC
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef JIT_UNI_POOL_KERNEL_F32_HPP
-#define JIT_UNI_POOL_KERNEL_F32_HPP
-
-#include <cfloat>
-
-#include "c_types_map.hpp"
-#include "pooling_pd.hpp"
-#include "type_helpers.hpp"
-
-#include "jit_generator.hpp"
-#include "jit_primitive_conf.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace Xbyak;
-
-template <cpu_isa_t isa>
-struct jit_uni_pool_kernel_f32: public jit_generator {
- jit_uni_pool_kernel_f32(jit_pool_conf_t ajpp): jpp(ajpp)
- {
- this->generate();
- jit_ker = (decltype(jit_ker))this->getCode();
- }
-
- jit_pool_conf_t jpp;
-
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel_f32)
-
- void operator()(jit_pool_call_s *arg) { jit_ker(arg); }
- static status_t init_conf(jit_pool_conf_t &jbp, const pooling_pd_t *ppd);
-
-private:
- using Vmm = typename utils::conditional3<isa == sse42, Xmm, isa == avx,
- Ymm, Zmm>::type;
- Xmm xreg(int idx) { return Xmm((isa == avx512_common ? 31 : 15) - idx); }
- Ymm yreg(int idx) { return Ymm(xreg(idx).getIdx()); }
- Vmm vreg(int idx) { return Vmm(xreg(idx).getIdx()); }
-
- const AddressFrame &vmmword = (isa == sse42) ? xword :
- (isa == avx) ? yword : zword;
-
- Xmm vmm_mask = Xmm(0);
- Xmm xmm_ker_area_h = Xmm(2);
- Xmm xmm_one = Xmm(2);
- Xmm xmm_tmp = Xmm(3);
-
- Vmm vmm_ker_area_h = Vmm(2);
- Vmm vmm_one = Vmm(2);
- Vmm vmm_tmp = Vmm(3);
-
- Vmm vmm_k_offset = Vmm(1);
-
- Opmask k_index_mask = Opmask(6);
- Opmask k_store_mask = Opmask(7);
-
- // Here be some (tame) dragons. This kernel does not follow the regular
- // OS-agnostic ABI pattern because when isa is sse42 it uses maskmovdqu
- // instruction which has its destination hardcoded in rdi. Therefore:
- // - all registers are hardcoded
- // - on Windows rdi and rcx are swapped to mimic the Unix x86_64 ABI
- //
- // While this is only required by the backward pass, the quirk above
- // is applied to the forward pass as well to keep things simpler.
-
- using reg64_t = const Xbyak::Reg64;
- reg64_t reg_param = rdi; // Always mimic the Unix ABI
- reg64_t reg_input = r8;
- reg64_t aux_reg_input = r9;
- reg64_t reg_index = r10;
- reg64_t reg_output = r12;
- reg64_t reg_kd_pad_shift = r13;
- reg64_t dst_ptr = rdi; // Must be rdi due to maskmovdqu
-
- reg64_t kj = r14;
- reg64_t oi_iter = r15;
- reg64_t reg_kh = rax;
- reg64_t reg_k_shift = rbx;
- reg64_t tmp_gpr = rcx; // Must be rcx because rdi is used above
- reg64_t reg_ker_area_h = rdx;
-
- reg64_t zero_size = r15;
- reg64_t ki = r12;
- reg64_t aux_reg_input_d = r8;
-
- Xbyak::Reg32 reg_shuf_mask = esi;
-
- int prev_kw;
- void (*jit_ker)(jit_pool_call_s *);
-
- void maybe_recalculate_divisor(int jj, int ur_w, int pad_l, int pad_r);
- void avg_step(int ur_w, int pad_l, int pad_r);
- void max_step_fwd(int ur_w, int pad_l, int pad_r);
- void max_step_bwd(int ur_w, int pad_l, int pad_r);
-
- void maybe_zero_diff_src();
-
- void step(int ur_w, int pad_l, int pad_r) {
- if (jpp.alg == alg_kind::pooling_max) {
- if(jpp.is_backward)
- max_step_bwd(ur_w, pad_l, pad_r);
- else
- max_step_fwd(ur_w, pad_l, pad_r);
- }
- else
- avg_step(ur_w, pad_l, pad_r);
- }
-
- void step_high_half(int ur_w, int pad_l, int pad_r) {
- add(reg_input, sizeof(float) * 4);
- add(reg_output, sizeof(float) * 4);
- if (jpp.alg == alg_kind::pooling_max &&
- (jpp.is_training || jpp.is_backward))
- add(reg_index, types::data_type_size(jpp.ind_dt) * 4);
-
- step(ur_w, pad_l, pad_r);
- }
-
- void generate();
-
- void avx_vpadd1(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) {
- assert(y0.getIdx() != x1.getIdx());
- vextractf128(xtmp, y0, 0);
- vpaddd(xtmp, xtmp, x1);
- vinsertf128(y0, y0, xtmp, 0);
- vextractf128(xtmp, y0, 1);
- vpaddd(xtmp, xtmp, x1);
- vinsertf128(y0, y0, xtmp, 1);
- }
-
- void avx_vpadd1(const Xmm& x0, const Xmm& x1, const Xmm&) {
- assert(false /*function should not be used*/);
- paddd(x0, x1);
- }
-
- void avx_pmovzxbd(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) {
- Xmm x0(y0.getIdx());
- pshufd(xmm_tmp, x1, 1);
- pmovzxbd(x0, x1);
- pmovzxbd(xmm_tmp, xmm_tmp);
- vinsertf128(y0, y0, xmm_tmp, 1);
- }
-
- void avx_pmovzxbd(const Xmm& x0, const Xmm& x1, const Xmm&) {
- assert(false /*function should not be used*/);
- pmovzxbd(x0, x1);
- }
-
- void avx_pcmpeqd(const Ymm& y0, const Ymm& y1, const Ymm& y2, const Xmm& xtmp) {
- assert(y0.getIdx() != y1.getIdx());
- assert(y0.getIdx() != y2.getIdx());
- Xmm x0(y0.getIdx());
- Xmm x2(y2.getIdx());
- vextractf128(x0, y1, 1);
- vextractf128(xtmp, y2, 1);
- pcmpeqd(xtmp, x0);
- vextractf128(x0, y1, 0);
- pcmpeqd(x0, x2);
- vinsertf128(y0, y0, xtmp, 1);
- }
-
- void avx_pcmpeqd(const Xmm& x0, const Xmm& x1, const Xmm&, const Xmm&) {
- assert(false /*function should not be used*/);
- pcmpeqd(x0, x1);
- }
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp
deleted file mode 100644
index afbcf996d8..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp
+++ /dev/null
@@ -1,264 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "mkldnn_types.h"
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "nstl.hpp"
-
-#include "jit_uni_pooling.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <cpu_isa_t isa>
-void jit_uni_pooling_fwd_t<isa>::execute_forward(const data_t *src,
- data_t *dst, char *indices) const {
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper indices_d(pd()->workspace_md());
- const size_t ind_dt_size = indices
- ? types::data_type_size(indices_d.data_type()) : 0;
-
- const auto &jpp = pd()->jpp_;
-
- auto ker = [&](int n, int b_c, int oh) {
- auto arg = jit_pool_call_s();
-
- const int ij = oh * jpp.stride_h;
- const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
- const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
- const int ih = nstl::max(ij - jpp.t_pad, 0);
-
- arg.src = &src[src_d.blk_off(n, b_c, ih)];
- arg.dst = &dst[dst_d.blk_off(n, b_c, oh)];
- if (indices) {
- const size_t ind_off = indices_d.blk_off(n, b_c, oh);
- arg.indices = &indices[ind_off * ind_dt_size];
- }
- arg.oh = oh == 0;
- arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
- arg.kh_padding_shift = i_t_overflow*jpp.kw;
- arg.kw_padding = 0;
- arg.ker_area_h = (float)(jpp.kh -
- nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
- nstl::max(0, jpp.t_pad - oh*jpp.stride_h));
- (*kernel_)(&arg);
- };
-
- parallel_nd(jpp.mb, jpp.nb_c, jpp.oh,
- [&](int n, int b_c, int oh) {
- ker(n, b_c, oh);
- });
-}
-
-template <cpu_isa_t isa>
-void jit_uni_pooling_fwd_t<isa>::execute_forward_3d(const data_t *src,
- data_t *dst, char *indices) const {
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper indices_d(pd()->workspace_md());
- const size_t ind_dt_size = indices
- ? types::data_type_size(indices_d.data_type()) : 0;
-
- const auto &jpp = pd()->jpp_;
-
- auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow,
- int d_b_overflow) {
- auto arg = jit_pool_call_s();
-
- const int ij = oh * jpp.stride_h;
- const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
- const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
- const int ih = nstl::max(ij - jpp.t_pad, 0);
-
- arg.src = &src[src_d.blk_off(n, b_c, id, ih)];
- arg.dst = &dst[dst_d.blk_off(n, b_c, od, oh)];
- if (indices) {
- const size_t ind_off = indices_d.blk_off(n, b_c, od, oh);
- arg.indices = &indices[ind_off * ind_dt_size];
- }
- arg.oh = (oh + od == 0);
- arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow;
- arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
- arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh;
- arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw;
- arg.kw_padding = 0;
- arg.ker_area_h = (float)(jpp.kh -
- nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
- nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd -
- nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) -
- nstl::max(0, jpp.f_pad - od*jpp.stride_d));
-
-
- (*kernel_)(&arg);
- };
-
- parallel_nd(jpp.mb, jpp.nb_c, jpp.od,
- [&](int n, int b_c, int od) {
- const int ik = od * jpp.stride_d;
- const int d_t_overflow = nstl::max(0, jpp.f_pad-ik);
- const int d_b_overflow = nstl::max(jpp.id, ik+jpp.kd-jpp.f_pad)
- -jpp.id;
- const int id = nstl::max(ik - jpp.f_pad, 0);
- for (int oh = 0; oh < jpp.oh; ++oh) {
- ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow);
- }
- });
-}
-
-template <cpu_isa_t isa>
-void jit_uni_pooling_bwd_t<isa>::execute_backward(const data_t *diff_dst,
- const char *indices, data_t *diff_src) const {
- const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper indices_d(pd()->workspace_md());
- const size_t ind_dt_size = indices
- ? types::data_type_size(indices_d.data_type()) : 0;
-
- const auto &jpp = pd()->jpp_;
-
- auto ker = [&](int n, int b_c, int oh) {
- auto arg = jit_pool_call_s();
-
- const int ij = oh * jpp.stride_h;
- const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
- const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
- const int ih = nstl::max(ij - jpp.t_pad, 0);
-
- arg.src = &diff_src[diff_src_d.blk_off(n, b_c, ih)];
- arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, oh)];
- if (indices) {
- const size_t ind_off = indices_d.blk_off(n, b_c, oh);
- arg.indices = &indices[ind_off * ind_dt_size];
- }
- arg.oh = (oh == 0);
- arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
- arg.kh_padding_shift = i_t_overflow*jpp.kw;
- arg.kw_padding = 0;
- arg.ker_area_h = (float)(jpp.kh -
- nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
- nstl::max(0, jpp.t_pad - oh*jpp.stride_h));
-
- (*kernel_)(&arg);
- };
-
- parallel_nd(jpp.mb, jpp.nb_c, [&](int n, int b_c) {
- for (int oh = 0; oh < jpp.oh; ++oh) {
- ker(n, b_c, oh);
- }
- });
-}
-
-template <cpu_isa_t isa>
-void jit_uni_pooling_bwd_t<isa>::execute_backward_3d(const data_t *diff_dst,
- const char *indices, data_t *diff_src) const {
- const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper indices_d(pd()->workspace_md());
- const size_t ind_dt_size = indices
- ? types::data_type_size(indices_d.data_type()) : 0;
-
- const auto &jpp = pd()->jpp_;
-
- auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow,
- int d_b_overflow, int zero_size, int kd) {
- auto arg = jit_pool_call_s();
-
- const int ij = oh * jpp.stride_h;
- const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
- const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
- const int ih = nstl::max(ij - jpp.t_pad, 0);
-
- arg.src = &diff_src[diff_src_d.blk_off(n, b_c, id + kd, ih)];
- arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, od, oh)];
- if (indices) {
- const size_t ind_off = indices_d.blk_off(n, b_c, od, oh);
- arg.indices = &indices[ind_off * ind_dt_size];
- }
- arg.oh = zero_size;
- arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow;
- arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
- arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh
- + kd * jpp.kw * jpp.kh;
- arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw;
- arg.kw_padding = 0;
- arg.ker_area_h = (float)(jpp.kh -
- nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
- nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd -
- nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) -
- nstl::max(0, jpp.f_pad - od*jpp.stride_d));
-
- (*kernel_)(&arg);
- };
-
- if (jpp.simple_alg) {
-
- parallel_nd(jpp.mb, jpp.nb_c, jpp.od,
- [&](int n, int b_c, int od) {
- const int ik = od * jpp.stride_d;
- const int d_t_overflow = nstl::max(0, jpp.f_pad - ik);
- const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd
- - jpp.f_pad) - jpp.id;
- const int id = nstl::max(ik - jpp.f_pad, 0);
- int zero_s = jpp.stride_d - d_t_overflow - (nstl::max(
- jpp.id, ik + jpp.stride_d - jpp.f_pad) - jpp.id);
- for (int oh = 0; oh < jpp.oh; ++oh) {
- ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow,
- (oh == 0) ? zero_s : 0, 0);
- }
- });
- } else {
- ptrdiff_t nelems = (ptrdiff_t)jpp.mb * (ptrdiff_t)jpp.c
- * (ptrdiff_t)jpp.id * (ptrdiff_t)jpp.ih * (ptrdiff_t)jpp.iw;
-
- parallel_nd(nelems, [&](ptrdiff_t i) { diff_src[i] = 0.f; });
-
- for (int kd = 0; kd < jpp.kd; ++kd) {
- parallel_nd(jpp.mb, jpp.nb_c, [&](int n, int b_c) {
- for (int od = 0; od < jpp.od; ++od) {
- const int ik = od * jpp.stride_d;
- const int d_t_overflow = nstl::max(0, jpp.f_pad-ik);
- const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd
- - jpp.f_pad) - jpp.id;
- if (kd >= jpp.kd - d_t_overflow - d_b_overflow)
- continue;
- const int id = nstl::max(ik - jpp.f_pad, 0);
- for (int oh = 0; oh < jpp.oh; ++oh) {
- ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow,
- 0, kd);
- }
- }
- });
- }
- }
-}
-
-
-template struct jit_uni_pooling_fwd_t<sse42>;
-template struct jit_uni_pooling_bwd_t<sse42>;
-template struct jit_uni_pooling_fwd_t<avx>;
-template struct jit_uni_pooling_bwd_t<avx>;
-template struct jit_uni_pooling_fwd_t<avx512_common>;
-template struct jit_uni_pooling_bwd_t<avx512_common>;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp
deleted file mode 100644
index 57bebacdee..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp
+++ /dev/null
@@ -1,182 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_JIT_UNI_POOLING_HPP
-#define CPU_JIT_UNI_POOLING_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_pooling_pd.hpp"
-#include "cpu_primitive.hpp"
-
-#include "jit_uni_pool_kernel_f32.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <cpu_isa_t isa>
-struct jit_uni_pooling_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_pooling_fwd_pd_t {
- using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t;
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit:", isa, ""),
- jit_uni_pooling_fwd_t<isa>);
-
- status_t init() {
- using namespace utils;
-
- bool ok = true
- && set_default_params() == status::success
- && is_fwd()
- && !has_zero_dim_memory()
- && everyone_is(data_type::f32,
- src_md()->data_type,
- dst_md()->data_type)
- && attr()->has_default_values()
- && memory_desc_matches_tag(*src_md(), desired_fmt_tag())
- && memory_desc_matches_tag(*dst_md(), desired_fmt_tag());
- if (!ok) return status::unimplemented;
-
- bool is_training = desc_.prop_kind == prop_kind::forward_training;
- if (desc()->alg_kind == alg_kind::pooling_max && is_training)
- init_default_ws();
-
- return jit_uni_pool_kernel_f32<isa>::init_conf(jpp_, this);
- }
-
- format_tag_t desired_fmt_tag() {
- using namespace format_tag;
- return ndims() == 4
- ? isa == avx512_common ? nChw16c : nChw8c
- : isa == avx512_common ? nCdhw16c : nCdhw8c;
- }
-
- jit_pool_conf_t jpp_;
- };
-
- jit_uni_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd)
- { kernel_ = new jit_uni_pool_kernel_f32<isa>(pd()->jpp_); }
-
- ~jit_uni_pooling_fwd_t() { delete kernel_; }
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
- auto ws = CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE);
-
- if (pd()->ndims() == 5)
- execute_forward_3d(src, dst, ws);
- else
- execute_forward(src, dst, ws);
-
- return status::success;
- }
-
-private:
- void execute_forward(const data_t *src, data_t *dst, char *indices) const;
- void execute_forward_3d(const data_t *src, data_t *dst,
- char *indices) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- jit_uni_pool_kernel_f32<isa> *kernel_;
-};
-
-template <cpu_isa_t isa>
-struct jit_uni_pooling_bwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_pooling_bwd_pd_t {
- using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t;
-
- DECLARE_COMMON_PD_T(
- JIT_IMPL_NAME_HELPER("jit:", isa, ""),
- jit_uni_pooling_bwd_t<isa>);
-
- status_t init() {
- using namespace utils;
-
- bool ok = true
- && set_default_params() == status::success
- && !is_fwd()
- && !has_zero_dim_memory()
- && everyone_is(data_type::f32,
- diff_src_md()->data_type,
- diff_dst_md()->data_type)
- && attr()->has_default_values()
- && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag())
- && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag());
- if (!ok) return status::unimplemented;
-
- if (desc()->alg_kind == alg_kind::pooling_max) {
- init_default_ws();
- if (!compare_ws(hint_fwd_pd_))
- return status::unimplemented;
- }
-
- return jit_uni_pool_kernel_f32<isa>::init_conf(jpp_, this);
- }
-
- format_tag_t desired_fmt_tag() {
- using namespace format_tag;
- return ndims()
- ? isa == avx512_common ? nChw16c : nChw8c
- : isa == avx512_common ? nCdhw16c : nCdhw8c;
- }
-
- jit_pool_conf_t jpp_;
- };
-
- jit_uni_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd)
- { kernel_ = new jit_uni_pool_kernel_f32<isa>(pd()->jpp_); }
-
- ~jit_uni_pooling_bwd_t() { delete kernel_; }
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto ws = CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE);
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
-
- if (pd()->ndims() == 5)
- execute_backward_3d(diff_dst, ws, diff_src);
- else
- execute_backward(diff_dst, ws, diff_src);
-
- return status::success;
- }
-
-private:
- void execute_backward(const data_t *diff_dst, const char *indices,
- data_t *diff_src) const;
- void execute_backward_3d(const data_t *diff_dst, const char *indices,
- data_t *diff_src) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- jit_uni_pool_kernel_f32<isa> *kernel_;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp
deleted file mode 100644
index 98796503b7..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp
+++ /dev/null
@@ -1,1006 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "memory_desc_wrapper.hpp"
-#include "mkldnn_debug.h"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-
-#include "cpu_primitive.hpp"
-#include "cpu_reorder_pd.hpp"
-#include "jit_uni_reorder.hpp"
-
-#include "jit_generator.hpp"
-
-// #define TR_DEBUG
-#if defined(TR_DEBUG)
-#define DEBUg(...) do { __VA_ARGS__ } while (0)
-#else
-#define DEBUg(...)
-#endif
-#define DEBUG(...) DEBUg(__VA_ARGS__)
-
-#ifdef _WIN32
-/* seems like s_addr is a reserved macro on Windows */
-#undef s_addr
-#endif
-
-using namespace Xbyak;
-using namespace mkldnn::impl::types;
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-namespace tr {
-
-/** Minimal reasonable/desirable kernel size.
- * The constant might be used to determine how a problem should be split
- * between kernel and threading driver. */
-const size_t ker_prb_size_min = 64;
-
-/* kernel */
-struct jit_uni_reorder_kernel_f32: public kernel_t, public jit_generator {
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reorder_kernel_f32)
-
- enum {
- len_unroll_max = 256,
- ndims_jit_loop_max = 3,
- };
-
- struct simple_impl_desc_t {
- int ndims_full_unroll;
- int len_last_dim_unroll;
- int len_unroll;
- };
-
- static bool simple_impl_desc_init(const prb_t &prb,
- simple_impl_desc_t *desc) {
- const int ndims = prb.ndims;
-
- int ndims_full_unroll = 0;
- int len_last_dim_unroll = 1;
- int len_unroll = 1;
-
- for (int d = 0; d < ndims; ++d) {
- auto &node = prb.nodes[d];
- if (len_unroll * node.n <= len_unroll_max) {
- ndims_full_unroll++;
- len_unroll *= node.n;
- } else {
- len_last_dim_unroll = len_unroll_max / len_unroll;
- while (node.n % len_last_dim_unroll)
- --len_last_dim_unroll;
- len_unroll *= len_last_dim_unroll;
- break;
- }
- }
-
- if (prb.ndims - ndims_full_unroll > ndims_jit_loop_max)
- return false;
-
- if (desc) {
- desc->ndims_full_unroll = ndims_full_unroll;
- desc->len_last_dim_unroll = len_last_dim_unroll;
- desc->len_unroll = len_unroll;
- }
-
- return true;
- }
-
- static bool applicable(const prb_t &p) {
- using namespace data_type;
-
- bool ok = true
- && p.ndims > 0
- && utils::one_of(p.itype, f32, s32, s8, u8)
- && utils::one_of(p.otype, f32, s32, s8, u8)
- && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */
- && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */
- && simple_impl_desc_init(p, nullptr)
- && mayiuse(sse42)
- && IMPLICATION(!utils::everyone_is(f32, p.itype, p.otype),
- mayiuse(avx));
- if (!ok) return false;
-
- const ptrdiff_t max_stride = (1LL<<31) - 1;
- for (int d = 0; d < p.ndims; ++d) {
- const ptrdiff_t cms = max_stride / p.nodes[d].n;
- bool strides_ok = true
- && p.nodes[d].is < cms / (int)data_type_size(p.itype)
- && p.nodes[d].os < cms / (int)data_type_size(p.otype);
- if (!strides_ok) return false;
- }
-
- return true;
- }
-
- int n(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].n; }
- int is(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].is; }
- int os(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].os; }
- int ss(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].ss; }
-
- Address i_addr(int i_off)
- { return ptr[reg_ptr_in + reg_off_in + i_off * itype_sz]; }
-
- Address o_addr(int o_off)
- { return ptr[reg_ptr_out + reg_off_out + o_off * otype_sz]; }
-
- Address s_addr(int s_off)
- { return ptr[reg_ptr_scale + reg_off_scale + s_off * stype_sz]; }
-
- void step(int off, int prev_i_off, int prev_o_off, int prev_s_off,
- int &i_off, int &o_off, int &s_off, int step_size = 1) {
- i_off = prev_i_off;
- o_off = prev_o_off;
- s_off = prev_s_off;
-
- if (off == 0) return;
-
- int start_dim = 0, dims_prod = 1;
- for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim)
- dims_prod *= n(start_dim);
- assert(start_dim < prb_.ndims);
- off /= step_size;
-
- for (int d = start_dim; d < prb_.ndims; ++d) {
- i_off += is(d);
- o_off += os(d);
- s_off += ss(d);
-
- if (off % n(d)) break;
-
- i_off += - n(d) * is(d);
- o_off += - n(d) * os(d);
- s_off += - n(d) * ss(d);
- off /= n(d);
-
- if (off == 0) break; /* FIXME: is it really required? */
- }
- }
-
- void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off,
- int step_size = 1) {
- int dummy = 0;
- step(off, prev_i_off, prev_o_off, dummy, i_off, o_off, dummy,
- step_size);
- }
-
- void tr8x8_avx2(int i_off, int o_off) {
- for (int i = 0; i < 8; i++)
- vmovups(Ymm(i), i_addr(i_off + i * 8));
-
- for (int i = 0; i < 8 / 2; i++) {
- vunpcklps(Ymm(8 + i), Ymm(2 * i), Ymm(2 * i + 1));
- vunpckhps(Ymm(i), Ymm(2 * i), Ymm(2 * i + 1));
- }
-
- const unsigned int lfloat = 0x44;
- const unsigned int ufloat = 0xee;
- for (int i = 0; i < 8 / 2; i++) {
- int j = i % 2 == 0 ? 8 + i : i - 1;
- vshufps(Ymm(8 / 2 + 2 * i), Ymm(j), Ymm(j + 1), lfloat);
- vshufps(Ymm(8 / 2 + 2 * i + 1), Ymm(j), Ymm(j + 1), ufloat);
- }
-
- const unsigned int lquad = 0x20;
- for (int i = 0; i < 8 / 2; i++)
- vperm2f128(Ymm(i), Ymm(8 / 2 + i), Ymm(8 + i), lquad);
-
- const unsigned int uquad = 0x31;
- for (int i = 8 / 2; i < 8; i++)
- vperm2f128(Ymm(i), Ymm(i), Ymm(8 / 2 + i), uquad);
-
- for (int i = 0; i < 8; i++)
- vmovups(o_addr(o_off + i * 8), Ymm(i));
- }
-
- bool process_unroll_tr8x8(int len) {
- bool can_do = true
- && mayiuse(avx2)
- && prb_.ndims >= 2
- && utils::everyone_is(4, itype_sz, otype_sz)
- && utils::everyone_is(8, n(0), n(1))
- && utils::everyone_is(1, os(0), is(1))
- && utils::everyone_is(8, os(1), is(0))
- && prb_.scale_type == scale_type_t::NONE
- && prb_.beta == 0.f;
- if (!can_do) return false;
-
- const int step_size = n(0) * n(1);
- int i_off = 0, o_off = 0;
- for (int off = 0; off < len; off += step_size) {
- step(off, i_off, o_off, i_off, o_off, step_size);
- tr8x8_avx2(i_off, o_off);
- }
-
- return true;
- }
-
- template <cpu_isa_t isa>
- bool process_direct_copy(int len) {
- using namespace data_type;
-
- using Vmm = typename cpu_isa_traits<isa>::Vmm;
- const int simd_w = cpu_isa_traits<isa>::vlen / itype_sz;
-
- bool can_do = true
- && mayiuse(isa)
- && utils::everyone_is(1, os(0), is(0))
- && (false
- || prb_.itype == prb_.otype
- || (prb_.itype == s32 && prb_.otype == f32)
- || (prb_.itype == f32 && prb_.otype == s32)
- )
- && len % simd_w == 0
- && n(0) % len == 0
- && prb_.scale_type == scale_type_t::NONE
- && prb_.beta == 0.f;
- if (!can_do) return false;
-
- for (int off = 0; off < len;) {
- const int unroll = nstl::min(16, (len - off) / simd_w);
-
- for (int ur = 0; ur < unroll; ++ur)
- uni_vmovups(Vmm(ur), i_addr(off + ur * simd_w));
-
- if (prb_.itype != prb_.otype) {
- for (int ur = 0; ur < unroll; ++ur) {
- if (prb_.itype == s32 && prb_.otype == f32)
- uni_vcvtdq2ps(Vmm(ur), Vmm(ur));
- else if (prb_.itype == f32 && prb_.otype == s32)
- uni_vcvtps2dq(Vmm(ur), Vmm(ur));
- else assert(!"unreachable");
- }
- }
-
- for (int ur = 0; ur < unroll; ++ur)
- uni_vmovups(o_addr(off + ur * simd_w), Vmm(ur));
-
- off += unroll * simd_w;
- }
-
- return true;
- }
-
- void process_unroll_generic_step(int reg_unroll, const int *i_off,
- const int *o_off, const int *s_off) {
- using namespace data_type;
-
- auto cvt2ps = [=](const Xmm &dst, const Operand &src, data_type_t idt) {
- Xmm dst_pure = Xmm(dst.getIdx());
- switch (idt) {
- case f32:
- if (src.isMEM() || src.getIdx() != dst.getIdx())
- vmovups(dst, src);
- break;
- case s32: vcvtdq2ps(dst, src); break;
- case s8: vpmovsxbd(dst, src); vcvtdq2ps(dst_pure, dst); break;
- case u8: vpmovzxbd(dst, src); vcvtdq2ps(dst_pure, dst); break;
- default: assert(!"unreachable");
- }
- };
-
- auto cvt2int = [=](const Xmm &xmm, data_type_t odt, data_type_t idt) {
- switch (odt) {
- case s32:
- if (idt == f32) vcvtps2dq(xmm, xmm);
- else if (idt == s8) vpmovsxbd(xmm, xmm);
- else if (idt == u8) vpmovzxbd(xmm, xmm);
- break;
- case s8:
- if (idt == f32) vcvtps2dq(xmm, xmm);
- if (idt == f32 || idt == s32) {
- if (mayiuse(avx512_core)) {
- vpmovsdb(xmm, xmm);
- } else {
- vpackssdw(xmm, xmm, xmm_zero);
- vpacksswb(xmm, xmm, xmm_zero);
- }
- }
- if (idt == u8) vpminub(xmm, xmm, xmm_4x127b);
- break;
- case u8:
- if (idt == f32) vcvtps2dq(xmm, xmm);
- if (idt == f32 || idt == s32) {
- if (mayiuse(avx512_core)) {
- vpmaxsd(xmm, xmm, xmm_zero);
- vpmovusdb(xmm, xmm);
- } else {
- vpackssdw(xmm, xmm, xmm_zero);
- vpackuswb(xmm, xmm, xmm_zero);
- }
- }
- if (idt == s8) vpmaxsb(xmm, xmm, xmm_zero);
- break;
- default: assert(!"unreachable");
- }
- };
-
- auto load = [=](const Xmm &xmm, const Address &addr, int size) {
- switch (size) {
- case 16: movups(xmm, addr); break;
- case 4: movss(xmm, addr); break;
- case 1: pinsrb(xmm, addr, 0x0); break;
- default: assert(!"unreachable");
- }
- };
-
- auto store = [=](const Address &addr, const Xmm &xmm, int size) {
- switch (size) {
- case 16: movups(addr, xmm); break;
- case 4: movss(addr, xmm); break;
- case 1: pextrb(addr, xmm, 0x0); break;
- default: assert(!"unreachable");
- }
- };
-
- /* check whether loading 4 values at once is possible */
- bool can_load_xmm = mayiuse(avx) && reg_unroll % 4 == 0;
- for (int ur = 1; ur < reg_unroll; ++ur)
- if (i_off[ur] != i_off[ur - 1] + 1)
- can_load_xmm = false;
- const int load_step = can_load_xmm ? 4 : 1;
-
- /* check whether storing 4 values at once is possible */
- bool can_store_xmm = reg_unroll % 4 == 0;
- for (int ur = 1; ur < reg_unroll; ++ur)
- if (o_off[ur] != o_off[ur - 1] + 1)
- can_store_xmm = false;
- const int ur_step = can_store_xmm ? 4 : 1;
-
- const bool interim_f32 = false
- || utils::one_of(f32, prb_.itype, prb_.otype)
- || prb_.scale_type != scale_type_t::NONE
- || prb_.beta != 0.f;
-
- if (!can_load_xmm && can_store_xmm) {
- assert(ur_step == 4);
- /* load with stride */
- for (int ur = 0; ur < reg_unroll; ur += ur_step) {
- for (int r = 0; r < ur_step; ++r) {
- if (itype_sz == 4)
- pinsrd(Xmm(ur), i_addr(i_off[ur + r]), r);
- else
- pinsrb(Xmm(ur), i_addr(i_off[ur + r]), r);
- }
- }
- } else {
- for (int ur = 0; ur < reg_unroll; ur += load_step)
- load(Xmm(ur), i_addr(i_off[ur]), load_step * itype_sz);
- }
-
- /* xmm[:] <-- (f32)xmm[:] */
- if (interim_f32) {
- const int cvt_step = nstl::max(load_step, ur_step);
- for (int ur = 0; ur < reg_unroll; ur += cvt_step)
- cvt2ps(Xmm(ur), Xmm(ur), prb_.itype);
- }
-
- if (can_load_xmm && !can_store_xmm) {
- const bool fast_return = true // transposition on the fly
- && prb_.scale_type != scale_type_t::MANY
- && prb_.beta == 0.f;
- if (fast_return) {
- for (int ur = 0; ur < reg_unroll; ur += load_step) {
- if (prb_.scale_type == scale_type_t::COMMON)
- mulps(Xmm(ur), xmm_scale);
- if (prb_.otype != f32)
- cvt2int(Xmm(ur), prb_.otype,
- interim_f32 ? f32 : prb_.itype);
- for (int r = 0; r < load_step; ++r) {
- if (otype_sz == 4)
- pextrd(o_addr(o_off[ur + r]), Xmm(ur), r);
- else
- pextrb(o_addr(o_off[ur + r]), Xmm(ur), r);
- }
- }
- return;
- }
-
- /* scatter elements of xmm into 4 xmms */
- if (itype_sz == 4 || interim_f32) {
- for (int ur = 0; ur < reg_unroll; ur += load_step)
- for (int r = 1; r < load_step; ++r)
- vshufps(Xmm(ur + r), Xmm(ur), Xmm(ur), r);
- } else {
- for (int ur = 0; ur < reg_unroll; ur += load_step)
- for (int r = 1; r < load_step; ++r)
- vpalignr(Xmm(ur + r), Xmm(ur), Xmm(ur), r);
- }
- }
-
- /* scale and beta processing */
- if (can_store_xmm) {
- /* xmm <-- scale * xmm[:] */
- if (prb_.scale_type == scale_type_t::COMMON) {
- for (int ur = 0; ur < reg_unroll; ur += ur_step)
- mulps(Xmm(ur), xmm_scale);
- } else if (prb_.scale_type == scale_type_t::MANY) {
- enum class scale_load_type_t { bcast, load, gather };
-
- for (int ur = 0; ur < reg_unroll; ur += ur_step) {
- scale_load_type_t scale_load_type =
- scale_load_type_t::bcast; // the best case
-
- for (int r = ur + 1; r < ur + ur_step; ++r)
- if (s_off[r] != s_off[r - 1] + 0)
- scale_load_type = scale_load_type_t::load;
-
- if (scale_load_type == scale_load_type_t::bcast) {
- movss(xmm_scale, s_addr(s_off[ur]));
- shufps(xmm_scale, xmm_scale, 0x0);
- mulps(Xmm(ur), xmm_scale);
- continue;
- }
-
- // bcast doesn't work, the next try -- load
- for (int r = ur + 1; r < ur + ur_step; ++r)
- if (s_off[r] != s_off[r - 1] + 1)
- scale_load_type = scale_load_type_t::gather;
-
- if (scale_load_type == scale_load_type_t::load) {
- movups(xmm_scale, s_addr(s_off[ur]));
- mulps(Xmm(ur), xmm_scale);
- continue;
- }
-
- // load doesn't work as well
- // so gather the scale factors one by one
- for (int r = ur; r < ur + ur_step; ++r)
- pinsrd(xmm_scale, s_addr(s_off[r]), r - ur);
- mulps(Xmm(ur), xmm_scale);
- }
- }
-
- /* dst <-- beta * dst + xmm[:] */
- assert(prb_.beta == 0.f || prb_.beta == 1.f);
- if (prb_.beta == 1.f) {
- for (int ur = 0; ur < reg_unroll; ur += ur_step) {
- if (prb_.otype == f32) {
- /* non VEX instructions do not support unaligned
- * memory for instructions other than movups. */
- if (mayiuse(avx)) {
- vaddps(Xmm(ur), o_addr(o_off[ur]));
- } else {
- /* register xmm(1) is unused */
- movups(Xmm(1), o_addr(o_off[ur]));
- addps(Xmm(ur), Xmm(1));
- }
- } else {
- cvt2ps(Xmm(1), o_addr(o_off[ur]), prb_.otype);
- vaddps(Xmm(ur), Xmm(1));
- }
- }
- }
- } else {
- /* xmm[0] <-- scale * xmm[0] */
- if (prb_.scale_type == scale_type_t::COMMON) {
- for (int ur = 0; ur < reg_unroll; ur += ur_step)
- mulss(Xmm(ur), xmm_scale);
- } else if (prb_.scale_type == scale_type_t::MANY) {
- for (int ur = 0; ur < reg_unroll; ur += ur_step) {
- mulss(Xmm(ur), s_addr(s_off[ur]));
- }
- }
-
- /* dst <-- beta * dst + xmm[0] */
- assert(prb_.beta == 0.f || prb_.beta == 1.f);
- if (prb_.beta == 1.f) {
- for (int ur = 0; ur < reg_unroll; ur += ur_step) {
- if (prb_.otype == f32) {
- addss(Xmm(ur), o_addr(o_off[ur]));
- } else {
- if (prb_.otype == s32) {
- vmovss(xmm_tmp, o_addr(o_off[ur]));
- } else if (utils::one_of(prb_.otype, s8, u8)) {
- pinsrb(xmm_tmp, o_addr(o_off[ur]), 0x0);
- } else {
- assert(!"unsupported o_type");
- }
- cvt2ps(xmm_tmp, xmm_tmp, prb_.otype);
- addps(Xmm(ur), xmm_tmp);
- }
- }
- }
- }
-
- for (int ur = 0; ur < reg_unroll; ur += ur_step) {
- if (prb_.otype != f32)
- cvt2int(Xmm(ur), prb_.otype, interim_f32 ? f32 : prb_.itype);
- store(o_addr(o_off[ur]), Xmm(ur), ur_step * otype_sz);
- }
- }
-
- void process_unroll_generic(int len) {
- const int blk = 8;
-
- int i_off[2 * blk] = {0};
- int o_off[2 * blk] = {0};
- int s_off[2 * blk] = {0};
-
- int curr = 0; // will switch between 0 and 1
-
- for (int off = 0; off < len; off += blk) {
- const int reg_unroll = nstl::min(off + blk, len) - off;
-
- /* compute offsets */
- for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) {
- const int ur_c = curr * blk + ur;
- const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur
- step(off + ur,
- i_off[ur_p], o_off[ur_p], s_off[ur_p],
- i_off[ur_c], o_off[ur_c], s_off[ur_c]);
- }
-
- process_unroll_generic_step(reg_unroll, i_off + curr * blk,
- o_off + curr * blk, s_off + curr * blk);
-
- curr = 1 - curr;
- }
- }
-
- void loop_begin(Label &l, Reg64 reg_cnt, int len) {
- mov(reg_cnt, len);
- L(l);
- }
-
- void loop_end(Label &l, Reg64 reg_cnt, int len,
- int i_step, int o_step, int s_step) {
- add(reg_off_in, i_step * itype_sz);
- add(reg_off_out, o_step * otype_sz);
- if (prb_.scale_type == scale_type_t::MANY)
- add(reg_off_scale, s_step * stype_sz);
- dec(reg_cnt);
- jnz(l);
-
- sub(reg_off_in, len * i_step * itype_sz);
- sub(reg_off_out, len * o_step * otype_sz);
- if (prb_.scale_type == scale_type_t::MANY)
- sub(reg_off_scale, len * s_step * stype_sz);
- }
-
- bool simple_impl() {
- simple_impl_desc_t d;
- if (!simple_impl_desc_init(prb_, &d)) return false;
-
- const int nfu = d.ndims_full_unroll;
- const int ldu = d.len_last_dim_unroll;
- const int n_jit_loops = prb_.ndims - d.ndims_full_unroll;
- assert(n_jit_loops <= ndims_jit_loop_max);
-
- xor_(reg_off_in, reg_off_in);
- xor_(reg_off_out, reg_off_out);
- if (prb_.scale_type == scale_type_t::MANY)
- xor_(reg_off_scale, reg_off_scale);
-
- Label l_loop[3];
- Reg64 reg_cnt[3] = {r15, r14, r13};
-
- if (n_jit_loops > 2)
- loop_begin(l_loop[2], reg_cnt[2], n(nfu + 2));
-
- if (n_jit_loops > 1)
- loop_begin(l_loop[1], reg_cnt[1], n(nfu + 1));
-
- if (n_jit_loops > 0)
- loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu);
-
- const bool optimized = false
- || process_direct_copy<avx>(d.len_unroll)
- || process_direct_copy<sse42>(d.len_unroll)
- || process_unroll_tr8x8(d.len_unroll);
- if (!optimized)
- process_unroll_generic(d.len_unroll);
-
- if (n_jit_loops > 0)
- loop_end(l_loop[0], reg_cnt[0],
- n(nfu + 0) / ldu, is(nfu + 0) * ldu, os(nfu + 0) * ldu,
- ss(nfu + 0) * ldu);
-
- if (n_jit_loops > 1)
- loop_end(l_loop[1], reg_cnt[1],
- n(nfu + 1), is(nfu + 1), os(nfu + 1), ss(nfu + 1));
-
- if (n_jit_loops > 2)
- loop_end(l_loop[2], reg_cnt[2],
- n(nfu + 2), is(nfu + 2), os(nfu + 2), ss(nfu + 2));
-
- return true;
- }
-
- void impl() {
- if (simple_impl()) return;
- assert(!"no implementation available");
- }
-
- jit_uni_reorder_kernel_f32(const desc_t &desc)
- : kernel_t(desc), jit_generator() {
- itype_sz = data_type_size(prb_.itype);
- otype_sz = data_type_size(prb_.otype);
- stype_sz = sizeof(float);
-
- preamble();
-# define PARAM(x) ptr[abi_param1 + offsetof(call_param_t, x)]
- if (prb_.scale_type == scale_type_t::COMMON) {
- auto reg_ptr_scale_tmp = reg_ptr_in;
- mov(reg_ptr_scale_tmp, PARAM(scale));
- movups(xmm_scale, ptr[reg_ptr_scale_tmp]);
- } else if (prb_.scale_type == scale_type_t::MANY) {
- mov(reg_ptr_scale, PARAM(scale));
- }
- mov(reg_ptr_in, PARAM(in));
- mov(reg_ptr_out, PARAM(out));
-# undef PARAM
-
- if (mayiuse(avx)) {
- vxorps(xmm_zero, xmm_zero, xmm_zero);
-
- if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) {
- mov(reg_tmp.cvt32(), 0x7f7f7f7f);
- movd(xmm_4x127b, reg_tmp.cvt32());
- }
- }
-
- impl();
- postamble();
- ker_ = (void (*)(const call_param_t *))getCode();
- }
-
-private:
- int itype_sz;
- int otype_sz;
- int stype_sz;
-
- Reg64 reg_ptr_in = rsi;
- Reg64 reg_ptr_out = rdx;
- Reg64 reg_ptr_scale = abi_not_param1;
-
- Reg64 reg_off_in = r8;
- Reg64 reg_off_out = r9;
- Reg64 reg_off_scale = r10;
-
- Reg64 reg_tmp = rax;
-
- Xmm xmm_scale = xmm15;
- Xmm xmm_zero = xmm14;
- Xmm xmm_4x127b = xmm13; // TODO: unite with xmm_zero
- Xmm xmm_tmp = xmm12;
-};
-
-status_t kernel_t::desc_init(kernel_t::desc_t &desc, const prb_t &prb,
- int ndims_ker_max) {
- desc.prb = prb;
- desc.prb.ioff = desc.prb.ooff = 0;
-
- if (ndims_ker_max > prb.ndims)
- return status::invalid_arguments;
-
- auto ndims_ker_max_f = [&]() {
- size_t cur_size = 1;
- for (int d = 0; d < prb.ndims; cur_size *= prb.nodes[d++].n)
- if (cur_size >= ker_prb_size_min) return d;
- return prb.ndims;
- };
-
- if (ndims_ker_max <= 0)
- ndims_ker_max = ndims_ker_max_f();
-
- /* traverse through kernel implementations */
- /* TODO: find a better way to do that... */
- desc.id = 0;
- for (int ndims_ker = ndims_ker_max; ndims_ker > 0; --ndims_ker) {
- desc.prb.ndims = ndims_ker;
- if (jit_uni_reorder_kernel_f32::applicable(desc.prb))
- return status::success;
- }
-
- return status::unimplemented;
-}
-
-kernel_t *kernel_t::create(const kernel_t::desc_t &desc) {
- switch (desc.id) {
- case 0: return new jit_uni_reorder_kernel_f32(desc);
- default: assert(!"unknown kernel id"); return nullptr;
- }
-
- return nullptr;
-}
-
-}
-
-static void prb_block_for_cache(tr::prb_t &prb) {
- if (prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > 16) {
- /** an attempt to use caches more efficient and
- * address the 4K-aliasing issue */
- /* TODO: improve the logic around here */
- int j = 1;
- for (; j < prb.ndims && prb.nodes[j].is != 1; ++j);
- if (j == prb.ndims) return;
-
- /* it makes sense to re-prioritize sequential read over
- * sequential write if the former would not trash the
- * cache, i.e. is == 1 and os % 2^smth != 0. Smth is
- * set to 2 at the moment */
- const int move_to = prb.nodes[j].os % 4 != 0 ? 0 : 1;
- if (j == move_to) return;
-
- if (prb.nodes[j].n > 16 && prb.nodes[j].n % 16 == 0)
- prb_node_split(prb, j, 16);
-
- prb_node_move(prb, j, move_to);
- DEBUG({ printf("cache: "); prb_dump(prb); });
- }
-}
-
-/** finds the maximum number of dimension the kernel should process and
- * optionally splits one of the dimension to achieve better balance between
- * parallel driver and the kernel. */
-static void prb_thread_kernel_balance(tr::prb_t &prb, int &ndims_ker_max) {
- size_t sz_total = 1;
- for (int d = 0; d < prb.ndims; ++d)
- sz_total *= prb.nodes[d].n;
-
- /* sz_drv_min is the minimal size for the parallel
- * driver required for good parallelization */
- const size_t sz_drv_min = nstl::min<size_t>(
- 16 * mkldnn_get_max_threads(),
- utils::div_up(sz_total, 1024));
-
- /* kdims -- # of dimensions processed by a kernel
- * sz_ker_cur -- product of the dimension processed by a kernel
- * sz_drv_cur -- product of the dimension processed by a driver */
-
- int kdims = prb.ndims;
- size_t sz_drv_cur = 1;
- for (; kdims > 1 && sz_drv_cur < sz_drv_min; --kdims)
- sz_drv_cur *= prb.nodes[kdims - 1].n;
-
- size_t sz_ker_cur = 1;
- for (int d = 0; d < kdims; ++d)
- sz_ker_cur *= prb.nodes[d].n;
-
- /* Initially kdims is chosen so that sz_drv_cur >= sz_drv_min.
- *
- * It might happen that for chosen kdims the sz_ker_cur is too small
- * (less than tr::ker_prb_size_min). In that case try to split the
- * innermost driver dimension into two, to increase sz_ker_cur. */
- bool want_borrow_ker_from_drv = true
- && kdims < prb.ndims
- && sz_ker_cur < tr::ker_prb_size_min
- && sz_drv_cur > sz_drv_min;
- if (want_borrow_ker_from_drv) {
- /* sz_want_borrow is the minimal sz, so that:
- * o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min
- * o) current innermost driver dimension is divisible by
- * sz_want_borrow (so that we can evenly split that
- * dimension into two)
- *
- * In the worst case the minimal sz_want_borrow is equal
- * to the innermost driver dimension itself. In that case
- * we will sacrifice it in favor of kernel (is it fine?). */
- size_t sz_want_borrow
- = utils::div_up(tr::ker_prb_size_min, sz_ker_cur);
- for (; prb.nodes[kdims].n % sz_want_borrow; ++sz_want_borrow);
- if (sz_want_borrow != prb.nodes[kdims].n)
- prb_node_split(prb, kdims, sz_want_borrow);
- kdims += 1;
- }
-
- /* On the other hand it might happen that for chosen kdims
- * the sz_drv_cur is too small (less than sz_drv_min). In that case
- * try to split the outermost kernel dimension into two, to increase
- * sz_drv_cur. */
- bool want_borrow_drv_from_ker = true
- && sz_ker_cur > tr::ker_prb_size_min
- && sz_drv_cur < sz_drv_min;
- if (want_borrow_drv_from_ker) {
- size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur);
- for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow);
- if (sz_want_borrow != prb.nodes[kdims - 1].n)
- prb_node_split(prb, kdims - 1,
- prb.nodes[kdims - 1].n / sz_want_borrow);
- }
-
- ndims_ker_max = kdims;
-
- if (want_borrow_ker_from_drv || want_borrow_drv_from_ker) {
- DEBUG({ printf("split: "); prb_dump(prb);
- printf("ndims_ker_max = %d\n", ndims_ker_max); });
- }
-}
-
-struct jit_uni_reorder_t : public cpu_primitive_t {
- struct pd_t : public cpu_reorder_pd_t {
- using cpu_reorder_pd_t::cpu_reorder_pd_t;
-
- DECLARE_COMMON_PD_T("jit:uni", jit_uni_reorder_t);
-
- static status_t create(reorder_pd_t **reorder_pd,
- engine_t *engine, const primitive_attr_t *attr,
- engine_t *src_engine, const memory_desc_t *src_md,
- engine_t *dst_engine, const memory_desc_t *dst_md) {
- auto prb = tr::prb_t();
-
- status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr);
- if (prb_init_status != status::success) return prb_init_status;
-
- DEBUG({ printf("init : "); prb_dump(prb); });
- prb_normalize(prb);
- DEBUG({ printf("norm : "); prb_dump(prb); });
- prb_simplify(prb);
- DEBUG({ printf("smpl : "); prb_dump(prb); });
-
- prb_block_for_cache(prb);
-
- int ndims_ker_max;
- prb_thread_kernel_balance(prb, ndims_ker_max);
-
- tr::kernel_t::desc_t ker_desc;
- status_t ker_init_status
- = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max);
- if (ker_init_status != status::success) return ker_init_status;
-
- const int ndims_driver = prb.ndims - ker_desc.prb.ndims;
- if (ndims_driver > jit_uni_reorder_t::ndims_driver_max)
- return status::unimplemented;
-
- DEBUG({ printf("ker : "); prb_dump(ker_desc.prb); });
-
- auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
- dst_md);
- if (_pd == nullptr) return status::out_of_memory;
- if (_pd->init() != status::success) {
- delete _pd;
- return status::unimplemented;
- }
- _pd->prb_ = prb;
- _pd->ker_desc_ = ker_desc;
- return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
- }
-
- tr::prb_t prb_;
- tr::kernel_t::desc_t ker_desc_;
- };
-
- jit_uni_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {
- kernel_ = tr::kernel_t::create(pd()->ker_desc_);
- assert(kernel_);
- }
- ~jit_uni_reorder_t() { delete kernel_; }
-
- void omp_driver_0d(int off, const char *in, char *out,
- const float *scale) const {
- tr::call_param_t c{in, out, scale};
- (*kernel_)(&c);
- }
-
- void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out,
- const float *scale) const {
- const tr::node_t *ns = pd()->prb_.nodes + off;
- for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) {
- auto c = tr::call_param_t();
- c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype);
- c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype);
- c.scale = scale + d0 * ns[0].ss;
- (*kernel_)(&c);
- });
- }
-
- void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out,
- const float *scale) const {
- const tr::node_t *ns = pd()->prb_.nodes + off;
- for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n,
- [&](ptrdiff_t d1, ptrdiff_t d0) {
- auto c = tr::call_param_t();
- c.in = in + (d0 * ns[0].is + d1 * ns[1].is)
- * data_type_size(pd()->prb_.itype);
- c.out = out + (d0 * ns[0].os + d1 * ns[1].os)
- * data_type_size(pd()->prb_.otype);
- c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss;
- (*kernel_)(&c);
- });
- }
-
- void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out,
- const float *scale) const {
- const tr::node_t *ns = pd()->prb_.nodes + off;
- for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n,
- (ptrdiff_t)ns[0].n,
- [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) {
- auto c = tr::call_param_t();
- c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is)
- * data_type_size(pd()->prb_.itype);
- c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os)
- * data_type_size(pd()->prb_.otype);
- c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss;
- (*kernel_)(&c);
- });
- }
-
- void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out,
- const float *scale) const {
- const tr::node_t *ns = pd()->prb_.nodes + off;
- for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n,
- (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n,
- [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) {
- auto c = tr::call_param_t();
- c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is
- + d3 * ns[3].is) * data_type_size(pd()->prb_.itype);
- c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os
- + d3 * ns[3].os) * data_type_size(pd()->prb_.otype);
- c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss
- + d3 * ns[3].ss;
- (*kernel_)(&c);
- });
- }
-
- void omp_driver(const char *in, char *out, const float *scale) const {
- in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype);
- out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype);
-
- DEBUG({ printf("prb : "); tr::prb_dump(pd()->prb_); });
- DEBUG({ printf("ker : "); tr::prb_dump(pd()->ker_desc_.prb); });
-
- int ndims = pd()->prb_.ndims;
- int ndims_ker = pd()->ker_desc_.prb.ndims;
- assert(ndims - ndims_ker <= ndims_driver_max);
-
- if (ndims - ndims_ker == 0) {
- omp_driver_0d(ndims_ker, in, out, scale);
- } else {
- parallel(0, [&](const int ithr, const int nthr) {
- switch (ndims - ndims_ker) {
- case 1: omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale); break;
- case 2: omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale); break;
- case 3: omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale); break;
- case 4: omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale); break;
- default: assert(!"unimplemented");
- }
- });
- }
- }
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- auto in = CTX_IN_MEM(const char *, MKLDNN_ARG_FROM);
- auto out = CTX_OUT_MEM(char *, MKLDNN_ARG_TO);
-
- omp_driver(in, out, pd()->attr()->output_scales_.scales_);
-
- return status::success;
- }
-
- enum { ndims_driver_max = 4 };
-
-private:
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- tr::kernel_t *kernel_;
-};
-
-status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd,
- engine_t *engine, const primitive_attr_t *attr,
- engine_t *src_engine, const memory_desc_t *src_md,
- engine_t *dst_engine, const memory_desc_t *dst_md) {
- return jit_uni_reorder_t::pd_t::create(reorder_pd, engine, attr,
- src_engine, src_md, dst_engine, dst_md);
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp
deleted file mode 100644
index 0746ea61d3..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp
+++ /dev/null
@@ -1,127 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef _JIT_UNI_REORDER_HPP
-#define _JIT_UNI_REORDER_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-
-#include "cpu_primitive.hpp"
-#include "cpu_reorder_pd.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-namespace tr {
-
-constexpr int max_ndims = MKLDNN_MAX_NDIMS;
-
-struct node_t {
- size_t n;
- ptrdiff_t is; // input stride
- ptrdiff_t os; // output stride
- ptrdiff_t ss; // scale stride
-};
-
-enum class scale_type_t { NONE, COMMON, MANY };
-
-struct prb_t {
- data_type_t itype;
- data_type_t otype;
- int ndims;
- node_t nodes[max_ndims];
- ptrdiff_t ioff;
- ptrdiff_t ooff;
- scale_type_t scale_type;
- float beta;
-};
-
-status_t prb_init(prb_t &prb, const memory_desc_t &imd,
- const memory_desc_t &omd, const primitive_attr_t *attr);
-
-/** sorts the problem nodes so that output strides come in ascending order */
-void prb_normalize(prb_t &p);
-
-/** folds nodes together if possible */
-void prb_simplify(prb_t &p);
-
-/** splits the node dim into two of sizes n1 and n / n1
- * @warning n must be multiple of n1 */
-void prb_node_split(prb_t &p, int dim, size_t n1);
-
-/** swaps d0 and d1 nodes */
-void prb_node_swap(prb_t &p, int d0, int d1);
-
-/** moves node d0 to the d1 position.
- * nodes (d0, d1] are shifted to the left if d0 < d1 or
- * to the right if d0 > d1 */
-void prb_node_move(prb_t &p, int d0, int d1);
-
-/** dumps the problem to stdout */
-void prb_dump(const prb_t &p);
-
-struct call_param_t {
- const void *in;
- void *out;
- const float *scale;
-};
-
-struct kernel_t {
- struct desc_t {
- int id;
- prb_t prb;
- };
-
- kernel_t(const desc_t &desc): desc_(desc), ker_(nullptr) {}
- void operator()(const call_param_t *c) const { assert(ker_); ker_(c); }
- virtual ~kernel_t() {}
-
- /** inits kernel descriptor:
- * desc -- kernel descriptor (output)
- * prb -- transposition problem (input)
- * ndims_ker_max -- limit the maximum number of dimensions kernel
- * will process (optional, 0 -- no limitation) */
- static status_t desc_init(desc_t &desc, const prb_t &prb,
- int ndims_ker_max = 0);
-
- /** creates kernel for the problem described in desc */
- static kernel_t *create(const desc_t &desc);
-
-protected:
- const desc_t desc_;
- const prb_t &prb_ = desc_.prb;
- void (*ker_)(const call_param_t *);
-};
-
-/* TODO: add trans_t class */
-
-}
-
-/* for cpu reorder list */
-status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd,
- engine_t *engine, const primitive_attr_t *attr,
- engine_t *src_engine, const memory_desc_t *src_md,
- engine_t *dst_engine, const memory_desc_t *dst_md);
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp
deleted file mode 100644
index 69b7a33604..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp
+++ /dev/null
@@ -1,313 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "memory_desc_wrapper.hpp"
-#include "mkldnn_debug.h"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "jit_uni_reorder.hpp"
-
-using namespace mkldnn::impl::types;
-using namespace mkldnn::impl::status;
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-namespace tr {
-
-/** ad-hoc structure to describe blocked memory layout */
-struct layout_desc_t {
- data_type_t dt;
- int ndims;
- dims_t id;
- dims_t dims;
- strides_t strides;
-};
-
-status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_,
- layout_desc_t &ld) {
- const auto md = memory_desc_wrapper(md_);
-
- bool ok = true
- && md.is_blocking_desc()
- && md.extra().flags == 0;
- if (!ok) return invalid_arguments;
-
- const auto &bd = md.blocking_desc();
-
- ld.ndims = 0;
- ld.dt = md.data_type();
-
- auto P = [&ld](int id, int dim, ptrdiff_t stride) {
- assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0]));
- ld.id[ld.ndims] = id;
- ld.dims[ld.ndims] = dim;
- ld.strides[ld.ndims] = stride;
- ++ld.ndims;
- };
-
- dims_t blocks;
- md.compute_blocks(blocks);
-
- for (int d = 0; d < md.ndims(); ++d) {
- const int ld_ndims_start = ld.ndims;
- if (blocks[d] != 1) {
- stride_t stride = 1;
- for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) {
- if (bd.inner_idxs[iblk] == d)
- P(d, bd.inner_blks[iblk], stride);
- stride *= bd.inner_blks[iblk];
- }
- }
- P(d, md.padded_dims()[d] / blocks[d], bd.strides[d]);
-
- // TODO: NOW: revisit, do we need a reverse?
- // TODO: NOW: consider using strides instead of block sizes in md
- // reverse the order of dims
- for (int ld_d = 0; ld_d < (ld.ndims - ld_ndims_start) / 2; ++ld_d) {
- const int idx0 = ld_ndims_start + ld_d;
- const int idx1 = ld.ndims - 1 - ld_d;
- nstl::swap(ld.dims[idx0], ld.dims[idx1]);
- nstl::swap(ld.strides[idx0], ld.strides[idx1]);
- }
- }
-
- return success;
-}
-
-status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
- const primitive_attr_t *attr) {
- auto im_d = memory_desc_wrapper(imd);
- auto om_d = memory_desc_wrapper(omd);
-
- bool ok = true
- && im_d.is_blocking_desc()
- && om_d.is_blocking_desc()
- && !im_d.has_zero_dim()
- && !om_d.has_zero_dim();
- if (!ok)
- return unimplemented;
-
- dims_t iblocks, oblocks;
- im_d.compute_blocks(iblocks);
- om_d.compute_blocks(oblocks);
-
- /* padding_dim consistency check */
- for (int d = 0; d < im_d.ndims(); ++d) {
- const auto pdim = im_d.padded_dims()[d];
- bool ok = true
- && pdim == om_d.padded_dims()[d]
- && pdim % iblocks[d] == 0
- && pdim % oblocks[d] == 0;
- if (!ok) return unimplemented;
- }
-
- layout_desc_t ild, old;
- status_t status = cvt_mem_desc_to_layout_desc(imd, ild);
- if (status != success) return status;
- status = cvt_mem_desc_to_layout_desc(omd, old);
- if (status != success) return status;
-
- p.itype = ild.dt;
- p.otype = old.dt;
-
- p.scale_type = attr->output_scales_.has_default_values()
- ? scale_type_t::NONE
- : (attr->output_scales_.mask_ == 0
- ? scale_type_t::COMMON
- : scale_type_t::MANY);
-
- ptrdiff_t ss[max_ndims] = {0};
- if (p.scale_type == scale_type_t::MANY) {
- ptrdiff_t last_ss = 1;
- for (int d = old.ndims - 1; d >=0; --d) {
- assert((d == 0 || old.id[d - 1] <= old.id[d])
- && "logical dimensions should be in ascending order");
- if (attr->output_scales_.mask_ & (1 << old.id[d])) {
- ss[d] = last_ss;
- last_ss *= old.dims[d];
- }
- }
- }
-
- int ndims = 0;
-
- int i_pos = 0; /* state for input -- current dimension */
- int o_pos = 0; /* state for output -- current dimension */
-
- while (i_pos < ild.ndims && o_pos < old.ndims) {
- assert(ild.id[i_pos] == old.id[o_pos]);
- if (ild.id[i_pos] != old.id[o_pos])
- return runtime_error;
-
- assert(ndims < max_ndims);
- if (ndims == max_ndims)
- return runtime_error;
-
- if (ild.dims[i_pos] == old.dims[o_pos]) {
- p.nodes[ndims].n = ild.dims[i_pos];
- p.nodes[ndims].is = ild.strides[i_pos];
- p.nodes[ndims].os = old.strides[o_pos];
- p.nodes[ndims].ss = ss[o_pos];
- ++ndims;
- ++i_pos;
- ++o_pos;
- } else if (ild.dims[i_pos] < old.dims[o_pos]) {
- assert(old.dims[o_pos] % ild.dims[i_pos] == 0);
- int factor = old.dims[o_pos] / ild.dims[i_pos];
- p.nodes[ndims].n = ild.dims[i_pos];
- p.nodes[ndims].is = ild.strides[i_pos];
- p.nodes[ndims].os = old.strides[o_pos] * factor;
- p.nodes[ndims].ss = ss[o_pos] * factor;
- ++ndims;
- ++i_pos;
- old.dims[o_pos] = factor;
- } else if (ild.dims[i_pos] > old.dims[o_pos]) {
- assert(ild.dims[i_pos] % old.dims[o_pos] == 0);
- int factor = ild.dims[i_pos] / old.dims[o_pos];
- p.nodes[ndims].n = old.dims[o_pos];
- p.nodes[ndims].is = ild.strides[i_pos] * factor;
- p.nodes[ndims].os = old.strides[o_pos];
- p.nodes[ndims].ss = ss[o_pos];
- ++ndims;
- ++o_pos;
- ild.dims[i_pos] = factor;
- }
- }
- p.ndims = ndims;
-
- dims_t zero_pos = {0};
- p.ioff = memory_desc_wrapper(imd).off_v(zero_pos);
- p.ooff = memory_desc_wrapper(omd).off_v(zero_pos);
-
- const int sum_idx = attr->post_ops_.find(primitive_kind::sum);
- p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale;
-
- return success;
-}
-
-void prb_normalize(prb_t &p) {
- for (int d = 0; d < p.ndims; ++d) {
- int min_pos = d;
- for (int j = d + 1; j < p.ndims; ++j) {
- bool new_min = false
- || p.nodes[j].os < p.nodes[min_pos].os
- || (true
- && p.nodes[j].os == p.nodes[min_pos].os
- && p.nodes[j].n < p.nodes[min_pos].n);
- if (new_min) min_pos = j;
- }
- if (min_pos != d)
- nstl::swap(p.nodes[d], p.nodes[min_pos]);
- }
-}
-
-void prb_simplify(prb_t &p) {
-#if defined(__GNUC__) && __GNUC__ >= 4
-/* GCC produces bogus array subscript is above array bounds warning for
- * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */
-#pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Warray-bounds"
-#endif
- for (int d = 0; d < p.ndims - 1; ++d) {
- auto &this_node = p.nodes[d + 0];
- auto &next_node = p.nodes[d + 1];
- const bool fold = false
- || next_node.n == (size_t)1 // trivial case, just drop next node
- || (true // or real folding if possible
- && next_node.is == (ptrdiff_t)this_node.n * this_node.is
- && next_node.os == (ptrdiff_t)this_node.n * this_node.os
- && next_node.ss == (ptrdiff_t)this_node.n * this_node.ss);
- if (fold) {
- this_node.n *= next_node.n;
- for (int j = d + 2; j < p.ndims; ++j)
- p.nodes[j - 1] = p.nodes[j];
- --p.ndims;
- --d; // make another try
- }
- }
-#if defined(__GNUC__) && __GNUC__ >= 4
-#pragma GCC diagnostic pop
-#endif
-}
-
-void prb_node_split(prb_t &p, int dim, size_t n1) {
- assert(dim < p.ndims);
- assert(p.ndims < max_ndims);
- assert(p.nodes[dim].n % n1 == 0);
-
- p.ndims += 1;
-
- for (int d = p.ndims; d > dim + 1; --d)
- p.nodes[d] = p.nodes[d - 1];
-
- p.nodes[dim + 1].n = p.nodes[dim].n / n1;
- p.nodes[dim + 1].is = p.nodes[dim].is * n1;
- p.nodes[dim + 1].os = p.nodes[dim].os * n1;
- p.nodes[dim + 1].ss = p.nodes[dim].ss * n1;
-
- p.nodes[dim].n = n1;
-}
-
-void prb_node_swap(prb_t &p, int d0, int d1) {
- assert(d0 < p.ndims);
- assert(d1 < p.ndims);
- assert(p.ndims < max_ndims);
-
- if (d0 == d1) return;
-
- nstl::swap(p.nodes[d0], p.nodes[d1]);
-}
-
-void prb_node_move(prb_t &p, int d0, int d1) {
- assert(d0 < p.ndims);
- assert(d1 < p.ndims);
- assert(p.ndims < max_ndims);
-
- if (d0 == d1) return;
-
- node_t node = p.nodes[d0];
-
- if (d0 < d1)
- for (int d = d0; d < d1; ++d)
- p.nodes[d] = p.nodes[d + 1];
- else
- for (int d = d0; d > d1; --d)
- p.nodes[d] = p.nodes[d - 1];
-
- p.nodes[d1] = node;
-}
-
-void prb_dump(const prb_t &p) {
- printf("@@@ type:%s:%s ndims:%d ", mkldnn_dt2str(p.itype),
- mkldnn_dt2str(p.otype), p.ndims);
- for (int d = 0; d < p.ndims; ++d)
- printf("[%zu:%td:%td:%td]",
- p.nodes[d].n, p.nodes[d].is, p.nodes[d].os, p.nodes[d].ss);
- printf(" off:%zu:%zu\n", p.ioff, p.ooff);
-}
-
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp
deleted file mode 100644
index 08747aa89c..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp
+++ /dev/null
@@ -1,115 +0,0 @@
-/*******************************************************************************
-* Copyright 2019 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <mutex>
-
-#include "utils.hpp"
-
-#ifndef MKLDNN_ENABLE_JIT_PROFILING
-#define MKLDNN_ENABLE_JIT_PROFILING 1
-#endif
-
-#ifndef MKLDNN_ENABLE_JIT_DUMP
-#define MKLDNN_ENABLE_JIT_DUMP 1
-#endif
-
-#if MKLDNN_ENABLE_JIT_PROFILING
-#include "jitprofiling/jitprofiling.h"
-#endif
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-namespace jit_utils {
-
-// WARNING: These functions are not thread safe and must be protected by a
-// mutex
-
-void dump_jit_code(const void *code, size_t code_size, const char *code_name)
-{
-#if MKLDNN_ENABLE_JIT_DUMP
- if (code && jit_dump_enabled()) {
- static int counter = 0;
-#define MAX_FNAME_LEN 256
- char fname[MAX_FNAME_LEN + 1];
- // TODO (Roma): support prefix for code / linux perf dumps
- snprintf(fname, MAX_FNAME_LEN, "mkldnn_dump_%s.%d.bin", code_name,
- counter);
- counter++;
-
- FILE *fp = fopen(fname, "w+");
- // Failure to dump code is not fatal
- if (fp) {
- size_t unused = fwrite(code, code_size, 1, fp);
- UNUSED(unused);
- fclose(fp);
- }
- }
-#undef MAX_FNAME_LEN
-#else
- UNUSED(code);
- UNUSED(code_size);
- UNUSED(code_name);
-#endif
-}
-
-void register_jit_code_vtune(const void *code, size_t code_size,
- const char *code_name, const char *source_file_name)
-{
-#if MKLDNN_ENABLE_JIT_PROFILING
- if (iJIT_IsProfilingActive() == iJIT_SAMPLING_ON) {
- auto jmethod = iJIT_Method_Load();
- jmethod.method_id = iJIT_GetNewMethodID(); // XXX: not thread-safe
- jmethod.method_name = (char *)code_name; // XXX: dropping const
- jmethod.class_file_name = NULL;
- jmethod.source_file_name = (char *)source_file_name; // XXX: dropping const
- jmethod.method_load_address = (void *)code;
- jmethod.method_size = (unsigned int)code_size;
-
- iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED,
- (void*)&jmethod);
- }
-#else
- UNUSED(code);
- UNUSED(code_size);
- UNUSED(code_name);
- UNUSED(source_file_name);
-#endif
-}
-
-void register_jit_code(const void *code, size_t code_size,
- const char *code_name, const char *source_file_name)
-{
- // The #ifdef guards are required to avoid generating a function that only
- // consists of lock and unlock code
-#if MKLDNN_ENABLE_JIT_PROFILING || MKLDNN_ENABLE_JIT_DUMP
- static std::mutex m;
- std::lock_guard<std::mutex> guard(m);
-
- dump_jit_code(code, code_size, code_name);
- register_jit_code_vtune(code, code_size, code_name, source_file_name);
-#else
- UNUSED(code);
- UNUSED(code_size);
- UNUSED(code_name);
- UNUSED(source_file_name);
-#endif
-}
-
-}
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp
deleted file mode 100644
index 2f52dba4ac..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp
+++ /dev/null
@@ -1,32 +0,0 @@
-/*******************************************************************************
-* Copyright 2019 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef JIT_SUPPORT_HPP
-#define JIT_SUPPORT_HPP
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-namespace jit_utils {
-
-void register_jit_code(const void *code, size_t code_size,
- const char *code_name, const char *source_file_name);
-
-}
-}
-}
-}
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD
deleted file mode 100644
index 4fd21cea57..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD
+++ /dev/null
@@ -1,27 +0,0 @@
-Copyright (c) 2011, Intel Corporation
-All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
-
-1. Redistributions of source code must retain the above copyright notice, this
- list of conditions and the following disclaimer.
-
-2. Redistributions in binary form must reproduce the above copyright notice,
- this list of conditions and the following disclaimer in the documentation
- and/or other materials provided with the distribution.
-
-3. Neither the name of the copyright holder nor the names of its
- contributors may be used to endorse or promote products derived from
- this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md
deleted file mode 100644
index fc67c4f134..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md
+++ /dev/null
@@ -1 +0,0 @@
-This code is from [Intel SEAPI library](https://github.com/intel/IntelSEAPI)
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h
deleted file mode 100644
index edbf4a15f0..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h
+++ /dev/null
@@ -1,595 +0,0 @@
-/* <copyright>
-
- Contact Information:
- http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/
-
- BSD LICENSE
-
- Copyright (c) 2005-2014 Intel Corporation. All rights reserved.
- All rights reserved.
-
- Redistribution and use in source and binary forms, with or without
- modification, are permitted provided that the following conditions
- are met:
-
- * Redistributions of source code must retain the above copyright
- notice, this list of conditions and the following disclaimer.
- * Redistributions in binary form must reproduce the above copyright
- notice, this list of conditions and the following disclaimer in
- the documentation and/or other materials provided with the
- distribution.
- * Neither the name of Intel Corporation nor the names of its
- contributors may be used to endorse or promote products derived
- from this software without specific prior written permission.
-
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
- "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
- LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
- A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
- OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
- SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
- LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
- DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
- THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-</copyright> */
-#ifndef _ITTNOTIFY_CONFIG_H_
-#define _ITTNOTIFY_CONFIG_H_
-
-/** @cond exclude_from_documentation */
-#ifndef ITT_OS_WIN
-# define ITT_OS_WIN 1
-#endif /* ITT_OS_WIN */
-
-#ifndef ITT_OS_LINUX
-# define ITT_OS_LINUX 2
-#endif /* ITT_OS_LINUX */
-
-#ifndef ITT_OS_MAC
-# define ITT_OS_MAC 3
-#endif /* ITT_OS_MAC */
-
-#ifndef ITT_OS_FREEBSD
-# define ITT_OS_FREEBSD 4
-#endif /* ITT_OS_FREEBSD */
-
-#ifndef ITT_OS
-# if defined WIN32 || defined _WIN32
-# define ITT_OS ITT_OS_WIN
-# elif defined( __APPLE__ ) && defined( __MACH__ )
-# define ITT_OS ITT_OS_MAC
-# elif defined( __FreeBSD__ )
-# define ITT_OS ITT_OS_FREEBSD
-# else
-# define ITT_OS ITT_OS_LINUX
-# endif
-#endif /* ITT_OS */
-
-#ifndef ITT_PLATFORM_WIN
-# define ITT_PLATFORM_WIN 1
-#endif /* ITT_PLATFORM_WIN */
-
-#ifndef ITT_PLATFORM_POSIX
-# define ITT_PLATFORM_POSIX 2
-#endif /* ITT_PLATFORM_POSIX */
-
-#ifndef ITT_PLATFORM_MAC
-# define ITT_PLATFORM_MAC 3
-#endif /* ITT_PLATFORM_MAC */
-
-#ifndef ITT_PLATFORM_FREEBSD
-# define ITT_PLATFORM_FREEBSD 4
-#endif /* ITT_PLATFORM_FREEBSD */
-
-#ifndef ITT_PLATFORM
-# if ITT_OS==ITT_OS_WIN
-# define ITT_PLATFORM ITT_PLATFORM_WIN
-# elif ITT_OS==ITT_OS_MAC
-# define ITT_PLATFORM ITT_PLATFORM_MAC
-# elif ITT_OS==ITT_OS_FREEBSD
-# define ITT_PLATFORM ITT_PLATFORM_FREEBSD
-# else
-# define ITT_PLATFORM ITT_PLATFORM_POSIX
-# endif
-#endif /* ITT_PLATFORM */
-
-#if defined(_UNICODE) && !defined(UNICODE)
-#define UNICODE
-#endif
-
-#include <stddef.h>
-#if ITT_PLATFORM==ITT_PLATFORM_WIN
-#include <tchar.h>
-#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
-#include <stdint.h>
-#if defined(UNICODE) || defined(_UNICODE)
-#include <wchar.h>
-#endif /* UNICODE || _UNICODE */
-#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
-
-#ifndef ITTAPI_CDECL
-# if ITT_PLATFORM==ITT_PLATFORM_WIN
-# define ITTAPI_CDECL __cdecl
-# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
-# if defined _M_IX86 || defined __i386__
-# define ITTAPI_CDECL __attribute__ ((cdecl))
-# else /* _M_IX86 || __i386__ */
-# define ITTAPI_CDECL /* actual only on x86 platform */
-# endif /* _M_IX86 || __i386__ */
-# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
-#endif /* ITTAPI_CDECL */
-
-#ifndef STDCALL
-# if ITT_PLATFORM==ITT_PLATFORM_WIN
-# define STDCALL __stdcall
-# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
-# if defined _M_IX86 || defined __i386__
-# define STDCALL __attribute__ ((stdcall))
-# else /* _M_IX86 || __i386__ */
-# define STDCALL /* supported only on x86 platform */
-# endif /* _M_IX86 || __i386__ */
-# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
-#endif /* STDCALL */
-
-#define ITTAPI ITTAPI_CDECL
-#define LIBITTAPI ITTAPI_CDECL
-
-/* TODO: Temporary for compatibility! */
-#define ITTAPI_CALL ITTAPI_CDECL
-#define LIBITTAPI_CALL ITTAPI_CDECL
-
-#if ITT_PLATFORM==ITT_PLATFORM_WIN
-/* use __forceinline (VC++ specific) */
-#define ITT_INLINE __forceinline
-#define ITT_INLINE_ATTRIBUTE /* nothing */
-#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
-/*
- * Generally, functions are not inlined unless optimization is specified.
- * For functions declared inline, this attribute inlines the function even
- * if no optimization level was specified.
- */
-#ifdef __STRICT_ANSI__
-#define ITT_INLINE static
-#define ITT_INLINE_ATTRIBUTE __attribute__((unused))
-#else /* __STRICT_ANSI__ */
-#define ITT_INLINE static inline
-#define ITT_INLINE_ATTRIBUTE __attribute__((always_inline, unused))
-#endif /* __STRICT_ANSI__ */
-#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
-/** @endcond */
-
-#ifndef ITT_ARCH_IA32
-# define ITT_ARCH_IA32 1
-#endif /* ITT_ARCH_IA32 */
-
-#ifndef ITT_ARCH_IA32E
-# define ITT_ARCH_IA32E 2
-#endif /* ITT_ARCH_IA32E */
-
-#ifndef ITT_ARCH_ARM
-# define ITT_ARCH_ARM 4
-#endif /* ITT_ARCH_ARM */
-
-#ifndef ITT_ARCH_PPC64
-# define ITT_ARCH_PPC64 5
-#endif /* ITT_ARCH_PPC64 */
-
-#ifndef ITT_ARCH
-# if defined _M_IX86 || defined __i386__
-# define ITT_ARCH ITT_ARCH_IA32
-# elif defined _M_X64 || defined _M_AMD64 || defined __x86_64__
-# define ITT_ARCH ITT_ARCH_IA32E
-# elif defined _M_IA64 || defined __ia64__
-# define ITT_ARCH ITT_ARCH_IA64
-# elif defined _M_ARM || defined __arm__
-# define ITT_ARCH ITT_ARCH_ARM
-# elif defined __powerpc64__
-# define ITT_ARCH ITT_ARCH_PPC64
-# endif
-#endif
-
-#ifdef __cplusplus
-# define ITT_EXTERN_C extern "C"
-# define ITT_EXTERN_C_BEGIN extern "C" {
-# define ITT_EXTERN_C_END }
-#else
-# define ITT_EXTERN_C /* nothing */
-# define ITT_EXTERN_C_BEGIN /* nothing */
-# define ITT_EXTERN_C_END /* nothing */
-#endif /* __cplusplus */
-
-#define ITT_TO_STR_AUX(x) #x
-#define ITT_TO_STR(x) ITT_TO_STR_AUX(x)
-
-#define __ITT_BUILD_ASSERT(expr, suffix) do { \
- static char __itt_build_check_##suffix[(expr) ? 1 : -1]; \
- __itt_build_check_##suffix[0] = 0; \
-} while(0)
-#define _ITT_BUILD_ASSERT(expr, suffix) __ITT_BUILD_ASSERT((expr), suffix)
-#define ITT_BUILD_ASSERT(expr) _ITT_BUILD_ASSERT((expr), __LINE__)
-
-#define ITT_MAGIC { 0xED, 0xAB, 0xAB, 0xEC, 0x0D, 0xEE, 0xDA, 0x30 }
-
-/* Replace with snapshot date YYYYMMDD for promotion build. */
-#define API_VERSION_BUILD 20151119
-
-#ifndef API_VERSION_NUM
-#define API_VERSION_NUM 0.0.0
-#endif /* API_VERSION_NUM */
-
-#define API_VERSION "ITT-API-Version " ITT_TO_STR(API_VERSION_NUM) \
- " (" ITT_TO_STR(API_VERSION_BUILD) ")"
-
-/* OS communication functions */
-#if ITT_PLATFORM==ITT_PLATFORM_WIN
-#include <windows.h>
-typedef HMODULE lib_t;
-typedef DWORD TIDT;
-typedef CRITICAL_SECTION mutex_t;
-#define MUTEX_INITIALIZER { 0 }
-#define strong_alias(name, aliasname) /* empty for Windows */
-#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
-#include <dlfcn.h>
-#if defined(UNICODE) || defined(_UNICODE)
-#include <wchar.h>
-#endif /* UNICODE */
-#ifndef _GNU_SOURCE
-#define _GNU_SOURCE 1 /* need for PTHREAD_MUTEX_RECURSIVE */
-#endif /* _GNU_SOURCE */
-#ifndef __USE_UNIX98
-#define __USE_UNIX98 1 /* need for PTHREAD_MUTEX_RECURSIVE, on SLES11.1 with gcc 4.3.4 wherein pthread.h missing dependency on __USE_XOPEN2K8 */
-#endif /*__USE_UNIX98*/
-#include <pthread.h>
-typedef void* lib_t;
-typedef pthread_t TIDT;
-typedef pthread_mutex_t mutex_t;
-#define MUTEX_INITIALIZER PTHREAD_MUTEX_INITIALIZER
-#define _strong_alias(name, aliasname) \
- extern __typeof (name) aliasname __attribute__ ((alias (#name)));
-#define strong_alias(name, aliasname) _strong_alias(name, aliasname)
-#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
-
-#if ITT_PLATFORM==ITT_PLATFORM_WIN
-#define __itt_get_proc(lib, name) GetProcAddress(lib, name)
-#define __itt_mutex_init(mutex) InitializeCriticalSection(mutex)
-#define __itt_mutex_lock(mutex) EnterCriticalSection(mutex)
-#define __itt_mutex_unlock(mutex) LeaveCriticalSection(mutex)
-#define __itt_load_lib(name) LoadLibraryA(name)
-#define __itt_unload_lib(handle) FreeLibrary(handle)
-#define __itt_system_error() (int)GetLastError()
-#define __itt_fstrcmp(s1, s2) lstrcmpA(s1, s2)
-#define __itt_fstrnlen(s, l) strnlen_s(s, l)
-#define __itt_fstrcpyn(s1, b, s2, l) strncpy_s(s1, b, s2, l)
-#define __itt_fstrdup(s) _strdup(s)
-#define __itt_thread_id() GetCurrentThreadId()
-#define __itt_thread_yield() SwitchToThread()
-#ifndef ITT_SIMPLE_INIT
-ITT_INLINE long
-__itt_interlocked_increment(volatile long* ptr) ITT_INLINE_ATTRIBUTE;
-ITT_INLINE long __itt_interlocked_increment(volatile long* ptr)
-{
- return InterlockedIncrement(ptr);
-}
-#endif /* ITT_SIMPLE_INIT */
-
-#define DL_SYMBOLS (1)
-#define PTHREAD_SYMBOLS (1)
-
-#else /* ITT_PLATFORM!=ITT_PLATFORM_WIN */
-#define __itt_get_proc(lib, name) dlsym(lib, name)
-#define __itt_mutex_init(mutex) {\
- pthread_mutexattr_t mutex_attr; \
- int error_code = pthread_mutexattr_init(&mutex_attr); \
- if (error_code) \
- __itt_report_error(__itt_error_system, "pthread_mutexattr_init", \
- error_code); \
- error_code = pthread_mutexattr_settype(&mutex_attr, \
- PTHREAD_MUTEX_RECURSIVE); \
- if (error_code) \
- __itt_report_error(__itt_error_system, "pthread_mutexattr_settype", \
- error_code); \
- error_code = pthread_mutex_init(mutex, &mutex_attr); \
- if (error_code) \
- __itt_report_error(__itt_error_system, "pthread_mutex_init", \
- error_code); \
- error_code = pthread_mutexattr_destroy(&mutex_attr); \
- if (error_code) \
- __itt_report_error(__itt_error_system, "pthread_mutexattr_destroy", \
- error_code); \
-}
-#define __itt_mutex_lock(mutex) pthread_mutex_lock(mutex)
-#define __itt_mutex_unlock(mutex) pthread_mutex_unlock(mutex)
-#define __itt_load_lib(name) dlopen(name, RTLD_LAZY)
-#define __itt_unload_lib(handle) dlclose(handle)
-#define __itt_system_error() errno
-#define __itt_fstrcmp(s1, s2) strcmp(s1, s2)
-
-/* makes customer code define safe APIs for SDL_STRNLEN_S and SDL_STRNCPY_S */
-#ifdef SDL_STRNLEN_S
-#define __itt_fstrnlen(s, l) SDL_STRNLEN_S(s, l)
-#else
-#define __itt_fstrnlen(s, l) strlen(s)
-#endif /* SDL_STRNLEN_S */
-#ifdef SDL_STRNCPY_S
-#define __itt_fstrcpyn(s1, b, s2, l) SDL_STRNCPY_S(s1, b, s2, l)
-#else
-#define __itt_fstrcpyn(s1, b, s2, l) strncpy(s1, s2, l)
-#endif /* SDL_STRNCPY_S */
-
-#define __itt_fstrdup(s) strdup(s)
-#define __itt_thread_id() pthread_self()
-#define __itt_thread_yield() sched_yield()
-#if ITT_ARCH==ITT_ARCH_IA64
-#ifdef __INTEL_COMPILER
-#define __TBB_machine_fetchadd4(addr, val) __fetchadd4_acq((void *)addr, val)
-#else /* __INTEL_COMPILER */
-/* TODO: Add Support for not Intel compilers for IA-64 architecture */
-#endif /* __INTEL_COMPILER */
-#elif ITT_ARCH==ITT_ARCH_IA32 || ITT_ARCH==ITT_ARCH_IA32E /* ITT_ARCH!=ITT_ARCH_IA64 */
-ITT_INLINE long
-__TBB_machine_fetchadd4(volatile void* ptr, long addend) ITT_INLINE_ATTRIBUTE;
-ITT_INLINE long __TBB_machine_fetchadd4(volatile void* ptr, long addend)
-{
- long result;
- __asm__ __volatile__("lock\nxadd %0,%1"
- : "=r"(result),"=m"(*(int*)ptr)
- : "0"(addend), "m"(*(int*)ptr)
- : "memory");
- return result;
-}
-#elif ITT_ARCH==ITT_ARCH_ARM || ITT_ARCH==ITT_ARCH_PPC64
-#define __TBB_machine_fetchadd4(addr, val) __sync_fetch_and_add(addr, val)
-#endif /* ITT_ARCH==ITT_ARCH_IA64 */
-#ifndef ITT_SIMPLE_INIT
-ITT_INLINE long
-__itt_interlocked_increment(volatile long* ptr) ITT_INLINE_ATTRIBUTE;
-ITT_INLINE long __itt_interlocked_increment(volatile long* ptr)
-{
- return __TBB_machine_fetchadd4(ptr, 1) + 1L;
-}
-#endif /* ITT_SIMPLE_INIT */
-
-void* dlopen(const char*, int) __attribute__((weak));
-void* dlsym(void*, const char*) __attribute__((weak));
-int dlclose(void*) __attribute__((weak));
-#define DL_SYMBOLS (dlopen && dlsym && dlclose)
-
-int pthread_mutex_init(pthread_mutex_t*, const pthread_mutexattr_t*) __attribute__((weak));
-int pthread_mutex_lock(pthread_mutex_t*) __attribute__((weak));
-int pthread_mutex_unlock(pthread_mutex_t*) __attribute__((weak));
-int pthread_mutex_destroy(pthread_mutex_t*) __attribute__((weak));
-int pthread_mutexattr_init(pthread_mutexattr_t*) __attribute__((weak));
-int pthread_mutexattr_settype(pthread_mutexattr_t*, int) __attribute__((weak));
-int pthread_mutexattr_destroy(pthread_mutexattr_t*) __attribute__((weak));
-pthread_t pthread_self(void) __attribute__((weak));
-#define PTHREAD_SYMBOLS (pthread_mutex_init && pthread_mutex_lock && pthread_mutex_unlock && pthread_mutex_destroy && pthread_mutexattr_init && pthread_mutexattr_settype && pthread_mutexattr_destroy && pthread_self)
-
-#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
-
-typedef enum {
- __itt_collection_normal = 0,
- __itt_collection_paused = 1
-} __itt_collection_state;
-
-typedef enum {
- __itt_thread_normal = 0,
- __itt_thread_ignored = 1
-} __itt_thread_state;
-
-#pragma pack(push, 8)
-
-typedef struct ___itt_thread_info
-{
- const char* nameA; /*!< Copy of original name in ASCII. */
-#if defined(UNICODE) || defined(_UNICODE)
- const wchar_t* nameW; /*!< Copy of original name in UNICODE. */
-#else /* UNICODE || _UNICODE */
- void* nameW;
-#endif /* UNICODE || _UNICODE */
- TIDT tid;
- __itt_thread_state state; /*!< Thread state (paused or normal) */
- int extra1; /*!< Reserved to the runtime */
- void* extra2; /*!< Reserved to the runtime */
- struct ___itt_thread_info* next;
-} __itt_thread_info;
-
-#include "ittnotify_types.h" /* For __itt_group_id definition */
-
-typedef struct ___itt_api_info_20101001
-{
- const char* name;
- void** func_ptr;
- void* init_func;
- __itt_group_id group;
-} __itt_api_info_20101001;
-
-typedef struct ___itt_api_info
-{
- const char* name;
- void** func_ptr;
- void* init_func;
- void* null_func;
- __itt_group_id group;
-} __itt_api_info;
-
-typedef struct __itt_counter_info
-{
- const char* nameA; /*!< Copy of original name in ASCII. */
-#if defined(UNICODE) || defined(_UNICODE)
- const wchar_t* nameW; /*!< Copy of original name in UNICODE. */
-#else /* UNICODE || _UNICODE */
- void* nameW;
-#endif /* UNICODE || _UNICODE */
- const char* domainA; /*!< Copy of original name in ASCII. */
-#if defined(UNICODE) || defined(_UNICODE)
- const wchar_t* domainW; /*!< Copy of original name in UNICODE. */
-#else /* UNICODE || _UNICODE */
- void* domainW;
-#endif /* UNICODE || _UNICODE */
- int type;
- long index;
- int extra1; /*!< Reserved to the runtime */
- void* extra2; /*!< Reserved to the runtime */
- struct __itt_counter_info* next;
-} __itt_counter_info_t;
-
-struct ___itt_domain;
-struct ___itt_string_handle;
-
-typedef struct ___itt_global
-{
- unsigned char magic[8];
- unsigned long version_major;
- unsigned long version_minor;
- unsigned long version_build;
- volatile long api_initialized;
- volatile long mutex_initialized;
- volatile long atomic_counter;
- mutex_t mutex;
- lib_t lib;
- void* error_handler;
- const char** dll_path_ptr;
- __itt_api_info* api_list_ptr;
- struct ___itt_global* next;
- /* Joinable structures below */
- __itt_thread_info* thread_list;
- struct ___itt_domain* domain_list;
- struct ___itt_string_handle* string_list;
- __itt_collection_state state;
- __itt_counter_info_t* counter_list;
-} __itt_global;
-
-#pragma pack(pop)
-
-#define NEW_THREAD_INFO_W(gptr,h,h_tail,t,s,n) { \
- h = (__itt_thread_info*)malloc(sizeof(__itt_thread_info)); \
- if (h != NULL) { \
- h->tid = t; \
- h->nameA = NULL; \
- h->nameW = n ? _wcsdup(n) : NULL; \
- h->state = s; \
- h->extra1 = 0; /* reserved */ \
- h->extra2 = NULL; /* reserved */ \
- h->next = NULL; \
- if (h_tail == NULL) \
- (gptr)->thread_list = h; \
- else \
- h_tail->next = h; \
- } \
-}
-
-#define NEW_THREAD_INFO_A(gptr,h,h_tail,t,s,n) { \
- h = (__itt_thread_info*)malloc(sizeof(__itt_thread_info)); \
- if (h != NULL) { \
- h->tid = t; \
- h->nameA = n ? __itt_fstrdup(n) : NULL; \
- h->nameW = NULL; \
- h->state = s; \
- h->extra1 = 0; /* reserved */ \
- h->extra2 = NULL; /* reserved */ \
- h->next = NULL; \
- if (h_tail == NULL) \
- (gptr)->thread_list = h; \
- else \
- h_tail->next = h; \
- } \
-}
-
-#define NEW_DOMAIN_W(gptr,h,h_tail,name) { \
- h = (__itt_domain*)malloc(sizeof(__itt_domain)); \
- if (h != NULL) { \
- h->flags = 1; /* domain is enabled by default */ \
- h->nameA = NULL; \
- h->nameW = name ? _wcsdup(name) : NULL; \
- h->extra1 = 0; /* reserved */ \
- h->extra2 = NULL; /* reserved */ \
- h->next = NULL; \
- if (h_tail == NULL) \
- (gptr)->domain_list = h; \
- else \
- h_tail->next = h; \
- } \
-}
-
-#define NEW_DOMAIN_A(gptr,h,h_tail,name) { \
- h = (__itt_domain*)malloc(sizeof(__itt_domain)); \
- if (h != NULL) { \
- h->flags = 1; /* domain is enabled by default */ \
- h->nameA = name ? __itt_fstrdup(name) : NULL; \
- h->nameW = NULL; \
- h->extra1 = 0; /* reserved */ \
- h->extra2 = NULL; /* reserved */ \
- h->next = NULL; \
- if (h_tail == NULL) \
- (gptr)->domain_list = h; \
- else \
- h_tail->next = h; \
- } \
-}
-
-#define NEW_STRING_HANDLE_W(gptr,h,h_tail,name) { \
- h = (__itt_string_handle*)malloc(sizeof(__itt_string_handle)); \
- if (h != NULL) { \
- h->strA = NULL; \
- h->strW = name ? _wcsdup(name) : NULL; \
- h->extra1 = 0; /* reserved */ \
- h->extra2 = NULL; /* reserved */ \
- h->next = NULL; \
- if (h_tail == NULL) \
- (gptr)->string_list = h; \
- else \
- h_tail->next = h; \
- } \
-}
-
-#define NEW_STRING_HANDLE_A(gptr,h,h_tail,name) { \
- h = (__itt_string_handle*)malloc(sizeof(__itt_string_handle)); \
- if (h != NULL) { \
- h->strA = name ? __itt_fstrdup(name) : NULL; \
- h->strW = NULL; \
- h->extra1 = 0; /* reserved */ \
- h->extra2 = NULL; /* reserved */ \
- h->next = NULL; \
- if (h_tail == NULL) \
- (gptr)->string_list = h; \
- else \
- h_tail->next = h; \
- } \
-}
-
-#define NEW_COUNTER_W(gptr,h,h_tail,name,domain,type) { \
- h = (__itt_counter_info_t*)malloc(sizeof(__itt_counter_info_t)); \
- if (h != NULL) { \
- h->nameA = NULL; \
- h->nameW = name ? _wcsdup(name) : NULL; \
- h->domainA = NULL; \
- h->domainW = name ? _wcsdup(domain) : NULL; \
- h->type = type; \
- h->index = 0; \
- h->next = NULL; \
- if (h_tail == NULL) \
- (gptr)->counter_list = h; \
- else \
- h_tail->next = h; \
- } \
-}
-
-#define NEW_COUNTER_A(gptr,h,h_tail,name,domain,type) { \
- h = (__itt_counter_info_t*)malloc(sizeof(__itt_counter_info_t)); \
- if (h != NULL) { \
- h->nameA = name ? __itt_fstrdup(name) : NULL; \
- h->nameW = NULL; \
- h->domainA = domain ? __itt_fstrdup(domain) : NULL; \
- h->domainW = NULL; \
- h->type = type; \
- h->index = 0; \
- h->next = NULL; \
- if (h_tail == NULL) \
- (gptr)->counter_list = h; \
- else \
- h_tail->next = h; \
- } \
-}
-
-#endif /* _ITTNOTIFY_CONFIG_H_ */
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h
deleted file mode 100644
index 99fbc24054..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h
+++ /dev/null
@@ -1,94 +0,0 @@
-/* <copyright>
-
- Contact Information:
- http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/
-
- BSD LICENSE
-
- Copyright (c) 2005-2014 Intel Corporation. All rights reserved.
- All rights reserved.
-
- Redistribution and use in source and binary forms, with or without
- modification, are permitted provided that the following conditions
- are met:
-
- * Redistributions of source code must retain the above copyright
- notice, this list of conditions and the following disclaimer.
- * Redistributions in binary form must reproduce the above copyright
- notice, this list of conditions and the following disclaimer in
- the documentation and/or other materials provided with the
- distribution.
- * Neither the name of Intel Corporation nor the names of its
- contributors may be used to endorse or promote products derived
- from this software without specific prior written permission.
-
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
- "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
- LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
- A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
- OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
- SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
- LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
- DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
- THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-</copyright> */
-
-#ifndef _ITTNOTIFY_TYPES_H_
-#define _ITTNOTIFY_TYPES_H_
-
-typedef enum ___itt_group_id
-{
- __itt_group_none = 0,
- __itt_group_legacy = 1<<0,
- __itt_group_control = 1<<1,
- __itt_group_thread = 1<<2,
- __itt_group_mark = 1<<3,
- __itt_group_sync = 1<<4,
- __itt_group_fsync = 1<<5,
- __itt_group_jit = 1<<6,
- __itt_group_model = 1<<7,
- __itt_group_splitter_min = 1<<7,
- __itt_group_counter = 1<<8,
- __itt_group_frame = 1<<9,
- __itt_group_stitch = 1<<10,
- __itt_group_heap = 1<<11,
- __itt_group_splitter_max = 1<<12,
- __itt_group_structure = 1<<12,
- __itt_group_suppress = 1<<13,
- __itt_group_arrays = 1<<14,
- __itt_group_all = -1
-} __itt_group_id;
-
-#pragma pack(push, 8)
-
-typedef struct ___itt_group_list
-{
- __itt_group_id id;
- const char* name;
-} __itt_group_list;
-
-#pragma pack(pop)
-
-#define ITT_GROUP_LIST(varname) \
- static __itt_group_list varname[] = { \
- { __itt_group_all, "all" }, \
- { __itt_group_control, "control" }, \
- { __itt_group_thread, "thread" }, \
- { __itt_group_mark, "mark" }, \
- { __itt_group_sync, "sync" }, \
- { __itt_group_fsync, "fsync" }, \
- { __itt_group_jit, "jit" }, \
- { __itt_group_model, "model" }, \
- { __itt_group_counter, "counter" }, \
- { __itt_group_frame, "frame" }, \
- { __itt_group_stitch, "stitch" }, \
- { __itt_group_heap, "heap" }, \
- { __itt_group_structure, "structure" }, \
- { __itt_group_suppress, "suppress" }, \
- { __itt_group_arrays, "arrays" }, \
- { __itt_group_none, NULL } \
- }
-
-#endif /* _ITTNOTIFY_TYPES_H_ */
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c
deleted file mode 100644
index 15f4b9929b..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c
+++ /dev/null
@@ -1,293 +0,0 @@
-/* <copyright>
-
- Contact Information:
- http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/
-
- BSD LICENSE
-
- Copyright (c) 2005-2014 Intel Corporation. All rights reserved.
- All rights reserved.
-
- Redistribution and use in source and binary forms, with or without
- modification, are permitted provided that the following conditions
- are met:
-
- * Redistributions of source code must retain the above copyright
- notice, this list of conditions and the following disclaimer.
- * Redistributions in binary form must reproduce the above copyright
- notice, this list of conditions and the following disclaimer in
- the documentation and/or other materials provided with the
- distribution.
- * Neither the name of Intel Corporation nor the names of its
- contributors may be used to endorse or promote products derived
- from this software without specific prior written permission.
-
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
- "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
- LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
- A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
- OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
- SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
- LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
- DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
- THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-</copyright> */
-
-#include "ittnotify_config.h"
-
-#if ITT_PLATFORM==ITT_PLATFORM_WIN
-#include <windows.h>
-#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
-#if ITT_PLATFORM != ITT_PLATFORM_MAC && ITT_PLATFORM != ITT_PLATFORM_FREEBSD
-#include <malloc.h>
-#endif
-#include <stdlib.h>
-
-#include "jitprofiling.h"
-
-static const char rcsid[] = "\n@(#) $Revision: 471937 $\n";
-
-#define DLL_ENVIRONMENT_VAR "VS_PROFILER"
-
-#ifndef NEW_DLL_ENVIRONMENT_VAR
-#if ITT_ARCH==ITT_ARCH_IA32
-#define NEW_DLL_ENVIRONMENT_VAR "INTEL_JIT_PROFILER32"
-#else
-#define NEW_DLL_ENVIRONMENT_VAR "INTEL_JIT_PROFILER64"
-#endif
-#endif /* NEW_DLL_ENVIRONMENT_VAR */
-
-#if ITT_PLATFORM==ITT_PLATFORM_WIN
-#define DEFAULT_DLLNAME "JitPI.dll"
-HINSTANCE m_libHandle = NULL;
-#elif ITT_PLATFORM==ITT_PLATFORM_MAC
-#define DEFAULT_DLLNAME "libJitPI.dylib"
-void* m_libHandle = NULL;
-#else
-#define DEFAULT_DLLNAME "libJitPI.so"
-void* m_libHandle = NULL;
-#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
-
-/* default location of JIT profiling agent on Android */
-#define ANDROID_JIT_AGENT_PATH "/data/intel/libittnotify.so"
-
-/* the function pointers */
-typedef unsigned int(JITAPI *TPInitialize)(void);
-static TPInitialize FUNC_Initialize=NULL;
-
-typedef unsigned int(JITAPI *TPNotify)(unsigned int, void*);
-static TPNotify FUNC_NotifyEvent=NULL;
-
-static iJIT_IsProfilingActiveFlags executionMode = iJIT_NOTHING_RUNNING;
-
-/* end collector dll part. */
-
-/* loadiJIT_Funcs() : this function is called just in the beginning
- * and is responsible to load the functions from BistroJavaCollector.dll
- * result:
- * on success: the functions loads, iJIT_DLL_is_missing=0, return value = 1
- * on failure: the functions are NULL, iJIT_DLL_is_missing=1, return value = 0
- */
-static int loadiJIT_Funcs(void);
-
-/* global representing whether the collector can't be loaded */
-static int iJIT_DLL_is_missing = 0;
-
-ITT_EXTERN_C int JITAPI
-iJIT_NotifyEvent(iJIT_JVM_EVENT event_type, void *EventSpecificData)
-{
- int ReturnValue = 0;
-
- /* initialization part - the collector has not been loaded yet. */
- if (!FUNC_NotifyEvent)
- {
- if (iJIT_DLL_is_missing)
- return 0;
-
- if (!loadiJIT_Funcs())
- return 0;
- }
-
- if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED ||
- event_type == iJVM_EVENT_TYPE_METHOD_UPDATE)
- {
- if (((piJIT_Method_Load)EventSpecificData)->method_id == 0)
- return 0;
- }
- else if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2)
- {
- if (((piJIT_Method_Load_V2)EventSpecificData)->method_id == 0)
- return 0;
- }
- else if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3)
- {
- if (((piJIT_Method_Load_V3)EventSpecificData)->method_id == 0)
- return 0;
- }
- else if (event_type == iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED)
- {
- if (((piJIT_Method_Inline_Load)EventSpecificData)->method_id == 0 ||
- ((piJIT_Method_Inline_Load)EventSpecificData)->parent_method_id == 0)
- return 0;
- }
-
- ReturnValue = (int)FUNC_NotifyEvent(event_type, EventSpecificData);
-
- return ReturnValue;
-}
-
-ITT_EXTERN_C iJIT_IsProfilingActiveFlags JITAPI iJIT_IsProfilingActive()
-{
- if (!iJIT_DLL_is_missing)
- {
- loadiJIT_Funcs();
- }
-
- return executionMode;
-}
-
-/* This function loads the collector dll and the relevant functions.
- * on success: all functions load, iJIT_DLL_is_missing = 0, return value = 1
- * on failure: all functions are NULL, iJIT_DLL_is_missing = 1, return value = 0
- */
-static int loadiJIT_Funcs()
-{
- static int bDllWasLoaded = 0;
- char *dllName = (char*)rcsid; /* !! Just to avoid unused code elimination */
-#if ITT_PLATFORM==ITT_PLATFORM_WIN
- DWORD dNameLength = 0;
-#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
-
- if(bDllWasLoaded)
- {
- /* dll was already loaded, no need to do it for the second time */
- return 1;
- }
-
- /* Assumes that the DLL will not be found */
- iJIT_DLL_is_missing = 1;
- FUNC_NotifyEvent = NULL;
-
- if (m_libHandle)
- {
-#if ITT_PLATFORM==ITT_PLATFORM_WIN
- FreeLibrary(m_libHandle);
-#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
- dlclose(m_libHandle);
-#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
- m_libHandle = NULL;
- }
-
- /* Try to get the dll name from the environment */
-#if ITT_PLATFORM==ITT_PLATFORM_WIN
- dNameLength = GetEnvironmentVariableA(NEW_DLL_ENVIRONMENT_VAR, NULL, 0);
- if (dNameLength)
- {
- DWORD envret = 0;
- dllName = (char*)malloc(sizeof(char) * (dNameLength + 1));
- if(dllName != NULL)
- {
- envret = GetEnvironmentVariableA(NEW_DLL_ENVIRONMENT_VAR,
- dllName, dNameLength);
- if (envret)
- {
- /* Try to load the dll from the PATH... */
- m_libHandle = LoadLibraryExA(dllName,
- NULL, LOAD_WITH_ALTERED_SEARCH_PATH);
- }
- free(dllName);
- }
- } else {
- /* Try to use old VS_PROFILER variable */
- dNameLength = GetEnvironmentVariableA(DLL_ENVIRONMENT_VAR, NULL, 0);
- if (dNameLength)
- {
- DWORD envret = 0;
- dllName = (char*)malloc(sizeof(char) * (dNameLength + 1));
- if(dllName != NULL)
- {
- envret = GetEnvironmentVariableA(DLL_ENVIRONMENT_VAR,
- dllName, dNameLength);
- if (envret)
- {
- /* Try to load the dll from the PATH... */
- m_libHandle = LoadLibraryA(dllName);
- }
- free(dllName);
- }
- }
- }
-#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
- dllName = getenv(NEW_DLL_ENVIRONMENT_VAR);
- if (!dllName)
- dllName = getenv(DLL_ENVIRONMENT_VAR);
-#if defined(__ANDROID__) || defined(ANDROID)
- if (!dllName)
- dllName = ANDROID_JIT_AGENT_PATH;
-#endif
- if (dllName)
- {
- /* Try to load the dll from the PATH... */
- m_libHandle = dlopen(dllName, RTLD_LAZY);
- }
-#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
-
- if (!m_libHandle)
- {
-#if ITT_PLATFORM==ITT_PLATFORM_WIN
- m_libHandle = LoadLibraryA(DEFAULT_DLLNAME);
-#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
- m_libHandle = dlopen(DEFAULT_DLLNAME, RTLD_LAZY);
-#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
- }
-
- /* if the dll wasn't loaded - exit. */
- if (!m_libHandle)
- {
- iJIT_DLL_is_missing = 1; /* don't try to initialize
- * JIT agent the second time
- */
- return 0;
- }
-
-#if ITT_PLATFORM==ITT_PLATFORM_WIN
- FUNC_NotifyEvent = (TPNotify)GetProcAddress(m_libHandle, "NotifyEvent");
-#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
- FUNC_NotifyEvent = (TPNotify)dlsym(m_libHandle, "NotifyEvent");
-#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
- if (!FUNC_NotifyEvent)
- {
- FUNC_Initialize = NULL;
- return 0;
- }
-
-#if ITT_PLATFORM==ITT_PLATFORM_WIN
- FUNC_Initialize = (TPInitialize)GetProcAddress(m_libHandle, "Initialize");
-#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
- FUNC_Initialize = (TPInitialize)dlsym(m_libHandle, "Initialize");
-#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
- if (!FUNC_Initialize)
- {
- FUNC_NotifyEvent = NULL;
- return 0;
- }
-
- executionMode = (iJIT_IsProfilingActiveFlags)FUNC_Initialize();
-
- bDllWasLoaded = 1;
- iJIT_DLL_is_missing = 0; /* DLL is ok. */
-
- return 1;
-}
-
-ITT_EXTERN_C unsigned int JITAPI iJIT_GetNewMethodID()
-{
- static unsigned int methodID = 1;
-
- if (methodID == 0)
- return 0; /* ERROR : this is not a valid value */
-
- return methodID++;
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h
deleted file mode 100644
index bf0489b1a1..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h
+++ /dev/null
@@ -1,673 +0,0 @@
-/* <copyright>
-
- Contact Information:
- http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/
-
- BSD LICENSE
-
- Copyright (c) 2005-2014 Intel Corporation. All rights reserved.
- All rights reserved.
-
- Redistribution and use in source and binary forms, with or without
- modification, are permitted provided that the following conditions
- are met:
-
- * Redistributions of source code must retain the above copyright
- notice, this list of conditions and the following disclaimer.
- * Redistributions in binary form must reproduce the above copyright
- notice, this list of conditions and the following disclaimer in
- the documentation and/or other materials provided with the
- distribution.
- * Neither the name of Intel Corporation nor the names of its
- contributors may be used to endorse or promote products derived
- from this software without specific prior written permission.
-
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
- "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
- LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
- A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
- OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
- SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
- LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
- DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
- THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
- (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
- OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-</copyright> */
-
-#ifndef __JITPROFILING_H__
-#define __JITPROFILING_H__
-
-/**
- * @brief JIT Profiling APIs
- *
- * The JIT Profiling API is used to report information about just-in-time
- * generated code that can be used by performance tools. The user inserts
- * calls in the code generator to report information before JIT-compiled
- * code goes to execution. This information is collected at runtime and used
- * by tools like Intel(R) VTune(TM) Amplifier to display performance metrics
- * associated with JIT-compiled code.
- *
- * These APIs can be used to\n
- * - **Profile trace-based and method-based JIT-compiled
- * code**. Some examples of environments that you can profile with these APIs:
- * dynamic JIT compilation of JavaScript code traces, JIT execution in OpenCL(TM)
- * software technology, Java/.NET managed execution environments, and custom
- * ISV JIT engines.
- * @code
- * #include <jitprofiling.h>
- *
- * if (iJIT_IsProfilingActive != iJIT_SAMPLING_ON) {
- * return;
- * }
- *
- * iJIT_Method_Load jmethod = {0};
- * jmethod.method_id = iJIT_GetNewMethodID();
- * jmethod.method_name = "method_name";
- * jmethod.class_file_name = "class_name";
- * jmethod.source_file_name = "source_file_name";
- * jmethod.method_load_address = code_addr;
- * jmethod.method_size = code_size;
- *
- * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&jmethod);
- * iJIT_NotifyEvent(iJVM_EVENT_TYPE_SHUTDOWN, NULL);
- * @endcode
- *
- * * Expected behavior:
- * * If any iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED event overwrites an
- * already reported method, then such a method becomes invalid and its
- * memory region is treated as unloaded. VTune Amplifier displays the metrics
- * collected by the method until it is overwritten.
- * * If supplied line number information contains multiple source lines for
- * the same assembly instruction (code location), then VTune Amplifier picks up
- * the first line number.
- * * Dynamically generated code can be associated with a module name.
- * Use the iJIT_Method_Load_V2 structure.\n
- * Clarification of some cases:
- * * If you register a function with the same method ID multiple times,
- * specifying different module names, then the VTune Amplifier picks up
- * the module name registered first. If you want to distinguish the same
- * function between different JIT engines, supply different method IDs for
- * each function. Other symbolic information (for example, source file)
- * can be identical.
- *
- * - **Analyze split functions** (multiple joint or disjoint code regions
- * belonging to the same function) **including re-JIT**
- * with potential overlapping of code regions in time, which is common in
- * resource-limited environments.
- * @code
- * #include <jitprofiling.h>
- *
- * unsigned int method_id = iJIT_GetNewMethodID();
- *
- * iJIT_Method_Load a = {0};
- * a.method_id = method_id;
- * a.method_load_address = 0x100;
- * a.method_size = 0x20;
- *
- * iJIT_Method_Load b = {0};
- * b.method_id = method_id;
- * b.method_load_address = 0x200;
- * b.method_size = 0x30;
- *
- * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&a);
- * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&b);
- * @endcode
- *
- * * Expected behaviour:
- * * If a iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED event overwrites an
- * already reported method, then such a method becomes invalid and
- * its memory region is treated as unloaded.
- * * All code regions reported with the same method ID are considered as
- * belonging to the same method. Symbolic information (method name,
- * source file name) will be taken from the first notification, and all
- * subsequent notifications with the same method ID will be processed
- * only for line number table information. So, the VTune Amplifier will map
- * samples to a source line using the line number table from the current
- * notification while taking the source file name from the very first one.\n
- * Clarification of some cases:\n
- * * If you register a second code region with a different source file
- * name and the same method ID, then this information will be saved and
- * will not be considered as an extension of the first code region, but
- * VTune Amplifier will use the source file of the first code region and map
- * performance metrics incorrectly.
- * * If you register a second code region with the same source file as
- * for the first region and the same method ID, then the source file will be
- * discarded but VTune Amplifier will map metrics to the source file correctly.
- * * If you register a second code region with a null source file and
- * the same method ID, then provided line number info will be associated
- * with the source file of the first code region.
- *
- * - **Explore inline functions** including multi-level hierarchy of
- * nested inline methods which shows how performance metrics are distributed through them.
- * @code
- * #include <jitprofiling.h>
- *
- * // method_id parent_id
- * // [-- c --] 3000 2000
- * // [---- d -----] 2001 1000
- * // [---- b ----] 2000 1000
- * // [------------ a ----------------] 1000 n/a
- *
- * iJIT_Method_Load a = {0};
- * a.method_id = 1000;
- *
- * iJIT_Method_Inline_Load b = {0};
- * b.method_id = 2000;
- * b.parent_method_id = 1000;
- *
- * iJIT_Method_Inline_Load c = {0};
- * c.method_id = 3000;
- * c.parent_method_id = 2000;
- *
- * iJIT_Method_Inline_Load d = {0};
- * d.method_id = 2001;
- * d.parent_method_id = 1000;
- *
- * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&a);
- * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&b);
- * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&c);
- * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&d);
- * @endcode
- *
- * * Requirements:
- * * Each inline (iJIT_Method_Inline_Load) method should be associated
- * with two method IDs: one for itself; one for its immediate parent.
- * * Address regions of inline methods of the same parent method cannot
- * overlap each other.
- * * Execution of the parent method must not be started until it and all
- * its inline methods are reported.
- * * Expected behaviour:
- * * In case of nested inline methods an order of
- * iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED events is not important.
- * * If any event overwrites either inline method or top parent method,
- * then the parent, including inline methods, becomes invalid and its memory
- * region is treated as unloaded.
- *
- * **Life time of allocated data**\n
- * The client sends an event notification to the agent with event-specific
- * data, which is a structure. The pointers in the structure refer to memory
- * allocated by the client, which responsible for releasing it. The pointers are
- * used by the iJIT_NotifyEvent method to copy client's data in a trace file,
- * and they are not used after the iJIT_NotifyEvent method returns.
- */
-
-/**
- * @defgroup jitapi JIT Profiling
- * @ingroup internal
- * @{
- */
-
-/**
- * @brief Enumerator for the types of notifications
- */
-typedef enum iJIT_jvm_event
-{
- iJVM_EVENT_TYPE_SHUTDOWN = 2, /**<\brief Send this to shutdown the agent.
- * Use NULL for event data. */
-
- iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED = 13, /**<\brief Send when dynamic code is
- * JIT compiled and loaded into
- * memory by the JIT engine, but
- * before the code is executed.
- * Use iJIT_Method_Load as event
- * data. */
-/** @cond exclude_from_documentation */
- iJVM_EVENT_TYPE_METHOD_UNLOAD_START, /**<\brief Send when compiled dynamic
- * code is being unloaded from memory.
- * Use iJIT_Method_Load as event data.*/
-/** @endcond */
-
- iJVM_EVENT_TYPE_METHOD_UPDATE, /**<\brief Send to provide new content for
- * a previously reported dynamic code.
- * The previous content will be invalidated
- * starting from the time of the notification.
- * Use iJIT_Method_Load as event data but
- * required fields are following:
- * - method_id identify the code to update.
- * - method_load_address specify start address
- * within identified code range
- * where update should be started.
- * - method_size specify length of updated code
- * range. */
-
-
- iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, /**<\brief Send when an inline dynamic
- * code is JIT compiled and loaded
- * into memory by the JIT engine,
- * but before the parent code region
- * starts executing.
- * Use iJIT_Method_Inline_Load as event data.*/
-
-/** @cond exclude_from_documentation */
- iJVM_EVENT_TYPE_METHOD_UPDATE_V2,
-/** @endcond */
-
- iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2 = 21, /**<\brief Send when a dynamic code is
- * JIT compiled and loaded into
- * memory by the JIT engine, but
- * before the code is executed.
- * Use iJIT_Method_Load_V2 as event data. */
-
- iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3 /**<\brief Send when a dynamic code is
- * JIT compiled and loaded into
- * memory by the JIT engine, but
- * before the code is executed.
- * Use iJIT_Method_Load_V3 as event data. */
-} iJIT_JVM_EVENT;
-
-/**
- * @brief Enumerator for the agent's mode
- */
-typedef enum _iJIT_IsProfilingActiveFlags
-{
- iJIT_NOTHING_RUNNING = 0x0000, /**<\brief The agent is not running;
- * iJIT_NotifyEvent calls will
- * not be processed. */
- iJIT_SAMPLING_ON = 0x0001, /**<\brief The agent is running and
- * ready to process notifications. */
-} iJIT_IsProfilingActiveFlags;
-
-/**
- * @brief Description of a single entry in the line number information of a code region.
- * @details A table of line number entries gives information about how the reported code region
- * is mapped to source file.
- * Intel(R) VTune(TM) Amplifier uses line number information to attribute
- * the samples (virtual address) to a line number. \n
- * It is acceptable to report different code addresses for the same source line:
- * @code
- * Offset LineNumber
- * 1 2
- * 12 4
- * 15 2
- * 18 1
- * 21 30
- *
- * VTune Amplifier constructs the following table using the client data
- *
- * Code subrange Line number
- * 0-1 2
- * 1-12 4
- * 12-15 2
- * 15-18 1
- * 18-21 30
- * @endcode
- */
-typedef struct _LineNumberInfo
-{
- unsigned int Offset; /**<\brief Offset from the begining of the code region. */
- unsigned int LineNumber; /**<\brief Matching source line number offset (from beginning of source file). */
-
-} *pLineNumberInfo, LineNumberInfo;
-
-/**
- * @brief Enumerator for the code architecture.
- */
-typedef enum _iJIT_CodeArchitecture
-{
- iJIT_CA_NATIVE = 0, /**<\brief Native to the process architecture that is calling it. */
-
- iJIT_CA_32, /**<\brief 32-bit machine code. */
-
- iJIT_CA_64 /**<\brief 64-bit machine code. */
-
-} iJIT_CodeArchitecture;
-
-#pragma pack(push, 8)
-
-/**
- * @brief Description of a JIT-compiled method
- * @details When you use the iJIT_Method_Load structure to describe
- * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED
- * as an event type to report it.
- */
-typedef struct _iJIT_Method_Load
-{
- unsigned int method_id; /**<\brief Unique method ID. Cannot be 0.
- * You must either use the API function
- * iJIT_GetNewMethodID to get a valid and unique
- * method ID, or else manage ID uniqueness
- * and correct range by yourself.\n
- * You must use the same method ID for all code
- * regions of the same method, otherwise different
- * method IDs specify different methods. */
-
- char* method_name; /**<\brief The name of the method. It can be optionally
- * prefixed with its class name and appended with
- * its complete signature. Can't be NULL. */
-
- void* method_load_address; /**<\brief The start virtual address of the method code
- * region. If NULL, data provided with
- * event are not accepted. */
-
- unsigned int method_size; /**<\brief The code size of the method in memory.
- * If 0, then data provided with the event are not
- * accepted. */
-
- unsigned int line_number_size; /**<\brief The number of entries in the line number
- * table.0 if none. */
-
- pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info
- * array. Can be NULL if
- * line_number_size is 0. See
- * LineNumberInfo Structure for a
- * description of a single entry in
- * the line number info array */
-
- unsigned int class_id; /**<\brief This field is obsolete. */
-
- char* class_file_name; /**<\brief Class name. Can be NULL.*/
-
- char* source_file_name; /**<\brief Source file name. Can be NULL.*/
-
-} *piJIT_Method_Load, iJIT_Method_Load;
-
-/**
- * @brief Description of a JIT-compiled method
- * @details When you use the iJIT_Method_Load_V2 structure to describe
- * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2
- * as an event type to report it.
- */
-typedef struct _iJIT_Method_Load_V2
-{
- unsigned int method_id; /**<\brief Unique method ID. Cannot be 0.
- * You must either use the API function
- * iJIT_GetNewMethodID to get a valid and unique
- * method ID, or else manage ID uniqueness
- * and correct range by yourself.\n
- * You must use the same method ID for all code
- * regions of the same method, otherwise different
- * method IDs specify different methods. */
-
- char* method_name; /**<\brief The name of the method. It can be optionally
- * prefixed with its class name and appended with
- * its complete signature. Can't be NULL. */
-
- void* method_load_address; /**<\brief The start virtual address of the method code
- * region. If NULL, then data provided with the
- * event are not accepted. */
-
- unsigned int method_size; /**<\brief The code size of the method in memory.
- * If 0, then data provided with the event are not
- * accepted. */
-
- unsigned int line_number_size; /**<\brief The number of entries in the line number
- * table. 0 if none. */
-
- pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info
- * array. Can be NULL if
- * line_number_size is 0. See
- * LineNumberInfo Structure for a
- * description of a single entry in
- * the line number info array. */
-
- char* class_file_name; /**<\brief Class name. Can be NULL. */
-
- char* source_file_name; /**<\brief Source file name. Can be NULL. */
-
- char* module_name; /**<\brief Module name. Can be NULL.
- The module name can be useful for distinguishing among
- different JIT engines. VTune Amplifier will display
- reported methods grouped by specific module. */
-
-} *piJIT_Method_Load_V2, iJIT_Method_Load_V2;
-
-/**
- * @brief Description of a JIT-compiled method
- * @details The iJIT_Method_Load_V3 structure is the same as iJIT_Method_Load_V2
- * with a newly introduced 'arch' field that specifies architecture of the code region.
- * When you use the iJIT_Method_Load_V3 structure to describe
- * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3
- * as an event type to report it.
- */
-typedef struct _iJIT_Method_Load_V3
-{
- unsigned int method_id; /**<\brief Unique method ID. Cannot be 0.
- * You must either use the API function
- * iJIT_GetNewMethodID to get a valid and unique
- * method ID, or manage ID uniqueness
- * and correct range by yourself.\n
- * You must use the same method ID for all code
- * regions of the same method, otherwise they are
- * treated as regions of different methods. */
-
- char* method_name; /**<\brief The name of the method. It can be optionally
- * prefixed with its class name and appended with
- * its complete signature. Cannot be NULL. */
-
- void* method_load_address; /**<\brief The start virtual address of the method code
- * region. If NULL, then data provided with the
- * event are not accepted. */
-
- unsigned int method_size; /**<\brief The code size of the method in memory.
- * If 0, then data provided with the event are not
- * accepted. */
-
- unsigned int line_number_size; /**<\brief The number of entries in the line number
- * table. 0 if none. */
-
- pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info
- * array. Can be NULL if
- * line_number_size is 0. See
- * LineNumberInfo Structure for a
- * description of a single entry in
- * the line number info array. */
-
- char* class_file_name; /**<\brief Class name. Can be NULL. */
-
- char* source_file_name; /**<\brief Source file name. Can be NULL. */
-
- char* module_name; /**<\brief Module name. Can be NULL.
- * The module name can be useful for distinguishing among
- * different JIT engines. VTune Amplifier will display
- * reported methods grouped by specific module. */
-
- iJIT_CodeArchitecture module_arch; /**<\brief Architecture of the method's code region.
- * By default, it is the same as the process
- * architecture that is calling it.
- * For example, you can use it if your 32-bit JIT
- * engine generates 64-bit code.
- *
- * If JIT engine reports both 32-bit and 64-bit types
- * of methods then VTune Amplifier splits the methods
- * with the same module name but with different
- * architectures in two different modules. VTune Amplifier
- * modifies the original name provided with a 64-bit method
- * version by ending it with '(64)' */
-
-} *piJIT_Method_Load_V3, iJIT_Method_Load_V3;
-
-/**
- * @brief Description of an inline JIT-compiled method
- * @details When you use the_iJIT_Method_Inline_Load structure to describe
- * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED
- * as an event type to report it.
- */
-typedef struct _iJIT_Method_Inline_Load
-{
- unsigned int method_id; /**<\brief Unique method ID. Cannot be 0.
- * You must either use the API function
- * iJIT_GetNewMethodID to get a valid and unique
- * method ID, or else manage ID uniqueness
- * and correct range by yourself. */
-
- unsigned int parent_method_id; /**<\brief Unique immediate parent's method ID.
- * Cannot be 0.
- * You must either use the API function
- * iJIT_GetNewMethodID to get a valid and unique
- * method ID, or else manage ID uniqueness
- * and correct range by yourself. */
-
- char* method_name; /**<\brief The name of the method. It can be optionally
- * prefixed with its class name and appended with
- * its complete signature. Can't be NULL. */
-
- void* method_load_address; /** <\brief The virtual address on which the method
- * is inlined. If NULL, then data provided with
- * the event are not accepted. */
-
- unsigned int method_size; /**<\brief The code size of the method in memory.
- * If 0, then data provided with the event are not
- * accepted. */
-
- unsigned int line_number_size; /**<\brief The number of entries in the line number
- * table. 0 if none. */
-
- pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info
- * array. Can be NULL if
- * line_number_size is 0. See
- * LineNumberInfo Structure for a
- * description of a single entry in
- * the line number info array */
-
- char* class_file_name; /**<\brief Class name. Can be NULL.*/
-
- char* source_file_name; /**<\brief Source file name. Can be NULL.*/
-
-} *piJIT_Method_Inline_Load, iJIT_Method_Inline_Load;
-
-/** @cond exclude_from_documentation */
-/**
- * @brief Description of a segment type
- * @details Use the segment type to specify a type of data supplied
- * with the iJVM_EVENT_TYPE_METHOD_UPDATE_V2 event to be applied to
- * a certain code trace.
- */
-typedef enum _iJIT_SegmentType
-{
- iJIT_CT_UNKNOWN = 0,
-
- iJIT_CT_CODE, /**<\brief Executable code. */
-
- iJIT_CT_DATA, /**<\brief Data (not executable code).
- * VTune Amplifier uses the format string
- * (see iJIT_Method_Update) to represent
- * this data in the VTune Amplifier GUI */
-
- iJIT_CT_KEEP, /**<\brief Use the previous markup for the trace.
- * Can be used for the following
- * iJVM_EVENT_TYPE_METHOD_UPDATE_V2 events,
- * if the type of the previously reported segment
- * type is the same. */
- iJIT_CT_EOF
-} iJIT_SegmentType;
-
-/**
- * @brief Description of a dynamic update of the content within JIT-compiled method
- * @details The JIT engine may generate the methods that are updated at runtime
- * partially by mixed (data + executable code) content. When you use the iJIT_Method_Update
- * structure to describe the update of the content within a JIT-compiled method,
- * use iJVM_EVENT_TYPE_METHOD_UPDATE_V2 as an event type to report it.
- *
- * On the first Update event, VTune Amplifier copies the original code range reported by
- * the iJVM_EVENT_TYPE_METHOD_LOAD event, then modifies it with the supplied bytes and
- * adds the modified range to the original method. For next update events, VTune Amplifier
- * does the same but it uses the latest modified version of a code region for update.
- * Eventually, VTune Amplifier GUI displays multiple code ranges for the method reported by
- * the iJVM_EVENT_TYPE_METHOD_LOAD event.
- * Notes:
- * - Multiple update events with different types for the same trace are allowed
- * but they must be reported for the same code ranges.
- * Example,
- * @code
- * [-- data---] Allowed
- * [-- code --] Allowed
- * [code] Ignored
- * [-- data---] Allowed
- * [-- code --] Allowed
- * [------------ trace ---------]
- * @endcode
- * - The types of previously reported events can be changed but they must be reported
- * for the same code ranges.
- * Example,
- * @code
- * [-- data---] Allowed
- * [-- code --] Allowed
- * [-- data---] Allowed
- * [-- code --] Allowed
- * [------------ trace ---------]
- * @endcode
- */
-
-typedef struct _iJIT_Method_Update
-{
- void* load_address; /**<\brief Start address of the update within a method */
-
- unsigned int size; /**<\brief The update size */
-
- iJIT_SegmentType type; /**<\brief Type of the update */
-
- const char* data_format; /**<\brief C string that contains a format string
- * that follows the same specifications as format in printf.
- * The format string is used for iJIT_CT_CODE only
- * and cannot be NULL.
- * Format can be changed on the fly. */
-} *piJIT_Method_Update, iJIT_Method_Update;
-
-/** @endcond */
-
-#pragma pack(pop)
-
-/** @cond exclude_from_documentation */
-#ifdef __cplusplus
-extern "C" {
-#endif /* __cplusplus */
-
-#ifndef JITAPI_CDECL
-# if defined WIN32 || defined _WIN32
-# define JITAPI_CDECL __cdecl
-# else /* defined WIN32 || defined _WIN32 */
-# if defined _M_IX86 || defined __i386__
-# define JITAPI_CDECL __attribute__ ((cdecl))
-# else /* _M_IX86 || __i386__ */
-# define JITAPI_CDECL /* actual only on x86_64 platform */
-# endif /* _M_IX86 || __i386__ */
-# endif /* defined WIN32 || defined _WIN32 */
-#endif /* JITAPI_CDECL */
-
-#define JITAPI JITAPI_CDECL
-/** @endcond */
-
-/**
- * @brief Generates a new unique method ID.
- *
- * You must use this API to obtain unique and valid method IDs for methods or
- * traces reported to the agent if you don't have your own mechanism to generate
- * unique method IDs.
- *
- * @return a new unique method ID. When out of unique method IDs, this API
- * returns 0, which is not an accepted value.
- */
-unsigned int JITAPI iJIT_GetNewMethodID(void);
-
-/**
- * @brief Returns the current mode of the agent.
- *
- * @return iJIT_SAMPLING_ON, indicating that agent is running, or
- * iJIT_NOTHING_RUNNING if no agent is running.
- */
-iJIT_IsProfilingActiveFlags JITAPI iJIT_IsProfilingActive(void);
-
-/**
- * @brief Reports infomation about JIT-compiled code to the agent.
- *
- * The reported information is used to attribute samples obtained from any
- * Intel(R) VTune(TM) Amplifier collector. This API needs to be called
- * after JIT compilation and before the first entry into the JIT-compiled
- * code.
- *
- * @param[in] event_type - type of the data sent to the agent
- * @param[in] EventSpecificData - pointer to event-specific data
- *
- * @returns 1 on success, otherwise 0.
- */
-int JITAPI iJIT_NotifyEvent(iJIT_JVM_EVENT event_type, void *EventSpecificData);
-
-#ifdef __cplusplus
-}
-#endif /* __cplusplus */
-/** @endcond */
-
-/** @} jitapi group */
-
-#endif /* __JITPROFILING_H__ */
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp
deleted file mode 100644
index ef4c42bacf..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp
+++ /dev/null
@@ -1,317 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include <math.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "math_utils.hpp"
-#include "mkldnn_thread.hpp"
-#include "nstl.hpp"
-
-#include "nchw_pooling.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <impl::data_type_t data_type>
-void nchw_pooling_fwd_t<data_type>::execute_forward(
- const exec_ctx_t &ctx) const {
- using namespace alg_kind;
-
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
- auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE);
-
- const memory_desc_wrapper ws_d(pd()->workspace_md());
- const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
-
- const int MB = pd()->MB();
- const int C = pd()->C();
- const int OD = pd()->OD();
- const int OH = pd()->OH();
- const int OW = pd()->OW();
- const int ID = pd()->ID();
- const int IH = pd()->IH();
- const int IW = pd()->IW();
- const int KD = pd()->KD();
- const int KH = pd()->KH();
- const int KW = pd()->KW();
- const int SD = pd()->KSD();
- const int SH = pd()->KSH();
- const int SW = pd()->KSW();
- const int padF = pd()->padFront();
- const int padT = pd()->padT();
- const int padL = pd()->padL();
-
- auto alg = pd()->desc()->alg_kind;
-
- auto apply_offset = [=](int index, int offset) {
- return (index > offset) ? index - offset : 0;
- };
-
- auto set_ws = [=](int mb, int c, int od, int oh, int ow, int value) {
- if (ws) {
- assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
- size_t ws_offset
- = (size_t)OW * OH * OD * C * mb
- + (size_t)OW * OH * OD * c
- + (size_t)OW * OH * od
- + (size_t)OW * oh
- + (size_t)ow;
- if (ws_dt == data_type::u8) {
- assert(0 <= value && value <= 255);
- ws[ws_offset] = value;
- } else
- reinterpret_cast<int *>(ws)[ws_offset] = value;
- }
- };
-
- auto ker_max = [=](data_t *d, int mb, int c, int od, int oh, int ow) {
- for (int kd = 0; kd < KD; ++kd) {
- for (int kh = 0; kh < KH; ++kh) {
- for (int kw = 0; kw < KW; ++kw) {
- const int id = od * SD - padF + kd;
- const int ih = oh * SH - padT + kh;
- const int iw = ow * SW - padL + kw;
-
- if (id < 0 || id >= ID) continue;
- if (ih < 0 || ih >= IH) continue;
- if (iw < 0 || iw >= IW) continue;
-
- auto src_offset
- = (size_t)IW * IH * ID * C * mb
- + (size_t)IW * IH * ID * c
- + (size_t)IW * IH * id
- + (size_t)IW * ih
- + (size_t)iw;
- auto s = src[src_offset];
- if (s > d[0]) {
- d[0] = s;
- set_ws(mb, c, od, oh, ow, kd*KH*KW + kh*KW + kw);
- }
- }
- }
- }
- };
-
- auto ker_avg = [=](data_t *d, int mb, int c, int od, int oh, int ow) {
- auto id_start = apply_offset(od*SD, padF);
- auto ih_start = apply_offset(oh*SH, padT);
- auto iw_start = apply_offset(ow*SW, padL);
- auto id_end = nstl::min(od*SD - padF + KD, ID);
- auto ih_end = nstl::min(oh*SH - padT + KH, IH);
- auto iw_end = nstl::min(ow*SW - padL + KW, IW);
-
- auto num_summands = (alg == pooling_avg_include_padding) ? KD*KW*KH
- : (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start);
-
- for (int id = id_start; id < id_end; ++id) {
- for (int ih = ih_start; ih < ih_end; ++ih) {
- for (int iw = iw_start; iw < iw_end; ++iw) {
- auto src_offset
- = (size_t)IW * IH * ID * C * mb
- + (size_t)IW * IH * ID * c
- + (size_t)IW * IH * id
- + (size_t)IW * ih
- + (size_t)iw;
- d[0] += src[src_offset];
- }
- }
- }
-
- d[0] = math::out_round<data_t>((float)d[0] / num_summands);
- };
-
-
- if (pd()->desc()->alg_kind == pooling_max) {
- parallel_nd(MB, C, OD, OH, OW,
- [&](int mb, int c, int od, int oh, int ow) {
- size_t dst_offset
- = (size_t)OW * OH * OD * C * mb
- + (size_t)OW * OH * OD * c
- + (size_t)OW * OH * od
- + (size_t)OW * oh
- + (size_t)ow;
- data_t *d = &dst[dst_offset];
- d[0] = nstl::numeric_limits<data_t>::lowest();
- set_ws(mb, c, od, oh, ow, 0);
- ker_max(d, mb, c, od, oh, ow);
- });
- } else {
- parallel_nd(MB, C, OD, OH, OW,
- [&](int mb, int c, int od, int oh, int ow) {
- size_t dst_offset
- = (size_t)OW * OH * OD * C * mb
- + (size_t)OW * OH * OD * c
- + (size_t)OW * OH * od
- + (size_t)OW * oh
- + (size_t)ow;
- data_t *d = &dst[dst_offset];
- d[0] = 0;
- ker_avg(d, mb, c, od, oh, ow);
- });
- }
-}
-
-template <impl::data_type_t data_type>
-void nchw_pooling_bwd_t<data_type>::execute_backward(
- const exec_ctx_t &ctx) const {
- using namespace alg_kind;
-
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE);
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const memory_desc_wrapper ws_d(pd()->workspace_md());
-
- const int MB = pd()->MB();
- const int C = pd()->C();
- const int OD = pd()->OD();
- const int OH = pd()->OH();
- const int OW = pd()->OW();
- const int ID = pd()->ID();
- const int IH = pd()->IH();
- const int IW = pd()->IW();
- const int KD = pd()->KD();
- const int KH = pd()->KH();
- const int KW = pd()->KW();
- const int SD = pd()->KSD();
- const int SH = pd()->KSH();
- const int SW = pd()->KSW();
- const int padF = pd()->padFront();
- const int padT = pd()->padT();
- const int padL = pd()->padL();
-
- const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
-
- auto alg = pd()->desc()->alg_kind;
-
- auto apply_offset = [=](int index, int offset) {
- return (index > offset) ? index - offset : 0;
- };
-
- auto ker_zero = [=](int mb, int c) {
- size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW;
- for (int id = 0; id < ID; ++id) {
- for (int ih = 0; ih < IH; ++ih) {
- for (int iw = 0; iw < IW; ++iw) {
- diff_src[diff_src_offset++] = 0;
- }
- }
- }
- };
-
- auto ker_max = [=](const data_t *d, int mb, int c, int od, int oh, int ow) {
- auto b_c = ws_d.blocking_desc().inner_nblks == 0
- ? 1 : ws_d.blocking_desc().inner_blks[0];
- auto ws_offset = is_3d
- ? ws_d.blk_off(mb, c / b_c, od, oh, ow) + c % b_c
- : ws_d.blk_off(mb, c / b_c, oh, ow) + c % b_c;
-
- const int index = ws_d.data_type() == data_type::u8
- ? (int)ws[ws_offset] : ((const int *)ws)[ws_offset];
- const int kw = index % KW;
- const int kh = (index / KW) % KH;
- const int kd = (index / KW) / KH;
-
- const int id = od * SD - padF + kd;
- const int ih = oh * SH - padT + kh;
- const int iw = ow * SW - padL + kw;
-
- // If padding area could fit the kernel,
- // then input displacement would be out of bounds.
- // No need to back propagate there as padding is
- // virtual in pooling_max case.
- if (id < 0 || id >= ID)
- return;
- if (ih < 0 || ih >= IH)
- return;
- if (iw < 0 || iw >= IW)
- return;
-
- size_t diff_src_offset =
- (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW + (size_t)id*IH*IW
- + (size_t)ih*IW + (size_t)iw;
- diff_src[diff_src_offset] += d[0];
- };
-
- auto ker_avg = [=](const data_t *d, int mb, int c, int od, int oh, int ow) {
- auto id_start = apply_offset(od*SD, padF);
- auto ih_start = apply_offset(oh*SH, padT);
- auto iw_start = apply_offset(ow*SW, padL);
- auto id_end = nstl::min(od*SD - padF + KD, ID);
- auto ih_end = nstl::min(oh*SH - padT + KH, IH);
- auto iw_end = nstl::min(ow*SW - padL + KW, IW);
-
- size_t num_summands = (alg == pooling_avg_include_padding)
- ? (size_t)KW*KH*KD
- : (size_t)(id_end - id_start)*(ih_end - ih_start)
- *(iw_end - iw_start);
-
- for (int id = id_start; id < id_end; ++id) {
- for (int ih = ih_start; ih < ih_end; ++ih) {
- for (int iw = iw_start; iw < iw_end; ++iw) {
- size_t diff_src_offset = (size_t)mb*C*ID*IH*IW
- + (size_t)c*ID*IH*IW + (size_t)id*IH*IW
- + (size_t)ih*IW + (size_t)iw;
- diff_src[diff_src_offset] += d[0] / num_summands;
- }
- }
- }
- };
-
- if (pd()->desc()->alg_kind == pooling_max) {
- parallel_nd(MB, C, [&](int mb, int c) {
- size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW
- + (size_t)c*OD*OH*OW;
- ker_zero(mb, c);
- for (int od = 0; od < OD; ++od) {
- for (int oh = 0; oh < OH; ++oh) {
- for (int ow = 0; ow < OW; ++ow) {
- const data_t *d = &diff_dst[diff_dst_offset++];
- ker_max(d, mb, c, od, oh, ow);
- }
- }
- }
- });
- } else {
- parallel_nd(MB, C, [&](int mb, int c) {
- size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW
- + (size_t)c*OD*OH*OW;
- ker_zero(mb, c);
- for (int od = 0; od < OD; ++od) {
- for (int oh = 0; oh < OH; ++oh) {
- for (int ow = 0; ow < OW; ++ow) {
- const data_t *d = &diff_dst[diff_dst_offset++];
- ker_avg(d, mb, c, od, oh, ow);
- }
- }
- }
- });
- }
-}
-
-template struct nchw_pooling_fwd_t<data_type::f32>;
-template struct nchw_pooling_bwd_t<data_type::f32>;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp
deleted file mode 100644
index bbdd04f6b9..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp
+++ /dev/null
@@ -1,147 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_NCHW_POOLING_HPP
-#define CPU_NCHW_POOLING_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_pooling_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <impl::data_type_t data_type>
-struct nchw_pooling_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_pooling_fwd_pd_t {
- using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t;
-
- DECLARE_COMMON_PD_T("nchw_pooling:any", nchw_pooling_fwd_t);
-
- status_t init() {
- const format_tag_t desired_fmt_tag =
- ndims() == 4 ? format_tag::nchw : format_tag::ncdhw;
-
- bool ok = true
- && set_default_params() == status::success
- && is_fwd()
- && utils::one_of(desc()->alg_kind, alg_kind::pooling_max,
- alg_kind::pooling_avg_include_padding,
- alg_kind::pooling_avg_exclude_padding)
- && !has_zero_dim_memory()
- && utils::everyone_is(data_type, src_md()->data_type,
- dst_md()->data_type)
- && attr()->has_default_values()
- && memory_desc_matches_tag(*src_md(), desired_fmt_tag)
- && memory_desc_matches_tag(*dst_md(), desired_fmt_tag);
- if (!ok) return status::unimplemented;
-
- bool is_training = desc_.prop_kind == prop_kind::forward_training;
- if (desc()->alg_kind == alg_kind::pooling_max && is_training)
- init_default_ws();
-
- return status::success;
- }
- };
-
- nchw_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
- typedef typename prec_traits<data_type>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-template <impl::data_type_t data_type>
-struct nchw_pooling_bwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_pooling_bwd_pd_t {
- using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t;
-
- DECLARE_COMMON_PD_T("nchw:any", nchw_pooling_bwd_t);
-
- status_t init() {
- const format_tag_t desired_fmt_tag =
- ndims() == 4 ? format_tag::nchw : format_tag::ncdhw;
-
- bool ok = true
- && set_default_params() == status::success
- && !is_fwd()
- && utils::one_of(desc()->alg_kind, alg_kind::pooling_max,
- alg_kind::pooling_avg_include_padding,
- alg_kind::pooling_avg_exclude_padding)
- && !has_zero_dim_memory()
- && utils::everyone_is(data_type,
- diff_dst_md()->data_type,
- diff_src_md()->data_type)
- && attr()->has_default_values()
- && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag)
- && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag);
- if (!ok) return status::unimplemented;
-
- if (desc()->alg_kind == alg_kind::pooling_max) {
- bool ws_ok = true
- && hint_fwd_pd_
- && hint_fwd_pd_->workspace_md();
- if (!ws_ok)
- return status::unimplemented;
-
- const auto &ws_blk =
- hint_fwd_pd_->workspace_md()->format_desc.blocking;
- ws_ok = ws_ok
- && ws_blk.inner_nblks < 1
- && IMPLICATION(ws_blk.inner_nblks == 1,
- ws_blk.inner_idxs[0] == 1);
- if (!ws_ok)
- return status::unimplemented;
-
- ws_md_ = *hint_fwd_pd_->workspace_md();
- }
-
- return status::success;
- }
- };
-
- nchw_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
- typedef typename prec_traits<data_type>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward(ctx);
- return status::success;
- }
-
-private:
- void execute_backward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp
deleted file mode 100644
index c0e93fefe4..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp
+++ /dev/null
@@ -1,382 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include <math.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-
-#include "cpu_batch_normalization_utils.hpp"
-#include "jit_generator.hpp"
-
-#include "ncsp_batch_normalization.hpp"
-
-// clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases
-#if (defined __clang_major__) && (__clang_major__ >= 6)
-#define SAFE_TO_USE_OMP_SIMD 0
-#else
-#define SAFE_TO_USE_OMP_SIMD 1
-#endif
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace memory_tracking::names;
-
-void ncsp_batch_normalization_fwd_t::execute_forward(
- const exec_ctx_t &ctx) const {
- const bool calculate_stats = !pd()->stats_is_src();
- const bool save_stats = pd()->is_training();
- const bool is_training = pd()->is_training();
- const bool fuse_bn_relu = pd()->fuse_bn_relu();
-
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);
-
- auto scratchpad = this->scratchpad(ctx);
- auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);
-
- data_t *mean, *variance;
- if (!calculate_stats) {
- mean = const_cast<data_t *>(
- CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN));
- variance = const_cast<data_t *>(
- CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE));
- } else {
- if (save_stats) {
- mean = CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN);
- variance = CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE);
- } else {
- mean = scratchpad.get<data_t>(key_bnorm_tmp_mean);
- variance = scratchpad.get<data_t>(key_bnorm_tmp_var);
- }
- }
-
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
- auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE);
-
- const float eps = pd()->desc()->batch_norm_epsilon;
- const bool use_scaleshift = pd()->use_scaleshift();
- const bool with_relu = pd()->with_relu_post_op();
- auto maybe_post_op
- = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; };
- const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5);
- dim_t SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1;
- dim_t N = pd()->MB();
- dim_t C = pd()->C();
-
- int nthr = mkldnn_get_max_threads();
- size_t l3_size_ = get_cache_size(3, true) * nthr / 2;
- size_t data_size = N * C * SP * sizeof(data_t);
- bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0);
-
- parallel(0, [&](const int ithr, const int nthr) {
- int C_ithr = 0, C_nthr = 0;
- int N_ithr = 0, N_nthr = 0;
- int S_ithr = 0, S_nthr = 0;
-
- dim_t C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0;
- dim_t N_s = 0, N_e = 0;
- dim_t S_s = 0, S_e = 0;
-
- dim_t C_blks_per_iter = 1;
- int64_t iters = 1;
-
- if (do_blocking) {
- size_t working_set_size = N * SP * sizeof(data_t);
- bnorm_utils::cache_balance(
- working_set_size, C, C_blks_per_iter, iters);
- } else
- C_blks_per_iter = C;
- int64_t last_iter_blks = C - (iters - 1) * C_blks_per_iter;
- bool spatial_thr_allowed
- = bnorm_utils::thread_balance(do_blocking, true, ithr, nthr, N,
- C_blks_per_iter, SP, C_ithr, C_nthr, C_blk_s, C_blk_e,
- N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e);
- balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
- int SP_N_ithr = N_ithr * S_nthr + S_ithr;
- int SP_N_nthr = N_nthr * S_nthr;
- for (int64_t it = 0; it < iters; ++it) {
- if (it == iters - 1 && iters > 1) {
- // On the last iteration the access pattern to ws_reduce
- // might change (due to re-balance on C). So sync the
- // threads if they are not synced by the algorithm.
- if (SP_N_nthr == 1 && mkldnn_thr_syncable())
- mkldnn_thr_barrier();
-
- S_s = S_e = C_blk_s = C_blk_e = N_s = N_e = 0;
- spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking,
- spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP,
- C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s,
- N_e, S_ithr, S_nthr, S_s, S_e);
- balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
- SP_N_ithr = N_ithr * S_nthr + S_ithr;
- SP_N_nthr = N_nthr * S_nthr;
- }
- size_t C_off = it * C_blks_per_iter;
- // On the last iteration the access pattern to ws_reduce
- // might change (due to re-balance on C). Since sync is not always
- // possible (in case of TBB) use different parts of ws for each
- // iteration if threads are not synced by the algorithm.
- size_t ws_iter_off = (mkldnn_thr_syncable() ? 0 : 1) * C_off;
-
- if (calculate_stats) {
- data_t *mean_blk = mean + C_off;
- data_t *variance_blk = variance + C_off;
- for (dim_t c = C_blk_s; c < C_blk_e; c++) {
- size_t off = (c + C_off) * SP;
- data_t sum = 0;
- for (dim_t n = N_s; n < N_e; ++n)
- PRAGMA_OMP_SIMD(reduction(+ : sum))
- for (dim_t sp = S_s; sp < S_e; ++sp) {
- sum += src[off + n * C * SP + sp];
- }
- ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c]
- = sum;
- }
-
- if (SP_N_nthr > 1) mkldnn_thr_barrier();
-
- for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) {
- mean_blk[c] = 0.;
- for (dim_t n = 0; n < SP_N_nthr; n++)
- mean_blk[c] += ws_reduce[ws_iter_off
- + n * C_blks_per_iter + c];
- mean_blk[c] /= (N * SP);
- }
-
- if (SP_N_nthr > 1) mkldnn_thr_barrier();
-
- for (dim_t c = C_blk_s; c < C_blk_e; c++) {
- size_t off = c + C_off;
- data_t sum = 0.;
- for (dim_t n = N_s; n < N_e; ++n)
- PRAGMA_OMP_SIMD(reduction(+ : sum))
- for (dim_t sp = S_s; sp < S_e; ++sp) {
- data_t m = src[off * SP + n * C * SP + sp]
- - mean[off];
- sum += m * m;
- }
- ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c]
- = sum;
- }
-
- if (SP_N_nthr > 1) mkldnn_thr_barrier();
-
- for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) {
- variance_blk[c] = 0.;
- for (dim_t n = 0; n < SP_N_nthr; n++)
- variance_blk[c] += ws_reduce[ws_iter_off
- + n * C_blks_per_iter + c];
- variance_blk[c] /= (N * SP);
- }
-
- if (SP_N_nthr > 1) mkldnn_thr_barrier();
- }
-
- for (dim_t c = C_blk_s; c < C_blk_e; c++) {
- size_t off = c + C_off;
- data_t sqrt_variance
- = static_cast<data_t>(sqrtf(variance[off] + eps));
- data_t sm = (use_scaleshift ? scaleshift[off] : 1.0f) / sqrt_variance;
- data_t sv = use_scaleshift ? scaleshift[C + off] : 0;
- for (dim_t n = N_s; n < N_e; ++n)
-#if SAFE_TO_USE_OMP_SIMD
- PRAGMA_OMP_SIMD()
-#endif
- for (dim_t sp = S_s; sp < S_e; ++sp) {
- size_t d_off = off * SP + n * C * SP + sp;
- data_t bn_res
- = sm * (src[d_off] - mean[off]) + sv;
- if (fuse_bn_relu) {
- if (bn_res <= 0) {
- bn_res = 0;
- if (is_training)
- ws[d_off] = 0;
- } else {
- if (is_training)
- ws[d_off] = 1;
- }
- }
- dst[d_off] = maybe_post_op(bn_res);
- }
- }
- }
- });
-}
-
-void ncsp_batch_normalization_bwd_t::execute_backward(
- const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN);
- auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE);
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);
- auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE);
-
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
- auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT);
-
- auto scratchpad = this->scratchpad(ctx);
- auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);
-
- if (diff_scaleshift == nullptr)
- diff_scaleshift = scratchpad.get<data_t>(key_bnorm_tmp_diff_ss);
-
- const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5);
- dim_t SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1;
- dim_t C = pd()->C(), N = pd()->MB();
- const bool use_scaleshift = pd()->use_scaleshift();
- const float eps = pd()->desc()->batch_norm_epsilon;
- const bool calculate_diff_stats = !pd()->use_global_stats();
- const bool fuse_bn_relu = pd()->fuse_bn_relu();
-
- int nthr = mkldnn_get_max_threads();
- size_t l3_size_ = get_cache_size(3, true) * nthr / 2;
- size_t data_size = N * C * SP * sizeof(data_t);
- bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0);
-
- parallel(0, [&](const int ithr, const int nthr) {
- int C_ithr = 0, C_nthr = 0;
- int N_ithr = 0, N_nthr = 0;
- int S_ithr = 0, S_nthr = 0;
-
- dim_t C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0;
- dim_t N_s = 0, N_e = 0;
- dim_t S_s = 0, S_e = 0;
-
- dim_t C_blks_per_iter = 1;
- int64_t iters = 1;
-
- if (do_blocking) {
- size_t working_set_size = 2 * N * SP * sizeof(data_t);
- bnorm_utils::cache_balance(
- working_set_size, C, C_blks_per_iter, iters);
- } else
- C_blks_per_iter = C;
- int64_t last_iter_blks = C - (iters - 1) * C_blks_per_iter;
- bool spatial_thr_allowed
- = bnorm_utils::thread_balance(do_blocking, true, ithr, nthr, N,
- C_blks_per_iter, SP, C_ithr, C_nthr, C_blk_s, C_blk_e,
- N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e);
- balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
- int SP_N_ithr = N_ithr * S_nthr + S_ithr;
- int SP_N_nthr = N_nthr * S_nthr;
-
- for (int64_t it = 0; it < iters; ++it) {
- if (it == iters - 1 && iters > 1) {
- // On the last iteration the access pattern to ws_reduce
- // might change (due to re-balance on C). So sync the
- // threads if they are not synced by the algorithm.
- if (SP_N_nthr == 1 && mkldnn_thr_syncable())
- mkldnn_thr_barrier();
-
- C_blk_s = C_blk_e = N_s = N_e = 0;
- spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking,
- spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP,
- C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s,
- N_e, S_ithr, S_nthr, S_s, S_e);
- balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
- SP_N_ithr = N_ithr * S_nthr + S_ithr;
- SP_N_nthr = N_nthr * S_nthr;
- }
- size_t C_off = it * C_blks_per_iter;
- // On the last iteration the access pattern to ws_reduce
- // might change (due to re-balance on C). Since sync is not always
- // possible (in case of TBB) use different parts of ws for each
- // iteration if threads are not synced by the algorithm.
- size_t ws_iter_off = (mkldnn_thr_syncable() ? 0 : 1) * 2 * C_off;
-
- data_t *diff_gamma_blk = diff_scaleshift + C_off;
- data_t *diff_beta_blk = diff_scaleshift + C + C_off;
- for (dim_t c = C_blk_s; c < C_blk_e; c++) {
- size_t off = c + C_off;
- data_t diff_gamma = 0.0, diff_beta = 0.0;
- data_t v_mean = mean[off];
- for (dim_t n = N_s; n < N_e; ++n)
- PRAGMA_OMP_SIMD(reduction(+ : diff_gamma, diff_beta))
- for (dim_t sp = S_s; sp < S_e; ++sp) {
- const size_t d_off = off * SP + n * C * SP + sp;
- data_t dd;
- if (fuse_bn_relu)
- dd = (!ws[d_off]) ? 0 : diff_dst[d_off];
- else
- dd = diff_dst[d_off];
- diff_gamma += (src[d_off] - v_mean) * dd;
- diff_beta += dd;
- }
- ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c]
- = diff_gamma;
- ws_reduce[ws_iter_off + SP_N_nthr * C_blks_per_iter
- + SP_N_ithr * C_blks_per_iter + c] = diff_beta;
- }
-
- if (SP_N_nthr > 1) mkldnn_thr_barrier();
-
- for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) {
- data_t sqrt_variance = static_cast<data_t>(
- 1.0f / sqrtf(variance[c + C_off] + eps));
- diff_gamma_blk[c] = 0.;
- diff_beta_blk[c] = 0.;
- for (dim_t n = 0; n < SP_N_nthr; n++) {
- diff_gamma_blk[c] += ws_reduce[ws_iter_off
- + n * C_blks_per_iter + c];
- diff_beta_blk[c] += ws_reduce[ws_iter_off
- + SP_N_nthr * C_blks_per_iter + n * C_blks_per_iter
- + c];
- }
- diff_gamma_blk[c] *= sqrt_variance;
- }
-
- if (SP_N_nthr > 1) mkldnn_thr_barrier();
-
- for (dim_t c = C_blk_s; c < C_blk_e; c++) {
- size_t off = c + C_off;
- data_t gamma = use_scaleshift ? scaleshift[off] : 1;
- data_t sqrt_variance
- = static_cast<data_t>(1.0f / sqrtf(variance[off] + eps));
- data_t v_mean = mean[off];
- for (dim_t n = N_s; n < N_e; ++n)
-#if SAFE_TO_USE_OMP_SIMD
- PRAGMA_OMP_SIMD()
-#endif
- for (dim_t sp = S_s; sp < S_e; ++sp) {
- const size_t d_off = off * SP + n * C * SP + sp;
-
- data_t v_diff_src;
- if (fuse_bn_relu)
- v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off];
- else
- v_diff_src = diff_dst[d_off];
- if (calculate_diff_stats) {
- v_diff_src -= diff_beta_blk[c] / (SP * N)
- + (src[d_off] - v_mean) * diff_gamma_blk[c]
- * sqrt_variance / (SP * N);
- }
- v_diff_src *= gamma * sqrt_variance;
- diff_src[d_off] = v_diff_src;
- }
- }
- }
- });
-}
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp
deleted file mode 100644
index 97ca3b003f..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp
+++ /dev/null
@@ -1,160 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_NCSP_BATCH_NORMALIZATION_HPP
-#define CPU_NCSP_BATCH_NORMALIZATION_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_batch_normalization_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct ncsp_batch_normalization_fwd_t : public cpu_primitive_t {
- struct pd_t : public cpu_batch_normalization_fwd_pd_t {
- using cpu_batch_normalization_fwd_pd_t::cpu_batch_normalization_fwd_pd_t;
-
- DECLARE_COMMON_PD_T("ncsp_bnorm:any", ncsp_batch_normalization_fwd_t);
-
- status_t init() {
- using namespace data_type;
- using namespace prop_kind;
- using namespace format_tag;
-
- bool ok = true
- && is_fwd()
- && !has_zero_dim_memory()
- && src_md()->data_type == f32
- && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32)
- && memory_desc_matches_one_of_tag(*src_md(), ncdhw, nchw, nc)
- && (attr()->has_default_values() || this->with_relu_post_op());
- if (!ok) return status::unimplemented;
-
- if (is_training() && fuse_bn_relu()) init_default_ws(8);
-
- init_scratchpad();
-
- return status::success;
- }
-
- private:
- void init_scratchpad() {
- using namespace memory_tracking::names;
- auto scratchpad = scratchpad_registry().registrar();
- if (!stats_is_src()) {
- scratchpad.book(key_bnorm_reduction,
- sizeof(data_t) * C() * mkldnn_get_max_threads());
-
- if (!is_training()) {
- scratchpad.book(key_bnorm_tmp_mean, sizeof(data_t) * C());
- scratchpad.book(key_bnorm_tmp_var, sizeof(data_t) * C());
- }
- }
- }
- };
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- ncsp_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
- ~ncsp_batch_normalization_fwd_t() {}
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-struct ncsp_batch_normalization_bwd_t : public cpu_primitive_t {
- struct pd_t : public cpu_batch_normalization_bwd_pd_t {
- using cpu_batch_normalization_bwd_pd_t::cpu_batch_normalization_bwd_pd_t;
-
- DECLARE_COMMON_PD_T("ncsp_bnorm:any", ncsp_batch_normalization_bwd_t);
-
- status_t init() {
- using namespace data_type;
- using namespace format_tag;
-
- bool ok = true
- && is_bwd()
- && !has_zero_dim_memory()
- && utils::everyone_is(f32, src_md()->data_type,
- diff_src_md()->data_type)
- && IMPLICATION(use_scaleshift(),
- utils::everyone_is(f32,
- weights_md()->data_type,
- diff_weights_md()->data_type))
- && memory_desc_matches_one_of_tag(*src_md(), ncdhw, nchw, nc)
- && memory_desc_matches_one_of_tag(*diff_src_md(), ncdhw, nchw, nc)
- && attr()->has_default_values();
- if (!ok) return status::unimplemented;
-
- if (fuse_bn_relu()) {
- init_default_ws(8);
- if (!compare_ws(hint_fwd_pd_))
- return status::unimplemented;
- }
-
- init_scratchpad();
-
- return status::success;
- }
-
- private:
- void init_scratchpad() {
- using namespace memory_tracking::names;
- auto scratchpad = scratchpad_registry().registrar();
- scratchpad.book(key_bnorm_reduction,
- sizeof(data_t) * 2 * C() * mkldnn_get_max_threads());
- if (!(use_scaleshift() && desc()->prop_kind == prop_kind::backward))
- scratchpad.book(key_bnorm_tmp_diff_ss,
- sizeof(data_t) * 2 * C());
- }
- };
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- ncsp_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
- ~ncsp_batch_normalization_bwd_t() {}
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward(ctx);
- return status::success;
- }
-
-private:
- void execute_backward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp
deleted file mode 100644
index 38cfb28dce..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp
+++ /dev/null
@@ -1,392 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include <math.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "math_utils.hpp"
-#include "mkldnn_thread.hpp"
-#include "nstl.hpp"
-
-#include "nhwc_pooling.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-#define MEM_D(name) name##_d
-
-#define DECLARE_READ_STRIDES(name) \
- const size_t name##_n_stride = MEM_D(name).blocking_desc().strides[0]; \
- const size_t name##_d_stride = (!is_3d) \
- ? 0 \
- : MEM_D(name).blocking_desc().strides[2]; \
- const size_t name##_h_stride = (!is_3d) \
- ? MEM_D(name).blocking_desc().strides[2] \
- : MEM_D(name).blocking_desc().strides[3]; \
- const size_t name##_w_stride = (!is_3d) \
- ? MEM_D(name).blocking_desc().strides[3] \
- : MEM_D(name).blocking_desc().strides[4];
-
-namespace nhwc_pooling {
- size_t strided_offset(const int _n, const size_t _sn,
- const int _d, const size_t _sd,
- const int _h, const size_t _sh,
- const int _w, const size_t _sw)
- {
- return _n * _sn
- + _d * _sd
- + _h * _sh
- + _w * _sw;
- }
-}
-
-template <impl::data_type_t data_type>
-void nhwc_pooling_fwd_t<data_type>::array_div_by_const(const int n,
- const data_t *src, const size_t num, data_t *dst) const
-{
- for (int i = 0; i < n; ++i)
- {
- float ftmp = (float)src[i];
- ftmp = ftmp / num;
- dst[i] = math::out_round<data_t>(ftmp);
- }
-}
-
-template <impl::data_type_t data_type>
-void nhwc_pooling_fwd_t<data_type>::array_add(const int n, const data_t *src,
- data_t *dst) const
-{
- for (int i = 0; i < n; ++i)
- {
- dst[i] += src[i];
- }
-}
-
-template <impl::data_type_t data_type>
-void nhwc_pooling_fwd_t<data_type>::execute_forward(
- const exec_ctx_t &ctx) const {
- using namespace alg_kind;
- using namespace prop_kind;
- using namespace nhwc_pooling;
-
- auto alg = pd()->desc()->alg_kind;
-
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
- auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE);
-
- const memory_desc_wrapper MEM_D(src)(pd()->src_md());
- const memory_desc_wrapper MEM_D(dst)(pd()->dst_md());
- const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md());
-
- const int ID = pd()->ID();
- const int IH = pd()->IH();
- const int IW = pd()->IW();
- const int KD = pd()->KD();
- const int KH = pd()->KH();
- const int KW = pd()->KW();
- const int SD = pd()->KSD();
- const int SH = pd()->KSH();
- const int SW = pd()->KSW();
- const int padF = pd()->padFront();
- const int padT = pd()->padT();
- const int padL = pd()->padL();
- const int MB = pd()->MB();
- const int OC = pd()->C();
- const int OD = pd()->OD();
- const int OH = pd()->OH();
- const int OW = pd()->OW();
-
- const bool is_3d = pd()->desc()->src_desc.ndims == 5;
- const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
-
- DECLARE_READ_STRIDES(src);
- DECLARE_READ_STRIDES(dst);
-
- auto apply_offset = [=](int index, int offset) {
- return (index > offset) ? index - offset : 0;
- };
-
- parallel_nd(MB, OD, OH, OW,
- [&](int mb, int od, int oh, int ow) {
- size_t dst_offset_init = strided_offset(mb, dst_n_stride,
- od, dst_d_stride,
- oh, dst_h_stride,
- ow, dst_w_stride);
- if (alg == pooling_max) {
- size_t ws_offset_init = 0;
- if (ws)
- {
- DECLARE_READ_STRIDES(ws);
- ws_offset_init = strided_offset(mb, ws_n_stride,
- od, ws_d_stride,
- oh, ws_h_stride,
- ow, ws_w_stride);
- }
- // Note: GCC 4.8.5 won't vectorize below
- // simple loops unless they are singled out
- // into separate helper routines:
- // array_nhwc_initialize, array_nhwc_max
- if (!ws)
- array_nhwc_initialize<false>(OC, dst + dst_offset_init,
- ws, ws_offset_init, ws_dt);
- else
- array_nhwc_initialize<true>(OC, dst + dst_offset_init,
- ws, ws_offset_init, ws_dt);
-
-
- for (int kd = 0; kd < KD; ++kd)
- for (int kh = 0; kh < KH; ++kh)
- for (int kw = 0; kw < KW; ++kw) {
- const int id = od * SD - padF + kd;
- const int ih = oh * SH - padT + kh;
- const int iw = ow * SW - padL + kw;
-
- if (id < 0 || id >= ID)
- continue;
- if (ih < 0 || ih >= IH)
- continue;
- if (iw < 0 || iw >= IW)
- continue;
-
- size_t src_offset_init = strided_offset(mb, src_n_stride,
- id, src_d_stride,
- ih, src_h_stride,
- iw, src_w_stride);
-
- if (!ws)
- array_nhwc_max<false>(OC,
- dst + dst_offset_init,
- src + src_offset_init,
- ws, ws_offset_init,
- ws_dt,
- kd * KH * KW + kh * KW + kw
- );
- else
- array_nhwc_max<true>(OC,
- dst + dst_offset_init,
- src + src_offset_init,
- ws, ws_offset_init,
- ws_dt,
- kd * KH * KW + kh * KW + kw
- );
- }
- } else {
- // pooling_avg
- auto d = dst + dst_offset_init;
-
- utils::array_set(d, 0, OC);
-
- auto id_start = apply_offset(od * SD, padF);
- auto ih_start = apply_offset(oh * SH, padT);
- auto iw_start = apply_offset(ow * SW, padL);
- auto id_end = nstl::min(od * SD - padF + KD, ID);
- auto ih_end = nstl::min(oh * SH - padT + KH, IH);
- auto iw_end = nstl::min(ow * SW - padL + KW, IW);
-
- // it is cheaper to actually count this in a loop
- // as the typical kernel is small
- size_t num_summands = 0;
-
- for (int id = id_start; id < id_end; ++id)
- for (int ih = ih_start; ih < ih_end; ++ih)
- for (int iw = iw_start; iw < iw_end; ++iw) {
- size_t src_offset_init = strided_offset(mb, src_n_stride,
- id, src_d_stride,
- ih, src_h_stride,
- iw, src_w_stride);
- auto s = src + src_offset_init;
-
- // need to move the loop to separate function
- // for GCC 4.8.5 to vectorize
- array_add(OC, s, d);
-
- num_summands++;
- }
-
- num_summands = (alg == pooling_avg_include_padding) ?
- KW * KH * KD : num_summands;
-
- // need to move the loop to separate function
- // for GCC 4.8.5 to vectorize
- array_div_by_const(OC, d, num_summands, d);
- }
- });
-}
-
-template <impl::data_type_t data_type>
-void nhwc_pooling_bwd_t<data_type>::execute_backward(
- const exec_ctx_t &ctx) const {
- using namespace alg_kind;
- using namespace nhwc_pooling;
-
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE);
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_md());
- const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_md());
- const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md());
-
- const int ID = pd()->ID();
- const int IH = pd()->IH();
- const int IW = pd()->IW();
- const int KD = pd()->KD();
- const int KH = pd()->KH();
- const int KW = pd()->KW();
- const int SD = pd()->KSD();
- const int SH = pd()->KSH();
- const int SW = pd()->KSW();
- const int OC = pd()->C();
- const int padF = pd()->padFront();
- const int padT = pd()->padT();
- const int padL = pd()->padL();
- const int OD = pd()->OD();
- const int OH = pd()->OH();
- const int OW = pd()->OW();
-
- const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
- auto alg = pd()->desc()->alg_kind;
-
- DECLARE_READ_STRIDES(diff_src);
- DECLARE_READ_STRIDES(diff_dst);
-
- auto apply_offset = [=](int index, int offset) {
- return (index > offset) ? index - offset : 0;
- };
-
- const int MB = pd()->MB();
-
- parallel_nd(MB, ID, IH, IW,
- [&](int mb, int id, int ih, int iw) {
- size_t src_offset_init = strided_offset(mb, diff_src_n_stride,
- id, diff_src_d_stride,
- ih, diff_src_h_stride,
- iw, diff_src_w_stride);
-
- // check if kernel windows are disjoint, in this case there's no
- // update needed and we just write there once, no initialization
- // required.
- if (!(KD == SD && KH == SH && KW == SW))
- for (int oc = 0; oc < OC; ++oc)
- diff_src[src_offset_init + oc] = data_type_t(0);
-
- // Find out which output cells may correspond to current
- // input position. Current input postition divided by
- // stride, with integer divide rounding down, is the
- // right-most output.
- // Left-most output may be computed if we decrement input
- // by (kernel_size - 1) and then do the same division by
- // stride.
- int od_left = nstl::max((id + padF - KD + 1) / SD, 0);
- int oh_left = nstl::max((ih + padT - KH + 1) / SH, 0);
- int ow_left = nstl::max((iw + padL - KW + 1) / SW, 0);
- // Notice +1 here to preserve the C loop "less than"
- // condition for continuing the for loop.
- int od_right = nstl::min((id + padF) / SD + 1 , OD);
- int oh_right = nstl::min((ih + padT) / SH + 1 , OH);
- int ow_right = nstl::min((iw + padL) / SW + 1 , OW);
-
- for (int od = od_left; od < od_right; ++od)
- for (int oh = oh_left; oh < oh_right; ++oh)
- for (int ow = ow_left; ow < ow_right; ++ow) {
- const int kd = id - od*SD + padF;
- const int kh = ih - oh*SH + padT;
- const int kw = iw - ow*SW + padL;
-
- if (kd < 0 || kd >= KD)
- continue;
- if (kh < 0 || kh >= KH)
- continue;
- if (kw < 0 || kw >= KW)
- continue;
-
- size_t dst_offset_init = strided_offset(mb, diff_dst_n_stride,
- od, diff_dst_d_stride,
- oh, diff_dst_h_stride,
- ow, diff_dst_w_stride);
-
- if (alg == pooling_max) {
- DECLARE_READ_STRIDES(ws);
- size_t ws_offset_init = strided_offset(mb, ws_n_stride,
- od, ws_d_stride,
- oh, ws_h_stride,
- ow, ws_w_stride);
- const int index = kd * KH * KW + kh * KW + kw;
-
- PRAGMA_OMP_SIMD()
- for (int oc = 0; oc < OC; ++oc) {
- const int index_from_ws =
- (MEM_D(ws).data_type() == data_type::u8)
- ? (int)ws[ws_offset_init + oc]
- : ((int *)ws)[ws_offset_init + oc];
-
- const data_t d = diff_dst[dst_offset_init + oc];
-
- // Check if kernel windows are disjoint, in this case
- // there's no update needed and we just write there once
- // otherwise we add value to the contents.
- if (!(KD == SD && KH == SH && KW == SW))
- diff_src[src_offset_init + oc] +=
- (index_from_ws == index)
- ? d
- : data_type_t(0);
- else
- diff_src[src_offset_init + oc] =
- (index_from_ws == index)
- ? d
- : data_type_t(0);
- }
- } else {
- // pooling_avg
- auto id_start = apply_offset(od*SD, padF);
- auto ih_start = apply_offset(oh*SH, padT);
- auto iw_start = apply_offset(ow*SW, padL);
- auto id_end = nstl::min(od*SD - padF + KD, ID);
- auto ih_end = nstl::min(oh*SH - padT + KH, IH);
- auto iw_end = nstl::min(ow*SW - padL + KW, IW);
-
- auto num_summands = (alg == pooling_avg_include_padding)
- ? KW*KH*KD
- : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start);
-
- PRAGMA_OMP_SIMD()
- for (int oc = 0; oc < OC; ++oc) {
- const data_t d = diff_dst[dst_offset_init + oc];
- // Check if kernel windows are disjoint, in this case
- // there's no update needed and we just write there once
- // otherwise we add value to the contents.
- if (!(KD == SD && KH == SH && KW == SW))
- diff_src[src_offset_init + oc] += d / num_summands;
- else
- diff_src[src_offset_init + oc] = d / num_summands;
- }
- }
- }
- });
-}
-
-template struct nhwc_pooling_fwd_t<data_type::f32>;
-template struct nhwc_pooling_bwd_t<data_type::f32>;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp
deleted file mode 100644
index 7e33b6869f..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp
+++ /dev/null
@@ -1,210 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_NHWC_POOLING_HPP
-#define CPU_NHWC_POOLING_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_pooling_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-namespace nhwc_pooling {
-size_t strided_offset(const int _n, const size_t _sn, const int _d,
- const size_t _sd, const int _h, const size_t _sh, const int _w,
- const size_t _sw);
-}
-
-template <impl::data_type_t data_type>
-struct nhwc_pooling_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_pooling_fwd_pd_t {
- using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t;
-
- DECLARE_COMMON_PD_T("nhwc_pooling:any", nhwc_pooling_fwd_t);
-
- status_t init() {
- const format_tag_t desired_fmt_tag =
- ndims() == 4 ? format_tag::nhwc : format_tag::ndhwc;
-
- bool ok = true
- && set_default_params() == status::success
- && is_fwd()
- && utils::one_of(desc()->alg_kind, alg_kind::pooling_max,
- alg_kind::pooling_avg_include_padding,
- alg_kind::pooling_avg_exclude_padding)
- && utils::everyone_is(data_type,
- src_md()->data_type,
- dst_md()->data_type)
- && attr()->has_default_values()
- && memory_desc_matches_tag(*src_md(), desired_fmt_tag)
- && memory_desc_matches_tag(*dst_md(), desired_fmt_tag);
- if (!ok) return status::unimplemented;
-
- bool is_training = desc_.prop_kind == prop_kind::forward_training;
- if (desc()->alg_kind == alg_kind::pooling_max && is_training)
- init_default_ws();
-
- return status::success;
- }
- };
-
- nhwc_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
-
- typedef typename prec_traits<data_type>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- void array_div_by_const(const int n, const data_t *src, const size_t num,
- data_t *dst) const;
- void array_add(const int n, const data_t *src, data_t *dst) const;
-
- template <bool use_workspace>
- void array_nhwc_max(const int n, data_t *dst, const data_t *src,
- unsigned char *ws, const size_t ws_offset, const data_type_t ws_dt,
- const int index) const {
- assert(!((use_workspace == false) ^ (!ws))); // ensure ws pointer exists
- PRAGMA_OMP_SIMD()
- for (int oc = 0; oc < n; ++oc) {
- auto s = src[oc];
- data_t mv = dst[oc];
-
- // update index of maximum
-#if defined __INTEL_COMPILER
- if ((use_workspace) && (s > mv)) {
- assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
- if (ws_dt == data_type::u8) {
- assert(0 <= index && index <= 255);
- ws[ws_offset + oc] = index;
- } else
- reinterpret_cast<int *>(ws)[ws_offset + oc] = index;
- }
-#else
- // Need to add explicit predicates for GCC to vectorize this.
- // And although the resulting code is ugly, it is still 4 times
- // faster than scalar
- if (use_workspace) {
- assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
-
- if (ws_dt == data_type::u8) {
- assert(0 <= index && index <= 255);
- unsigned char predicate = (s > mv) ? 0xff : 0;
- unsigned char current_value = ws[ws_offset + oc];
- current_value = (predicate & (unsigned char)index)
- | ((~predicate) & current_value);
- ws[ws_offset + oc] = current_value;
- } else {
- auto wint = reinterpret_cast<int *>(ws);
- unsigned int predicate = (s > mv) ? 0xffffffff : 0;
- unsigned int current_value = wint[ws_offset + oc];
- current_value = (predicate & (unsigned int)index)
- | ((~predicate) & current_value);
- wint[ws_offset + oc] = current_value;
- }
- }
-#endif
- // update maximum
- dst[oc] = nstl::max(s, mv);
- }
- }
-
- template <bool use_workspace>
- void array_nhwc_initialize(const int n, data_t *dst, unsigned char *ws,
- const size_t ws_offset, const data_type_t ws_dt) const {
- assert(!((use_workspace == false) ^ (!ws))); // ensure ws pointer exists
- for (int oc = 0; oc < n; ++oc) {
- if (use_workspace) {
- assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
- if (ws_dt == data_type::u8) {
- ws[ws_offset + oc] = 0;
- } else
- reinterpret_cast<int *>(ws)[ws_offset + oc] = 0;
- }
- dst[oc] = nstl::numeric_limits<data_t>::lowest();
- }
- }
-
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-template <impl::data_type_t data_type>
-struct nhwc_pooling_bwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_pooling_bwd_pd_t {
- using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t;
-
- DECLARE_COMMON_PD_T("nhwc:any", nhwc_pooling_bwd_t);
-
- status_t init() {
- const format_tag_t desired_fmt_tag =
- ndims() == 4 ? format_tag::nchw : format_tag::ncdhw;
-
- bool ok = true
- && set_default_params() == status::success
- && !is_fwd()
- && utils::one_of(desc()->alg_kind, alg_kind::pooling_max,
- alg_kind::pooling_avg_include_padding,
- alg_kind::pooling_avg_exclude_padding)
- && utils::everyone_is(data_type,
- diff_dst_md()->data_type,
- diff_src_md()->data_type)
- && attr()->has_default_values()
- && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag)
- && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag);
- if (!ok) return status::unimplemented;
-
- if (desc()->alg_kind == alg_kind::pooling_max) {
- init_default_ws();
- if (!compare_ws(hint_fwd_pd_))
- return status::unimplemented;
- }
-
- return status::success;
- }
- };
-
- nhwc_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
- typedef typename prec_traits<data_type>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward(ctx);
- return status::success;
- }
-
-private:
- void execute_backward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-}// namespace cpu
-}// namespace impl
-}// namespace mkldnn
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp
deleted file mode 100644
index e20333e66f..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp
+++ /dev/null
@@ -1,288 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include <math.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-
-#include "cpu_batch_normalization_utils.hpp"
-#include "jit_generator.hpp"
-
-#include "nspc_batch_normalization.hpp"
-
-// clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases
-#if (defined __clang_major__) && (__clang_major__ >= 6)
-#define SAFE_TO_USE_OMP_SIMD 0
-#else
-#define SAFE_TO_USE_OMP_SIMD 1
-#endif
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace memory_tracking::names;
-
-void nspc_batch_normalization_fwd_t::execute_forward(
- const exec_ctx_t &ctx) const {
- const bool save_stats = pd()->is_training();
- const bool is_training = pd()->is_training();
- const bool fuse_bn_relu = pd()->fuse_bn_relu();
- const bool calculate_stats = !pd()->stats_is_src();
- const bool with_relu = pd()->with_relu_post_op();
-
- auto scratchpad = this->scratchpad(ctx);
- auto tmp_mean = scratchpad.get<data_t>(key_bnorm_tmp_mean);
- auto tmp_var = scratchpad.get<data_t>(key_bnorm_tmp_var);
- auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);
-
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);
-
- data_t *mean, *variance;
- if (!calculate_stats) {
- mean = const_cast<data_t *>(
- CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN));
- variance = const_cast<data_t *>(
- CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE));
- } else {
- if (save_stats) {
- mean = CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN);
- variance = CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE);
- } else {
- mean = tmp_mean;
- variance = tmp_var;
- }
- }
-
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
- auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE);
-
- const dim_t N = pd()->MB();
- const dim_t C = pd()->C();
- const dim_t SP = pd()->H() * pd()->W() * pd()->D();
-
- const float eps = pd()->desc()->batch_norm_epsilon;
- const bool use_scaleshift = pd()->use_scaleshift();
- auto maybe_post_op
- = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; };
-
- assert(mkldnn_thr_syncable());
- parallel(0, [&](const int ithr, const int nthr) {
- dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0;
- balance211(N, nthr, ithr, N_s, N_e);
- balance211(C, nthr, ithr, C_s, C_e);
- data_t *mean_loc = tmp_mean + nstl::max(C, (dim_t)16) * ithr;
- data_t *variance_loc = tmp_var + nstl::max(C, (dim_t)16) * ithr;
-
- if (calculate_stats) {
- for (dim_t c = 0; c < C; c++)
- ws_reduce[C * ithr + c] = 0.;
-
- for (dim_t n = N_s; n < N_e; n++)
- for (dim_t sp = 0; sp < SP; sp++)
- PRAGMA_OMP_SIMD()
- for (dim_t c = 0; c < C; c++)
- ws_reduce[C * ithr + c] += src[(size_t)n * SP * C
- + sp * C + c];
-
- mkldnn_thr_barrier();
-
- for (dim_t c = C_s; c < C_e; c++) {
- mean[c] = 0;
- for (dim_t n = 0; n < nthr; n++)
- mean[c] += ws_reduce[C * n + c];
- mean[c] /= SP * N;
- }
-
- mkldnn_thr_barrier();
-
- for (dim_t c = 0; c < C; c++) {
- mean_loc[c] = mean[c];
- ws_reduce[C * ithr + c] = 0.;
- }
-
- for (dim_t n = N_s; n < N_e; n++)
- for (dim_t sp = 0; sp < SP; sp++)
- PRAGMA_OMP_SIMD()
- for (dim_t c = 0; c < C; c++) {
- data_t m = src[(size_t)n * SP * C + sp * C + c]
- - mean_loc[c];
- ws_reduce[C * ithr + c] += m * m;
- }
-
- mkldnn_thr_barrier();
-
- for (dim_t c = C_s; c < C_e; c++) {
- variance[c] = 0;
- for (dim_t n = 0; n < nthr; n++)
- variance[c] += ws_reduce[C * n + c];
- variance[c] /= SP * N;
- }
-
- mkldnn_thr_barrier();
-
- for (dim_t c = 0; c < C; c++)
- variance_loc[c] = variance[c];
- } else {
- variance_loc = variance;
- mean_loc = mean;
- }
-
- for (dim_t n = N_s; n < N_e; n++) {
- for (dim_t sp = 0; sp < SP; sp++) {
-#if SAFE_TO_USE_OMP_SIMD
- PRAGMA_OMP_SIMD()
-#endif
- for (dim_t c = 0; c < C; c++) {
- data_t sqrt_variance = static_cast<data_t>(
- sqrtf(variance_loc[c] + eps));
- data_t sm = (use_scaleshift ? scaleshift[c] : 1.0f) / sqrt_variance;
- data_t sv = use_scaleshift ? scaleshift[C + c] : 0;
- size_t d_off = (size_t)n * SP * C + sp * C + c;
- data_t bn_res = sm * (src[d_off] - mean_loc[c]) + sv;
- if (fuse_bn_relu) {
- if (bn_res <= 0) {
- bn_res = 0;
- if (is_training)
- ws[d_off] = 0;
- } else {
- if (is_training)
- ws[d_off] = 1;
- }
- }
- dst[d_off] = maybe_post_op(bn_res);
- }
- }
- }
- });
-}
-
-void nspc_batch_normalization_bwd_t::execute_backward(
- const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN);
- auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE);
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);
- auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE);
-
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
- auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT);
-
- auto scratchpad = this->scratchpad(ctx);
- auto tmp_diff_ss = scratchpad.get<data_t>(key_bnorm_tmp_diff_ss);
-
- if (diff_scaleshift == nullptr)
- diff_scaleshift = tmp_diff_ss;
-
- const dim_t N = pd()->MB();
- const dim_t C = pd()->C();
- const dim_t SP = pd()->D() * pd()->H() * pd()->W();
- data_t *diff_gamma = diff_scaleshift, *diff_beta = diff_scaleshift + C;
- auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);
-
- const float eps = pd()->desc()->batch_norm_epsilon;
- const bool use_scaleshift = pd()->use_scaleshift();
- const bool calculate_diff_stats = !pd()->use_global_stats();
- const bool fuse_bn_relu = pd()->fuse_bn_relu();
-
- assert(mkldnn_thr_syncable());
- parallel(0, [&](const int ithr, const int nthr) {
- dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0;
- balance211(N, nthr, ithr, N_s, N_e);
- balance211(C, nthr, ithr, C_s, C_e);
-
- data_t *diff_gamma_loc = tmp_diff_ss + 2 * C + C * ithr;
- data_t *diff_beta_loc = tmp_diff_ss + 2 * C + C * (nthr + ithr);
-
- for (dim_t c = 0; c < C; c++) {
- ws_reduce[C * ithr + c] = 0.;
- ws_reduce[C * nthr + C * ithr + c] = 0.;
- }
-
- for (dim_t n = N_s; n < N_e; n++)
- for (dim_t sp = 0; sp < SP; sp++)
-#if SAFE_TO_USE_OMP_SIMD
- PRAGMA_OMP_SIMD()
-#endif
- for (dim_t c = 0; c < C; c++) {
- const size_t d_off = (size_t)n * SP * C + sp * C + c;
- data_t dd;
- if (fuse_bn_relu)
- dd = (!ws[d_off]) ? 0 : diff_dst[d_off];
- else
- dd = diff_dst[d_off];
- ws_reduce[C * ithr + c] += (src[d_off] - mean[c]) * dd;
- ws_reduce[C * nthr + C * ithr + c] += dd;
- }
-
- mkldnn_thr_barrier();
-
- for (dim_t c = C_s; c < C_e; c++) {
- data_t sqrt_variance
- = static_cast<data_t>(1.0f / sqrtf(variance[c] + eps));
- diff_gamma[c] = 0;
- diff_beta[c] = 0;
- for (dim_t n = 0; n < nthr; n++) {
- diff_gamma[c] += ws_reduce[C * n + c];
- diff_beta[c] += ws_reduce[C * nthr + C * n + c];
- }
- diff_gamma[c] *= sqrt_variance;
- }
-
- mkldnn_thr_barrier();
-
- for (dim_t c = 0; c < C; c++) {
- diff_gamma_loc[c] = diff_gamma[c];
- diff_beta_loc[c] = diff_beta[c];
- }
-
- for (dim_t n = N_s; n < N_e; n++) {
- for (dim_t sp = 0; sp < SP; sp++) {
-#if SAFE_TO_USE_OMP_SIMD
- PRAGMA_OMP_SIMD()
-#endif
- for (dim_t c = 0; c < C; c++) {
- const size_t d_off = (size_t)n * SP * C + sp * C + c;
- data_t gamma = use_scaleshift ? scaleshift[c] : 1;
- data_t sqrt_variance
- = static_cast<data_t>(1.0f / sqrtf(variance[c] + eps));
- data_t v_diff_src;
- if (fuse_bn_relu)
- v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off];
- else
- v_diff_src = diff_dst[d_off];
- if (calculate_diff_stats) {
- v_diff_src -= diff_beta_loc[c] / (SP * N)
- + (src[d_off] - mean[c]) * diff_gamma_loc[c]
- * sqrt_variance / (SP * N);
- }
- v_diff_src *= gamma * sqrt_variance;
- diff_src[d_off] = v_diff_src;
- }
- }
- }
- });
-}
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp
deleted file mode 100644
index aad86b05a7..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp
+++ /dev/null
@@ -1,169 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_NSPC_BATCH_NORMALIZATION_HPP
-#define CPU_NSPC_BATCH_NORMALIZATION_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_batch_normalization_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct nspc_batch_normalization_fwd_t : public cpu_primitive_t {
- struct pd_t : public cpu_batch_normalization_fwd_pd_t {
- pd_t(engine_t *engine, const batch_normalization_desc_t *adesc,
- const primitive_attr_t *attr,
- const batch_normalization_fwd_pd_t *hint_fwd_pd)
- : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- {}
-
- DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_fwd_t);
-
- status_t init() {
- using namespace data_type;
- using namespace prop_kind;
-
- bool ok = true
- /* the algorithm requires barriers while switching
- * between parallelization over N and C dimensions */
- && mkldnn_thr_syncable()
- && is_fwd()
- && !has_zero_dim_memory()
- && src_md()->data_type == f32
- && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32)
- && memory_desc_matches_tag(*src_md(), format_tag::nhwc)
- && (attr()->has_default_values() || this->with_relu_post_op());
- if (!ok) return status::unimplemented;
-
- if (is_training() && fuse_bn_relu()) init_default_ws(8);
-
- init_scratchpad();
-
- return status::success;
- }
-
- private:
- void init_scratchpad() {
- using namespace memory_tracking::names;
- auto scratchpad = scratchpad_registry().registrar();
- if (!stats_is_src()) {
- dim_t sz = nstl::max<dim_t>(C(), 16) * mkldnn_get_max_threads();
- scratchpad.book(key_bnorm_reduction, sizeof(data_t) * sz);
- scratchpad.book(key_bnorm_tmp_mean, sizeof(data_t) * sz);
- scratchpad.book(key_bnorm_tmp_var, sizeof(data_t) * sz);
- }
- }
- };
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- nspc_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
- ~nspc_batch_normalization_fwd_t() {}
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-struct nspc_batch_normalization_bwd_t : public cpu_primitive_t {
- struct pd_t : public cpu_batch_normalization_bwd_pd_t {
- pd_t(engine_t *engine, const batch_normalization_desc_t *adesc,
- const primitive_attr_t *attr,
- const batch_normalization_fwd_pd_t *hint_fwd_pd)
- : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- {}
-
- DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_bwd_t);
-
- status_t init() {
- using namespace data_type;
- using namespace prop_kind;
-
- bool ok = true
- /* the algorithm requires barriers while switching
- * between parallelization over N and C dimensions */
- && mkldnn_thr_syncable()
- && is_bwd()
- && !has_zero_dim_memory()
- && utils::everyone_is(f32, src_md()->data_type,
- diff_src_md()->data_type)
- && IMPLICATION(use_scaleshift(),
- utils::everyone_is(f32,
- weights_md()->data_type,
- diff_weights_md()->data_type))
- && memory_desc_matches_tag(*src_md(), format_tag::nhwc)
- && memory_desc_matches_tag(*diff_src_md(), format_tag::nhwc)
- && attr()->has_default_values();
- if (!ok) return status::unimplemented;
-
- if (fuse_bn_relu()) {
- init_default_ws(8);
- if (!compare_ws(hint_fwd_pd_))
- return status::unimplemented;
- }
-
- init_scratchpad();
-
- return status::success;
- }
-
- private:
- void init_scratchpad() {
- using namespace memory_tracking::names;
- auto scratchpad = scratchpad_registry().registrar();
- scratchpad.book(key_bnorm_reduction,
- sizeof(data_t) * 2 * C() * mkldnn_get_max_threads());
- scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * 2 * C()
- * (mkldnn_get_max_threads() + 1));
- }
- };
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- nspc_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
- ~nspc_batch_normalization_bwd_t() {}
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward(ctx);
- return status::success;
- }
-
-private:
- void execute_backward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp
deleted file mode 100644
index d79b1a034b..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp
+++ /dev/null
@@ -1,265 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include <math.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "mkldnn_thread.hpp"
-#include "simple_q10n.hpp"
-
-#include "ref_batch_normalization.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <impl::data_type_t data_type>
-void ref_batch_normalization_fwd_t<data_type>::execute_forward(
- const exec_ctx_t &ctx) const {
- /* fast return */
- if (this->pd()->has_zero_dim_memory()) return;
-
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto scaleshift = CTX_IN_MEM(const float *, MKLDNN_ARG_SCALE_SHIFT);
-
- auto mean = pd()->stats_is_src()
- ? const_cast<float *>(CTX_IN_MEM(const float *, MKLDNN_ARG_MEAN))
- : CTX_OUT_MEM(float *, MKLDNN_ARG_MEAN);
- auto variance = pd()->stats_is_src()
- ? const_cast<float *>(CTX_IN_MEM(const float *, MKLDNN_ARG_VARIANCE))
- : CTX_OUT_MEM(float *, MKLDNN_ARG_VARIANCE);
-
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
- auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE);
-
- const memory_desc_wrapper data_d(pd()->src_md());
- const memory_desc_wrapper scaleshift_d(pd()->weights_md());
-
- const dim_t N = pd()->MB();
- const dim_t C = pd()->C();
- dim_t H = 1, W = 1, D = 1;
- const bool has_spatial = utils::one_of(data_d.ndims(), 4, 5);
- if (has_spatial) {
- D = pd()->D();
- H = pd()->H();
- W = pd()->W();
- }
-
- const float eps = pd()->desc()->batch_norm_epsilon;
- const bool use_scaleshift = pd()->use_scaleshift();;
- const bool save_stats = pd()->is_training();
- const bool is_training = pd()->is_training();
- const bool fuse_bn_relu = pd()->fuse_bn_relu();
- const bool calculate_stats = !pd()->stats_is_src();
-
- const bool with_relu = pd()->with_relu_post_op();
- auto maybe_post_op = [&](float res) {
- return (with_relu && res < 0.0f) ? 0.0f : res;
- };
- const bool is_3d = data_d.ndims() == 5;
-
- auto data_offset = [&](const memory_desc_wrapper &data_d, dim_t n, dim_t c,
- dim_t d, dim_t h, dim_t w) {
- if (has_spatial) {
- if (is_3d)
- return data_d.off(n, c, d, h, w);
- else
- return data_d.off(n, c, h, w);
- } else
- return data_d.off(n, c);
- };
-
- parallel_nd(C, [&](dim_t c) {
- float v_mean = calculate_stats ? 0 : mean[c];
- float v_variance = calculate_stats ? 0 : variance[c];
-
- if (calculate_stats) {
- for (dim_t n = 0; n < N; ++n)
- for (dim_t d = 0; d < D; ++d)
- for (dim_t h = 0; h < H; ++h)
- for (dim_t w = 0; w < W; ++w)
- v_mean += src[data_offset(data_d, n, c, d, h, w)];
- v_mean /= W*N*H*D;
-
- for (dim_t n = 0; n < N; ++n)
- for (dim_t d = 0; d < D; ++d)
- for (dim_t h = 0; h < H; ++h)
- for (dim_t w = 0; w < W; ++w) {
- float m = src[data_offset(data_d, n, c, d, h, w)] - v_mean;
- v_variance += m*m;
- }
- v_variance /= W*H*N*D;
- }
-
- float sqrt_variance = sqrtf(v_variance + eps);
- float sm = (use_scaleshift
- ? scaleshift[scaleshift_d.off(0, c)]
- : 1.0f) / sqrt_variance;
- float sv = use_scaleshift ? scaleshift[scaleshift_d.off(1, c)] : 0;
-
- for (dim_t n = 0; n < N; ++n)
- for (dim_t d = 0; d < D; ++d)
- for (dim_t h = 0; h < H; ++h)
- for (dim_t w = 0; w < W; ++w) {
- auto d_off = data_offset(data_d,n,c,d,h,w);
- float bn_res = sm * ((float)src[d_off] - v_mean) + sv;
- if (fuse_bn_relu) {
- if (bn_res <= 0) {
- bn_res = 0;
- if (is_training)
- ws[d_off] = 0;
- } else {
- if (is_training)
- ws[d_off] = 1;
- }
- }
- if (data_type == data_type::s8) {
- dst[d_off] = qz_a1b0<float, data_t>()(maybe_post_op(bn_res));
- } else {
- dst[d_off] = static_cast<data_t>(maybe_post_op(bn_res));
- }
- }
-
- if (calculate_stats) {
- if (save_stats) {
- mean[c] = v_mean;
- variance[c] = v_variance;
- }
- }
- });
-}
-
-template struct ref_batch_normalization_fwd_t<data_type::f32>;
-template struct ref_batch_normalization_fwd_t<data_type::s8>;
-
-template <impl::data_type_t data_type>
-void ref_batch_normalization_bwd_t<data_type>::execute_backward(
- const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN);
- auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE);
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);
- auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE);
-
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
- auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT);
-
- const memory_desc_wrapper data_d(pd()->src_md());
- const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
- const memory_desc_wrapper scaleshift_d(pd()->weights_md());
- const memory_desc_wrapper diff_scaleshift_d(pd()->diff_weights_md());
-
- const dim_t C = pd()->C();
-
- /* fast return */
- if (this->pd()->has_zero_dim_memory()) {
- if (diff_scaleshift) {
- for (dim_t c = 0; c < C; ++c) {
- diff_scaleshift[diff_scaleshift_d.off(0, c)] = 0;
- diff_scaleshift[diff_scaleshift_d.off(1, c)] = 0;
- }
- }
- return;
- }
-
- const dim_t N = pd()->MB();
- dim_t H = 1, W = 1, D = 1;
- const bool has_spatial = utils::one_of(data_d.ndims(), 4, 5);
- if (has_spatial) {
- D = pd()->D();
- H = pd()->H();
- W = pd()->W();
- }
-
- const float eps = pd()->desc()->batch_norm_epsilon;
- const bool use_scaleshift = pd()->use_scaleshift();
- const bool calculate_diff_stats = !pd()->use_global_stats();
- const bool fuse_bn_relu = pd()->fuse_bn_relu();
-
- const bool is_3d = data_d.ndims() == 5;
-
- auto data_offset = [&](const memory_desc_wrapper &data_d, dim_t n, dim_t c,
- dim_t d, dim_t h, dim_t w) {
- if (has_spatial) {
- if (is_3d)
- return data_d.off(n, c, d, h, w);
- else
- return data_d.off(n, c, h, w);
- } else
- return data_d.off(n, c);
- };
-
- parallel_nd(C, [&](dim_t c) {
- data_t v_mean = mean[c];
- data_t v_variance = variance[c];
- data_t sqrt_variance = static_cast<data_t>(1.0f / sqrtf(v_variance + eps));
- data_t gamma = use_scaleshift ? scaleshift[scaleshift_d.off(0, c)] : 1;
- data_t diff_gamma = data_t(0);
- data_t diff_beta = data_t(0);
- diff_gamma = 0.0;
- diff_beta = 0.0;
-
- for (dim_t n = 0; n < N; ++n)
- for (dim_t d = 0; d < D; ++d)
- for (dim_t h = 0; h < H; ++h)
- for (dim_t w = 0; w < W; ++w) {
- const size_t s_off = data_offset(data_d, n, c, d, h, w);
- data_t dd = diff_dst[data_offset(diff_data_d, n, c, d, h, w)];
- if (fuse_bn_relu && !ws[s_off])
- dd = 0;
-
- diff_gamma += (src[s_off] - v_mean) * dd;
- diff_beta += dd;
- }
- diff_gamma *= sqrt_variance;
-
- if (diff_scaleshift) {
- diff_scaleshift[diff_scaleshift_d.off(0, c)] = diff_gamma;
- diff_scaleshift[diff_scaleshift_d.off(1, c)] = diff_beta;
- }
-
- for (dim_t n = 0; n < N; ++n)
- for (dim_t d = 0; d < D; ++d)
- for (dim_t h = 0; h < H; ++h)
- for (dim_t w = 0; w < W; ++w) {
- const size_t s_off = data_offset(data_d, n, c, d, h, w);
- const size_t dd_off = data_offset(diff_data_d, n, c, d, h, w);
- data_t dd = diff_dst[dd_off];
- if (fuse_bn_relu && !ws[s_off])
- dd = 0;
-
- data_t v_diff_src = dd;
- if (calculate_diff_stats) {
- v_diff_src -= diff_beta/(D*W*H*N) +
- (src[s_off] - v_mean) *
- diff_gamma*sqrt_variance/(D*W*H*N);
- }
- v_diff_src *= gamma*sqrt_variance;
- diff_src[dd_off] = v_diff_src;
- }
- });
-}
-
-template struct ref_batch_normalization_bwd_t<data_type::f32>;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp
deleted file mode 100644
index aa9f74125a..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp
+++ /dev/null
@@ -1,127 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_REF_BATCH_NORMALIZATION_HPP
-#define CPU_REF_BATCH_NORMALIZATION_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_batch_normalization_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <impl::data_type_t data_type>
-struct ref_batch_normalization_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_batch_normalization_fwd_pd_t {
- pd_t(engine_t *engine, const batch_normalization_desc_t *adesc,
- const primitive_attr_t *attr,
- const batch_normalization_fwd_pd_t *hint_fwd_pd)
- : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- {}
-
- DECLARE_COMMON_PD_T("ref:any", ref_batch_normalization_fwd_t);
-
- status_t init() {
- bool ok = true
- && is_fwd()
- && src_md()->data_type == data_type
- && IMPLICATION(use_scaleshift(),
- weights_md()->data_type == data_type::f32)
- && (attr()->has_default_values() || with_relu_post_op());
- if (!ok) return status::unimplemented;
-
- if (src_md()->data_type == data_type::s8 && !stats_is_src())
- return status::unimplemented;
-
- if (is_training() && fuse_bn_relu()) init_default_ws(8);
-
- return status::success;
- }
- };
-
- ref_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
-
- typedef typename prec_traits<data_type>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-template <impl::data_type_t data_type>
-struct ref_batch_normalization_bwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_batch_normalization_bwd_pd_t {
- pd_t(engine_t *engine, const batch_normalization_desc_t *adesc,
- const primitive_attr_t *attr,
- const batch_normalization_fwd_pd_t *hint_fwd_pd)
- : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- {}
-
- DECLARE_COMMON_PD_T("ref:any", ref_batch_normalization_bwd_t);
-
- status_t init() {
- bool ok = true
- && is_bwd()
- && utils::everyone_is(data_type, src_md()->data_type,
- diff_src_md()->data_type)
- && IMPLICATION(use_scaleshift(), utils::everyone_is(data_type,
- weights_md()->data_type,
- diff_weights_md()->data_type))
- && attr()->has_default_values();
- if (!ok) return status::unimplemented;
-
- if (fuse_bn_relu()) {
- init_default_ws(8);
- if (!compare_ws(hint_fwd_pd_))
- return status::unimplemented;
- }
-
- return status::success;
- }
- };
-
- ref_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
- typedef typename prec_traits<data_type>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward(ctx);
- return status::success;
- }
-
-private:
- void execute_backward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp
deleted file mode 100644
index 4c534b5508..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp
+++ /dev/null
@@ -1,97 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef REF_CONCAT_HPP
-#define REF_CONCAT_HPP
-
-#include "reorder_pd.hpp"
-
-#include "cpu_concat_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct ref_concat_t: public cpu_primitive_t {
- struct pd_t: public cpu_concat_pd_t {
- using cpu_concat_pd_t::cpu_concat_pd_t;
-
- pd_t(const pd_t &rhs): cpu_concat_pd_t(rhs) {
- for (size_t i = 0; i < rhs.reorder_pds_.size(); ++i)
- reorder_pds_.push_back(
- (const reorder_pd_t *)rhs.reorder_pds_[i]->clone());
- }
- ~pd_t() { for (auto &rpd: reorder_pds_) delete rpd; }
-
- DECLARE_CONCAT_PD_T("ref:any", ref_concat_t);
-
- status_t init() {
- bool ok = cpu_concat_pd_t::init() == status::success;
- if (!ok) return status::unimplemented;
-
- for (int i = 0; i < n_; ++i) {
- auto r_impls = engine_->get_reorder_implementation_list();
- for (auto r = r_impls; *r; ++r) {
- const primitive_attr_t attr; /* alpha == 1. */
- reorder_pd_t *r_pd = nullptr;
- if ((*r)(&r_pd, engine_, &attr, engine_, src_md(i),
- engine_, src_image_md(i)) == status::success) {
- r_pd->init_info();
- reorder_pds_.push_back(r_pd);
- break;
- }
- }
- }
-
- ok = reorder_pds_.size() == (size_t)n_;
- return ok ? status::success : status::unimplemented;
- }
-
- nstl::vector<const reorder_pd_t *> reorder_pds_;
- };
-
- ref_concat_t(const pd_t *apd): cpu_primitive_t(apd) {
- const int n = pd()->n_inputs();
- reorders_.resize(n);
- for (int i = 0; i < n; ++i)
- pd()->reorder_pds_[i]->create_primitive(&reorders_[i]);
- }
-
- ~ref_concat_t() { for (auto &r: reorders_) delete r; }
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- const auto n = pd()->n_inputs();
- for (int i = 0; i < n; ++i) {
- exec_args_t r_args;
- r_args[MKLDNN_ARG_SRC] = ctx.args().at(MKLDNN_ARG_MULTIPLE_SRC + i);
- r_args[MKLDNN_ARG_DST] = ctx.args().at(MKLDNN_ARG_DST);
- exec_ctx_t r_ctx(ctx.stream(), std::move(r_args));
- reorders_[i]->execute(r_ctx);
- }
- return status::success;
- }
-
-private:
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- nstl::vector<primitive_t *> reorders_;
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp
deleted file mode 100644
index c0a979c4cf..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp
+++ /dev/null
@@ -1,395 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "math_utils.hpp"
-#include "mkldnn_thread.hpp"
-#include "mkldnn_traits.hpp"
-#include "type_helpers.hpp"
-
-#include "ref_convolution.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using math::saturate;
-using math::get_bias;
-
-template <data_type_t src_type, data_type_t wei_type,
- data_type_t dst_type, data_type_t acc_type>
-void ref_convolution_fwd_t<src_type, wei_type, dst_type, acc_type>::
-execute_forward(const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
- const memory_desc_wrapper bias_d(pd()->weights_md(1));
-
- const bool with_groups = pd()->with_groups();
-
- const int G = pd()->G();
- const int MB = pd()->MB();
- const int OD = pd()->OD();
- const int OH = pd()->OH();
- const int OW = pd()->OW();
- const int ID = pd()->ID();
- const int IH = pd()->IH();
- const int IW = pd()->IW();
-
- const int OC = pd()->OC() / G;
- const int IC = pd()->IC() / G;
- const int KD = pd()->KD();
- const int KH = pd()->KH();
- const int KW = pd()->KW();
-
- const int KSD = pd()->KSD();
- const int KSH = pd()->KSH();
- const int KSW = pd()->KSW();
-
- const int KDD = pd()->KDD();
- const int KDH = pd()->KDH();
- const int KDW = pd()->KDW();
-
- const int padFront = pd()->padFront();
- const int padT = pd()->padT();
- const int padL = pd()->padL();
-
- const bool with_relu = 0; // TODO: change if support post_ops
- const float nslope = 0.f;
-
- const int ndims = pd()->desc()->src_desc.ndims;
-
- auto ker = [=](int g, int mb, int oc, int od, int oh,
- int ow) {
- acc_data_t d = 0;
- for (int ic = 0; ic < IC; ++ic)
- for (int kd = 0; kd < KD; ++kd)
- for (int kh = 0; kh < KH; ++kh)
- for (int kw = 0; kw < KW; ++kw) {
- const int id = od * KSD - padFront + kd * (1 + KDD);
- const int ih = oh * KSH - padT + kh * (1 + KDH);
- const int iw = ow * KSW - padL + kw * (1 + KDW);
-
- if (id < 0 || id >= ID) continue;
- if (ih < 0 || ih >= IH) continue;
- if (iw < 0 || iw >= IW) continue;
-
- if (ndims == 5)
- d += (acc_data_t)src[src_d.off(mb, g*IC + ic, id, ih, iw)]
- * (with_groups
- ? weights[weights_d.off(g, oc, ic, kd, kh, kw)]
- : weights[weights_d.off(oc, ic, kd, kh, kw)]);
- else if (ndims == 4)
- d += (acc_data_t)src[src_d.off(mb, g*IC + ic, ih, iw)]
- * (with_groups
- ? weights[weights_d.off(g, oc, ic, kh, kw)]
- : weights[weights_d.off(oc, ic, kh, kw)]);
- else if (ndims == 3)
- d += (acc_data_t)src[src_d.off(mb, g*IC + ic, iw)]
- * (with_groups
- ? weights[weights_d.off(g, oc, ic, kw)]
- : weights[weights_d.off(oc, ic, kw)]);
- else
- assert(false);
-
- }
- return d;
- };
-
- parallel_nd(G, MB, OC, OD, OH, OW,
- [&](int g, int mb, int oc, int od, int oh, int ow) {
- float a = bias
- ? get_bias(bias, bias_d.off(g * OC + oc),
- pd()->desc()->bias_desc.data_type)
- : 0;
- a += ker(g, mb, oc, od, oh, ow);
- if (with_relu && a < 0)
- a = a * nslope;
- if (ndims == 5)
- dst[dst_d.off(mb, g*OC + oc, od, oh, ow)] = saturate<dst_data_t>(a);
- else if (ndims == 4)
- dst[dst_d.off(mb, g*OC + oc, oh, ow)] = saturate<dst_data_t>(a);
- else if (ndims == 3)
- dst[dst_d.off(mb, g*OC + oc, ow)] = saturate<dst_data_t>(a);
- else
- assert(false);
- });
-}
-
-template <data_type_t diff_src_type, data_type_t wei_type,
- data_type_t diff_dst_type, data_type_t acc_type>
-void ref_convolution_bwd_data_t<diff_src_type, wei_type, diff_dst_type,
- acc_type>::execute_backward_data(const exec_ctx_t &ctx) const {
- auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
- auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
- const memory_desc_wrapper bias_d(pd()->weights_md(1));
-
- const bool with_groups = pd()->with_groups();
-
- const int G = pd()->G();
- const int MB = pd()->MB();
- const int OD = pd()->OD();
- const int OH = pd()->OH();
- const int OW = pd()->OW();
- const int ID = pd()->ID();
- const int IH = pd()->IH();
- const int IW = pd()->IW();
-
- const int OC = pd()->OC() / G;
- const int IC = pd()->IC() / G;
- const int KD = pd()->KD();
- const int KH = pd()->KH();
- const int KW = pd()->KW();
-
- const int KSD = pd()->KSD();
- const int KSH = pd()->KSH();
- const int KSW = pd()->KSW();
-
- const int KDD = pd()->KDD();
- const int KDH = pd()->KDH();
- const int KDW = pd()->KDW();
-
- const int padFront = pd()->padFront();
- const int padT = pd()->padT();
- const int padL = pd()->padL();
-
- const int ndims = pd()->desc()->diff_src_desc.ndims;
-
- auto ker = [=](int g, int mb, int ic, int id, int ih,
- int iw) {
- acc_data_t d = 0;
- for (int oc = 0; oc < OC; ++oc)
- for (int kd = 0; kd < KD; ++kd)
- for (int kh = 0; kh < KH; ++kh)
- for (int kw = 0; kw < KW; ++kw) {
- if (iw + padL < kw * (1 + KDW)
- || ih + padT < kh * (1 + KDH)
- || id + padFront < kd * (1 + KDD))
- continue;
- int ow = iw - kw * (1 + KDW) + padL;
- int oh = ih - kh * (1 + KDH) + padT;
- int od = id - kd * (1 + KDD) + padFront;
- if (ow % KSW != 0 || oh % KSH != 0 || od % KSD != 0)
- continue;
-
- ow /= KSW;
- oh /= KSH;
- od /= KSD;
-
- if (od < OD && oh < OH && ow < OW) {
- if (ndims == 5)
- d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC
- + oc, od, oh, ow)] * (with_groups
- ? weights[weights_d.off(g, oc, ic, kd, kh, kw)]
- : weights[weights_d.off(oc, ic, kd, kh, kw)]);
- else if (ndims == 4)
- d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC
- + oc, oh, ow)] * (with_groups
- ? weights[weights_d.off(g, oc, ic, kh, kw)]
- : weights[weights_d.off(oc, ic, kh, kw)]);
- else if (ndims == 3)
- d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC
- + oc, ow)] * (with_groups
- ? weights[weights_d.off(g, oc, ic, kw)]
- : weights[weights_d.off(oc, ic, kw)]);
- else
- assert(false);
- }
- }
- return d;
- };
-
- parallel_nd(G, MB, IC, ID, IH, IW,
- [&](int g, int mb, int ic, int id, int ih, int iw) {
- auto ds_idx = (ndims == 5)
- ? diff_src_d.off(mb, g*IC + ic, id, ih, iw)
- : (ndims == 4)
- ? diff_src_d.off(mb, g*IC + ic, ih, iw)
- : diff_src_d.off(mb, g*IC + ic, iw);
- float a = bias
- ? get_bias(bias, bias_d.off(g * IC + ic),
- pd()->desc()->bias_desc.data_type)
- : 0;
- a += ker(g, mb, ic, id, ih, iw);
- diff_src[ds_idx] = saturate<diff_src_data_t>(a);
- });
-}
-
-template <data_type_t src_type, data_type_t diff_wei_type,
- data_type_t diff_dst_type, data_type_t acc_type>
-void ref_convolution_bwd_weights_t<src_type, diff_wei_type, diff_dst_type,
- acc_type>::execute_backward_weights(const exec_ctx_t &ctx) const {
- auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
- auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
- auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
- auto diff_bias = CTX_OUT_MEM(diff_wei_data_t *, MKLDNN_ARG_DIFF_BIAS);
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
- const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1));
-
- const bool with_groups = pd()->with_groups();
-
- const int G = pd()->G();
- const int MB = pd()->MB();
- const int OD = pd()->OD();
- const int OH = pd()->OH();
- const int OW = pd()->OW();
- const int ID = pd()->ID();
- const int IH = pd()->IH();
- const int IW = pd()->IW();
-
- const int OC = pd()->OC() / G;
- const int IC = pd()->IC() / G;
- const int KD = pd()->KD();
- const int KH = pd()->KH();
- const int KW = pd()->KW();
-
- const int KSD = pd()->KSD();
- const int KSH = pd()->KSH();
- const int KSW = pd()->KSW();
-
- const int KDD = pd()->KDD();
- const int KDH = pd()->KDH();
- const int KDW = pd()->KDW();
-
- const int padFront = pd()->padFront();
- const int padT = pd()->padT();
- const int padL = pd()->padL();
-
- const int ndims = pd()->desc()->src_desc.ndims;
-
-auto ker = [=](acc_data_t &d, int g, int oc, int ic, int kd, int kh, int kw) {
- for (int mb = 0; mb < MB; ++mb)
- for (int od = 0; od < OD; ++od)
- for (int oh = 0; oh < OH; ++oh)
- for (int ow = 0; ow < OW; ++ow) {
- if (ow*KSW + kw * (1 + KDW) < padL
- || oh*KSH + kh * (1 + KDH) < padT
- || od*KSD + kd * (1 + KDD) < padFront
- || ow*KSW + kw * (1 + KDW) >= IW + padL
- || oh*KSH + kh * (1 + KDH) >= IH + padT
- || od*KSD + kd * (1 + KDD) >= ID + padFront)
- continue;
-
- int id = od*KSD - padFront + kd * (1 + KDD);
- int ih = oh*KSH - padT + kh * (1 + KDH);
- int iw = ow*KSW - padL + kw * (1 + KDW);
- if (ndims == 5)
- d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, od,
- oh, ow)] * src[src_d.off(mb, g*IC + ic, id, ih, iw)];
- else if (ndims == 4)
- d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, oh, ow)]
- * src[src_d.off(mb, g*IC + ic, ih, iw)];
- else if (ndims == 3)
- d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, ow)]
- * src[src_d.off(mb, g*IC + ic, iw)];
- else
- assert(false);
- }
- };
-
- auto ker_bias = [=](acc_data_t &d, int g, int oc) {
- for (int mb = 0; mb < MB; ++mb)
- for (int od = 0; od < OD; ++od)
- for (int oh = 0; oh < OH; ++oh)
- for (int ow = 0; ow < OW; ++ow) {
- if (ndims == 5)
- d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, od, oh,
- ow)];
- else if (ndims == 4)
- d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, oh,
- ow)];
- else if (ndims == 3)
- d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, ow)];
- else
- assert(false);
- }
- };
-
- parallel_nd(G, OC, [&](int g, int oc) {
- if (diff_bias) {
- // XXX: loss of precision when bias is a float...
- acc_data_t db = 0;
- ker_bias(db, g, oc);
- diff_bias[diff_bias_d.off(g*OC+oc)]
- = saturate<diff_wei_data_t>(db);
- }
-
- for (int ic = 0; ic < IC; ++ic)
- for (int kd = 0; kd < KD; ++kd)
- for (int kh = 0; kh < KH; ++kh)
- for (int kw = 0; kw < KW; ++kw) {
- acc_data_t dw = 0;
- ker(dw, g, oc, ic, kd, kh, kw);
-
- if (ndims == 5) {
- auto idx = with_groups
- ? diff_weights_d.off(g, oc, ic, kd, kh, kw)
- : diff_weights_d.off(oc, ic, kd, kh, kw);
- diff_weights[idx] = saturate<diff_wei_data_t>(dw);
- } else if (ndims == 4) {
- auto idx = with_groups
- ? diff_weights_d.off(g, oc, ic, kh, kw)
- : diff_weights_d.off(oc, ic, kh, kw);
- diff_weights[idx] = saturate<diff_wei_data_t>(dw);
- } else if (ndims == 3) {
- auto idx = with_groups
- ? diff_weights_d.off(g, oc, ic, kw)
- : diff_weights_d.off(oc, ic, kw);
- diff_weights[idx] = saturate<diff_wei_data_t>(dw);
- } else {
- assert(false);
- }
- }
- });
-}
-
-using namespace data_type;
-
-template struct ref_convolution_fwd_t<f32>;
-
-template struct ref_convolution_fwd_t<u8, s8, f32, s32>;
-template struct ref_convolution_fwd_t<u8, s8, s32, s32>;
-template struct ref_convolution_fwd_t<u8, s8, s8, s32>;
-template struct ref_convolution_fwd_t<u8, s8, u8, s32>;
-
-template struct ref_convolution_bwd_data_t<f32, f32, f32, f32>;
-
-template struct ref_convolution_bwd_data_t<f32, s8, u8, s32>;
-template struct ref_convolution_bwd_data_t<s32, s8, u8, s32>;
-template struct ref_convolution_bwd_data_t<s8, s8, u8, s32>;
-template struct ref_convolution_bwd_data_t<u8, s8, u8, s32>;
-
-template struct ref_convolution_bwd_weights_t<f32, f32, f32, f32>;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp
deleted file mode 100644
index 7c83d0c6d4..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp
+++ /dev/null
@@ -1,194 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_REF_CONVOLUTION_HPP
-#define CPU_REF_CONVOLUTION_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_convolution_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <impl::data_type_t src_type,
- impl::data_type_t wei_type = src_type,
- impl::data_type_t dst_type = src_type,
- impl::data_type_t acc_type = dst_type>
-struct ref_convolution_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_fwd_pd_t {
- using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t;
-
- DECLARE_COMMON_PD_T("ref:any", ref_convolution_fwd_t);
-
- status_t init() {
- using namespace data_type;
-
- bool ok = true
- && is_fwd()
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(src_type, wei_type, data_type::undef,
- dst_type, acc_type)
- && IMPLICATION(with_bias(), true
- && IMPLICATION(src_type == u8,
- utils::one_of(bias_md_.data_type, f32, s32, s8, u8))
- && IMPLICATION(src_type == f32,
- bias_md_.data_type == f32))
- && set_default_formats()
- && attr()->has_default_values();
- return ok ? status::success : status::unimplemented;
- }
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
- auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw);
- auto wei_tag = with_groups()
- ? utils::pick(ndims() - 3, goiw, goihw, goidhw)
- : utils::pick(ndims() - 3, oiw, oihw, oidhw);
- return set_default_formats_common(dat_tag, wei_tag, dat_tag);
- }
- };
-
- ref_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
-
- typedef typename prec_traits<src_type>::type src_data_t;
- typedef typename prec_traits<wei_type>::type wei_data_t;
- typedef typename prec_traits<dst_type>::type dst_data_t;
- typedef typename prec_traits<acc_type>::type acc_data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-template <impl::data_type_t diff_src_type, impl::data_type_t wei_type,
- impl::data_type_t diff_dst_type,
- impl::data_type_t acc_type = diff_src_type>
-struct ref_convolution_bwd_data_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_bwd_data_pd_t {
- using cpu_convolution_bwd_data_pd_t::cpu_convolution_bwd_data_pd_t;
-
- DECLARE_COMMON_PD_T("ref:any", ref_convolution_bwd_data_t);
-
- status_t init() {
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_data
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(diff_src_type, wei_type, data_type::undef,
- diff_dst_type, acc_type)
- && set_default_formats()
- && attr()->has_default_values();
-
- return ok ? status::success : status::unimplemented;
- }
-
- virtual bool support_bias() const override { return true; }
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
- auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw);
- auto wei_tag = with_groups()
- ? utils::pick(ndims() - 3, goiw, goihw, goidhw)
- : utils::pick(ndims() - 3, oiw, oihw, oidhw);
- return set_default_formats_common(dat_tag, wei_tag, dat_tag);
- }
- };
-
- ref_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {}
-
- typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
- typedef typename prec_traits<wei_type>::type wei_data_t;
- typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
- typedef typename prec_traits<acc_type>::type acc_data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward_data(ctx);
- return status::success;
- }
-
-private:
- void execute_backward_data(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-template <impl::data_type_t src_type, impl::data_type_t diff_wei_type,
- impl::data_type_t diff_dst_type,
- impl::data_type_t acc_type = diff_wei_type>
-struct ref_convolution_bwd_weights_t: public cpu_primitive_t {
- struct pd_t: public cpu_convolution_bwd_weights_pd_t {
- using cpu_convolution_bwd_weights_pd_t::cpu_convolution_bwd_weights_pd_t;
-
- DECLARE_COMMON_PD_T("ref:any", ref_convolution_bwd_weights_t);
-
- status_t init() {
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_weights
- && set_default_alg_kind(alg_kind::convolution_direct)
- && expect_data_types(src_type, diff_wei_type, diff_wei_type,
- diff_dst_type, acc_type)
- && set_default_formats()
- && attr()->has_default_values();
- return ok ? status::success : status::unimplemented;
- }
-
- protected:
- bool set_default_formats() {
- using namespace format_tag;
- auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw);
- auto wei_tag = with_groups()
- ? utils::pick(ndims() - 3, goiw, goihw, goidhw)
- : utils::pick(ndims() - 3, oiw, oihw, oidhw);
- return set_default_formats_common(dat_tag, wei_tag, dat_tag);
- }
- };
-
- ref_convolution_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {}
-
- typedef typename prec_traits<src_type>::type src_data_t;
- typedef typename prec_traits<diff_wei_type>::type diff_wei_data_t;
- typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
- typedef typename prec_traits<acc_type>::type acc_data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward_weights(ctx);
- return status::success;
- }
-
-private:
- void execute_backward_weights(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp
deleted file mode 100644
index 541a303aab..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp
+++ /dev/null
@@ -1,199 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "mkldnn_thread.hpp"
-#include "mkldnn_traits.hpp"
-#include "math_utils.hpp"
-
-#include "ref_deconvolution.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-void ref_deconvolution_fwd_t::compute_fwd_bias(const data_t *bias,
- data_t *dst) const {
- const memory_desc_wrapper dst_d(pd()->dst_md());
-
- const int G = pd()->G();
- const int MB = pd()->MB();
- const int OH = pd()->OH();
- const int OW = pd()->OW();
- const int OD = pd()->OD();
- const int OC = pd()->OC() / G;
- const int ndims = pd()->desc()->src_desc.ndims;
-
- parallel_nd(MB, G, OC, OD, OH, OW,
- [&](int mb, int g, int oc, int od, int oh, int ow) {
- auto b = bias[g * OC + oc];
- switch (ndims) {
- case 5: dst[dst_d.off(mb, g * OC + oc, od, oh, ow)] += b; break;
- case 4: dst[dst_d.off(mb, g * OC + oc, oh, ow)] += b; break;
- case 3: dst[dst_d.off(mb, g * OC + oc, ow)] += b; break;
- default: assert(!"invalid dimension size");
- }
- });
-}
-
-void ref_deconvolution_fwd_t::compute_fwd_bias_ncdhw(const data_t *bias,
- data_t *dst) const {
- const memory_desc_wrapper dst_d(pd()->dst_md());
-
- const int MB = pd()->MB();
- const int OC = pd()->OC();
- const int SP = pd()->OW()*pd()->OH()*pd()->OD();
-
- parallel_nd(MB, OC, [&](int mb, int oc) {
- PRAGMA_OMP_SIMD()
- for (int sp = 0; sp < SP; ++sp) {
- auto offset = (size_t)(mb * OC + oc) * SP + sp;
- dst[offset] += bias[oc];
- }
- });
-}
-
-template <int blksize>
-void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc(const data_t *bias,
- data_t *dst) const {
- const memory_desc_wrapper dst_d(pd()->dst_md());
-
- const int MB = pd()->MB();
- const int OC = pd()->OC();
- const int SP = pd()->OW() * pd()->OH() * pd()->OD();
-
- const ptrdiff_t stride_mb = dst_d.blocking_desc().strides[0];
-
- parallel_nd(MB, utils::div_up(OC, blksize), SP,
- [&](int mb, int oc_blk, int sp) {
- int oc = oc_blk * blksize;
- auto offset = mb * stride_mb + oc * SP + sp * blksize;
- const int blk = nstl::min(blksize, OC - oc);
-
- PRAGMA_OMP_SIMD()
- for (int i = 0; i < blk; ++i)
- dst[offset + i] += bias[oc + i];
- });
-}
-
-void ref_deconvolution_bwd_weights_t::compute_bwd_bias(const data_t *diff_dst,
- data_t *diff_bias) const {
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
-
- const int G = pd()->G();
- const int MB = pd()->MB();
- const int OH = pd()->OH();
- const int OW = pd()->OW();
- const int OC = pd()->OC() / G;
- const int OD = pd()->OD();
- const int ndims = pd()->desc()->src_desc.ndims;
-
- parallel_nd(G, OC, [&](int g, int oc) {
- data_t db = 0;
- for (int mb = 0; mb < MB; ++mb) {
- for (int od = 0; od < OD; ++od) {
- for (int oh = 0; oh < OH; ++oh) {
- for (int ow = 0; ow < OW; ++ow) {
- switch (ndims) {
- case 5:
- db += diff_dst[diff_dst_d.off(
- mb, g * OC + oc, od, oh, ow)];
- break;
- case 4:
- db += diff_dst[diff_dst_d.off(
- mb, g * OC + oc, oh, ow)];
- break;
- case 3:
- db += diff_dst[diff_dst_d.off(mb, g * OC + oc, ow)];
- break;
- default: assert(!"invalid dimension size");
- }
- }
- }
- }
- }
- diff_bias[g * OC + oc] = db;
- });
-}
-
-void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ncdhw(
- const data_t *diff_dst, data_t *diff_bias) const {
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
-
- const int OC = pd()->OC();
- const int MB = pd()->MB();
- const int SP = pd()->OH()*pd()->OW()*pd()->OD();
-
- parallel_nd(OC, [&](int oc) {
- data_t db = 0;
- for (int mb = 0; mb < MB; ++mb) {
- PRAGMA_OMP_SIMD()
- for (int sp = 0; sp < SP; ++sp) {
- auto offset = (size_t)(mb * OC + oc) * SP + sp;
- db += diff_dst[offset];
- }
- }
- diff_bias[oc] = db;
- });
-}
-
-template <int blksize>
-void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc(
- const data_t *diff_dst, data_t *diff_bias) const {
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
-
- const int OC = pd()->OC();
- const int MB = pd()->MB();
- const int SP = pd()->OH() * pd()->OW() * pd()->OD();
-
- const ptrdiff_t stride_mb = diff_dst_d.blocking_desc().strides[0];
-
- parallel_nd(utils::div_up(OC, blksize), [&](int ocb) {
- data_t db[blksize] = {0};
-
- for (int mb = 0; mb < MB; ++mb) {
- for (int sp = 0; sp < SP; ++sp) {
- auto offset = mb * stride_mb + (ocb * SP + sp) * blksize;
-
- PRAGMA_OMP_SIMD()
- for (int i = 0; i < blksize; ++i)
- db[i] += diff_dst[offset+i];
- }
- }
-
- const int blk = nstl::min(blksize, OC - ocb * blksize);
-
- PRAGMA_OMP_SIMD()
- for (int i = 0; i < blk; ++i)
- diff_bias[ocb * blksize + i] = db[i];
- });
-}
-
-template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc<8>(
- const data_t *diff_dst, data_t *diff_bias) const;
-template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc<16>(
- const data_t *diff_dst, data_t *diff_bias) const;
-template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc<8>(
- const data_t *diff_dst, data_t *diff_bias) const;
-template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc<16>(
- const data_t *diff_dst, data_t *diff_bias) const;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp
deleted file mode 100644
index d61903c32d..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp
+++ /dev/null
@@ -1,502 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_REF_DECONVOLUTION_HPP
-#define CPU_REF_DECONVOLUTION_HPP
-
-#include <assert.h>
-#include <string.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-#include "primitive_iterator.hpp"
-
-#include "cpu_convolution_pd.hpp"
-#include "cpu_deconvolution_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-static status_t compute_blocked_format(bool with_groups,
- const memory_desc_t *oi_md, memory_desc_t *io_md)
-{
- /* Computes blocking for *i*o* format from *o*i* format */
-
- bool sanity_check_ok = true
- && oi_md->ndims == io_md->ndims
- && oi_md->format_kind == format_kind::blocked;
- if (!sanity_check_ok) return status::invalid_arguments;
-
- const blocking_desc_t &oi_blk = oi_md->format_desc.blocking;
- blocking_desc_t io_blk = io_md->format_desc.blocking;
-
- io_md->format_kind = format_kind::blocked;
- io_blk = oi_blk;
-
- const int ID_OC = 0 + with_groups;
- const int ID_IC = 1 + with_groups;
-
- nstl::swap(io_blk.strides[ID_OC], io_blk.strides[ID_IC]);
- for (int i_blk = 0; i_blk < io_blk.inner_nblks; ++i_blk) {
- if (utils::one_of(io_blk.inner_idxs[i_blk], ID_OC, ID_IC)) {
- io_blk.inner_idxs[i_blk] =
- (io_blk.inner_idxs[i_blk] == ID_OC ? ID_IC : ID_OC);
- }
- }
-
- return memory_desc_init_by_blocking_desc(*io_md, io_blk);
-}
-
-static status_t conv_descr_create(const deconvolution_desc_t *dd,
- convolution_desc_t *cd)
-{
- using namespace prop_kind;
- alg_kind_t alg_kind = dd->alg_kind == alg_kind::deconvolution_direct
- ? alg_kind::convolution_direct : alg_kind::convolution_winograd;
-
- const memory_desc_t *src_md, *dst_md, *d_weights_d;
- prop_kind_t prop_kind;
- memory_desc_t c_weights_d;
- if (utils::one_of(dd->prop_kind, forward_training, forward_inference)) {
- prop_kind = backward_data;
- src_md = &dd->dst_desc;
- dst_md = &dd->src_desc;
- d_weights_d = &dd->weights_desc;
- } else if (dd->prop_kind == backward_data) {
- prop_kind = forward_training;
- src_md = &dd->diff_dst_desc;
- dst_md = &dd->diff_src_desc;
- d_weights_d = &dd->weights_desc;
- } else {
- prop_kind = dd->prop_kind;
- src_md = &dd->diff_dst_desc;
- dst_md = &dd->src_desc;
- d_weights_d = &dd->diff_weights_desc;
- }
-
- const bool with_groups = d_weights_d->ndims == src_md->ndims + 1;
-
- /* create weights desc for convolution */
- c_weights_d = *d_weights_d;
-
- const int ID_OC = 0 + with_groups;
- const int ID_IC = 1 + with_groups;
-
- nstl::swap(c_weights_d.dims[ID_OC], c_weights_d.dims[ID_IC]);
- nstl::swap(c_weights_d.padded_dims[ID_OC], c_weights_d.padded_dims[ID_IC]);
- nstl::swap(c_weights_d.padded_offsets[ID_OC], c_weights_d.padded_offsets[ID_IC]);
-
- if (c_weights_d.format_kind != format_kind::any)
- CHECK(compute_blocked_format(with_groups, d_weights_d, &c_weights_d));
-
- return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d,
- prop_kind != backward_weights ? &dd->bias_desc : nullptr,
- dst_md, dd->strides, dd->dilates,
- dd->padding[0], dd->padding[1], dd->padding_kind);
-}
-
-struct ref_deconvolution_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_deconvolution_fwd_pd_t {
- pd_t(engine_t *engine,
- const deconvolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const deconvolution_fwd_pd_t *hint_fwd_pd)
- : cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- , conv_pd_(nullptr)
- {}
-
- pd_t(const pd_t &other)
- : cpu_deconvolution_fwd_pd_t(other)
- , conv_pd_(other.conv_pd_->clone())
- , conv_supports_bias_(other.conv_supports_bias_)
- , dst_tag_(other.dst_tag_)
- {}
-
- ~pd_t() { delete conv_pd_; }
-
- DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_fwd_t);
-
- status_t init_convolution() {
- using namespace types;
-
- convolution_desc_t cd;
- CHECK(conv_descr_create(desc(), &cd));
-
- mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd,
- &attr_, nullptr);
- while (++it != it.end()) {
- conv_pd_ = *it;
- conv_supports_bias_ =
- static_cast<cpu_convolution_bwd_data_pd_t *>(conv_pd_)
- ->support_bias();
- bool output_f32 = utils::everyone_is(data_type::f32,
- desc()->accum_data_type, desc()->dst_desc.data_type);
-
- bool ok = true
- && conv_pd_->weights_md()->extra.flags == 0
- /* deconv reference code can process only f32 bias */
- && IMPLICATION(with_bias(),
- conv_supports_bias_ || output_f32);
- if (ok) return status::success;
-
- delete conv_pd_;
- }
- conv_pd_ = nullptr;
- return status::unimplemented;
- }
-
- status_t init() {
- using namespace format_tag;
- bool ok = true
- && is_fwd()
- && utils::one_of(desc()->alg_kind,
- alg_kind::deconvolution_direct,
- alg_kind::deconvolution_winograd)
- && attr()->post_ops_.has_default_values();
-
- if (ok) {
- CHECK(init_convolution());
- if (weights_md_.format_kind == format_kind::any) {
- CHECK(compute_blocked_format(with_groups(),
- conv_pd_->weights_md(), &desc_.weights_desc));
- weights_md_ = desc_.weights_desc;
- }
- if (src_md_.format_kind == format_kind::any)
- src_md_ = *conv_pd_->diff_dst_md();
- if (dst_md_.format_kind == format_kind::any)
- dst_md_ = *conv_pd_->diff_src_md();
- if (bias_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(bias_md_, x));
-
- dst_tag_ = memory_desc_matches_one_of_tag(dst_md_,
- utils::pick(ndims() - 3, ncw, nchw, ncdhw),
- utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c),
- utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c));
-
- return status::success;
- }
-
- return status::unimplemented;
- }
-
- virtual void init_scratchpad_md() override {
- scratchpad_md_ = *conv_pd_->scratchpad_md();
- }
-
- primitive_desc_t *conv_pd_;
- bool conv_supports_bias_;
- format_tag_t dst_tag_;
- };
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- ref_deconvolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd)
- { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); }
- ~ref_deconvolution_fwd_t() { delete conv_p_; }
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- const auto &args = ctx.args();
- exec_args_t conv_args;
- conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC);
- conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS);
- if (pd()->with_bias() && pd()->conv_supports_bias_)
- conv_args[MKLDNN_ARG_BIAS] = args.at(MKLDNN_ARG_BIAS);
- conv_args[MKLDNN_ARG_DIFF_SRC] = args.at(MKLDNN_ARG_DST);
- if (!types::is_zero_md(pd()->scratchpad_md()))
- conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD);
- const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args));
-
- conv_p_->execute(conv_ctx);
-
- if (pd()->with_bias() && !pd()->conv_supports_bias_) {
- using namespace format_tag;
-
- auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
-
- switch (pd()->dst_tag_) {
- case ncdhw: case nchw: case ncw:
- compute_fwd_bias_ncdhw(bias, dst);
- break;
- case nCdhw8c: case nChw8c: case nCw8c:
- compute_fwd_bias_nCdhwXc<8>(bias, dst);
- break;
- case nCdhw16c: case nChw16c: case nCw16c:
- compute_fwd_bias_nCdhwXc<16>(bias, dst);
- break;
- default:
- compute_fwd_bias(bias, dst);
- break;
- }
- }
- return status::success;
- }
-
-private:
- void compute_fwd_bias(const data_t *bias, data_t *dst) const;
- void compute_fwd_bias_ncdhw(const data_t *bias, data_t *dst) const;
- template <int blksize> void compute_fwd_bias_nCdhwXc(const data_t *bias,
- data_t *dst) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- primitive_t *conv_p_;
-};
-
-struct ref_deconvolution_bwd_data_t: public cpu_primitive_t {
- struct pd_t: public cpu_deconvolution_bwd_data_pd_t {
- pd_t(engine_t *engine, const deconvolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const deconvolution_fwd_pd_t *hint_fwd_pd)
- : cpu_deconvolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
- , conv_pd_(nullptr)
- {}
-
- pd_t(const pd_t &other)
- : cpu_deconvolution_bwd_data_pd_t(other)
- , conv_pd_(other.conv_pd_->clone()) {}
-
- ~pd_t() { delete conv_pd_; }
-
- DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_data_t);
-
- status_t init_convolution() {
- using namespace types;
-
- convolution_desc_t cd;
- status_t status = conv_descr_create(desc(), &cd);
- if (status != status::success) return status;
-
- mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd,
- &attr_, nullptr);
- while (++it != it.end()) {
- conv_pd_ = *it;
- if (conv_pd_->weights_md()->extra.flags == 0)
- return status::success;
- delete conv_pd_;
- }
-
- return status::unimplemented;
- }
-
- status_t init() {
- using namespace data_type;
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_data
- && utils::everyone_is(data_type::f32,
- desc()->diff_src_desc.data_type,
- desc()->weights_desc.data_type,
- desc()->diff_dst_desc.data_type)
- && utils::one_of(desc()->alg_kind,
- alg_kind::deconvolution_direct,
- alg_kind::deconvolution_winograd);
-
- if (ok) {
- CHECK(init_convolution());
- if (weights_md_.format_kind == format_kind::any) {
- CHECK(compute_blocked_format(with_groups(),
- conv_pd_->weights_md(), &desc_.weights_desc));
- weights_md_ = desc_.weights_desc;
- }
- if (diff_src_md_.format_kind == format_kind::any)
- diff_src_md_ = *conv_pd_->dst_md();
- if (diff_dst_md_.format_kind == format_kind::any)
- diff_dst_md_ = *conv_pd_->src_md();
-
- return status::success;
- }
-
- return status::unimplemented;
- }
-
- virtual void init_scratchpad_md() override {
- scratchpad_md_ = *conv_pd_->scratchpad_md();
- }
-
- primitive_desc_t *conv_pd_;
- };
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- ref_deconvolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd)
- { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); }
- ~ref_deconvolution_bwd_data_t() { delete conv_p_; }
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- const auto &args = ctx.args();
- exec_args_t conv_args;
- conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST);
- conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS);
- conv_args[MKLDNN_ARG_DST] = args.at(MKLDNN_ARG_DIFF_SRC);
- if (!types::is_zero_md(pd()->scratchpad_md()))
- conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD);
- const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args));
-
- conv_p_->execute(conv_ctx);
- return status::success;
- }
-
-private:
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- primitive_t *conv_p_;
-};
-
-struct ref_deconvolution_bwd_weights_t: public cpu_primitive_t {
- struct pd_t: public cpu_deconvolution_bwd_weights_pd_t {
- pd_t(engine_t *engine,
- const deconvolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const deconvolution_fwd_pd_t *hint_fwd_pd)
- : cpu_deconvolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
- , conv_pd_(nullptr)
- {}
-
- pd_t(const pd_t &other)
- : cpu_deconvolution_bwd_weights_pd_t(other)
- , conv_pd_(other.conv_pd_->clone())
- , dst_tag_(other.dst_tag_)
- {}
-
- ~pd_t() { delete conv_pd_; }
-
- DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_weights_t);
-
- status_t init_convolution() {
- using namespace types;
-
- convolution_desc_t cd;
- status_t status = conv_descr_create(desc(), &cd);
- if (status != status::success) return status;
-
- mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd,
- &attr_, nullptr);
- while (++it != it.end()) {
- conv_pd_ = *it;
- if (conv_pd_->diff_weights_md()->extra.flags == 0)
- return status::success;
- delete conv_pd_;
- }
- return status::unimplemented;
- }
-
- status_t init() {
- using namespace format_tag;
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_weights
- && utils::everyone_is(data_type::f32,
- desc()->src_desc.data_type,
- desc()->diff_weights_desc.data_type,
- desc()->diff_dst_desc.data_type)
- && utils::one_of(desc()->alg_kind,
- alg_kind::deconvolution_direct,
- alg_kind::deconvolution_winograd)
- && attr()->has_default_values();
- if (ok) {
- CHECK(init_convolution());
- if (diff_weights_md_.format_kind == format_kind::any) {
- CHECK(compute_blocked_format(with_groups(),
- conv_pd_->diff_weights_md(),
- &desc_.diff_weights_desc));
- diff_weights_md_ = desc_.diff_weights_desc;
- }
- if (src_md_.format_kind == format_kind::any)
- src_md_ = *conv_pd_->diff_dst_md();
- if (diff_dst_md_.format_kind == format_kind::any)
- diff_dst_md_ = *conv_pd_->src_md();
- if (diff_bias_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(diff_bias_md_, x));
-
- dst_tag_ = memory_desc_matches_one_of_tag(diff_dst_md_,
- utils::pick(ndims() - 3, ncw, nchw, ncdhw),
- utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c),
- utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c));
-
- return status::success;
- }
-
- return status::unimplemented;
- }
-
- virtual void init_scratchpad_md() override {
- scratchpad_md_ = *conv_pd_->scratchpad_md();
- }
-
- primitive_desc_t *conv_pd_;
- format_tag_t dst_tag_;
- };
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- ref_deconvolution_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd)
- { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); }
- ~ref_deconvolution_bwd_weights_t() { delete conv_p_; }
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- const auto &args = ctx.args();
- exec_args_t conv_args;
- conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC);
- conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST);
- conv_args[MKLDNN_ARG_DIFF_WEIGHTS] = args.at(MKLDNN_ARG_DIFF_WEIGHTS);
- if (!types::is_zero_md(pd()->scratchpad_md()))
- conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD);
- const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args));
-
- status_t status = conv_p_->execute(conv_ctx);
- if (status != status::success) return status;
-
- if (pd()->with_bias()) {
- using namespace format_tag;
-
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
-
- switch (pd()->dst_tag_) {
- case ncdhw: case nchw: case ncw:
- compute_bwd_bias_ncdhw(diff_dst, diff_bias);
- break;
- case nCdhw8c: case nChw8c: case nCw8c:
- compute_bwd_bias_nCdhwXc<8>(diff_dst, diff_bias);
- break;
- case nCdhw16c: case nChw16c: case nCw16c:
- compute_bwd_bias_nCdhwXc<16>(diff_dst, diff_bias);
- break;
- default:
- compute_bwd_bias(diff_dst, diff_bias);
- break;
- }
- }
- return status::success;
- }
-
-private:
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- void compute_bwd_bias(const data_t *diff_dst, data_t *diff_bias) const;
- void compute_bwd_bias_ncdhw(const data_t *diff_dst,
- data_t *diff_bias) const;
- template <int blksize> void compute_bwd_bias_nCdhwXc(
- const data_t *diff_dst, data_t *diff_bias) const;
-
- primitive_t *conv_p_;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp
deleted file mode 100644
index 7beee8d323..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp
+++ /dev/null
@@ -1,297 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "math_utils.hpp"
-#include "mkldnn_thread.hpp"
-
-#include "ref_eltwise.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace alg_kind;
-using namespace math;
-
-ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha,
- float beta): alg_(alg), alpha_(alpha), beta_(beta) {
- assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
- eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
- eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic));
-}
-
-ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(
- const post_ops_t::entry_t::eltwise_t &eltwise)
- : ref_eltwise_scalar_fwd_t(eltwise.alg, eltwise.alpha, eltwise.beta) {}
-
-float ref_eltwise_scalar_fwd_t::compute_scalar(float s) {
- switch (alg_) {
- case eltwise_relu: return relu_fwd(s, alpha_);
- case eltwise_tanh: return tanh_fwd(s);
- case eltwise_elu: return elu_fwd(s, alpha_);
- case eltwise_square: return square_fwd(s);
- case eltwise_abs: return abs_fwd(s);
- case eltwise_sqrt: return sqrt_fwd(s);
- case eltwise_linear: return linear_fwd(s, alpha_, beta_);
- case eltwise_bounded_relu: return bounded_relu_fwd(s, alpha_);
- case eltwise_soft_relu: return soft_relu_fwd(s);
- case eltwise_logistic: return logistic_fwd(s);
- default: assert(!"unknown eltwise alg_kind");
- }
-
- return 0.f;
-}
-
-template <impl::data_type_t data_type>
-void ref_eltwise_fwd_t<data_type>::execute_forward_nCspBc_padded(
- const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper data_d(pd()->src_md());
- const blocking_desc_t &blk = data_d.blocking_desc();
- const int block = blk.inner_blks[0];
-
- const int MB = pd()->MB();
- const int C = pd()->C() / block;
- const int C_PADDED = data_d.padded_dims()[1] / block;
- const int tail = pd()->C() % block;
- const int SP = pd()->D() * pd()->H() * pd()->W();
- const auto alg_kind = pd()->desc()->alg_kind;
- const float alpha = pd()->desc()->alpha;
- const float beta = pd()->desc()->beta;
-
- auto ker = [=] (data_t &d, data_t s) {
- switch (alg_kind) {
- case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
- case eltwise_bounded_relu:
- d = bounded_relu_fwd(s, alpha); break;
- case eltwise_soft_relu: d = soft_relu_fwd(s); break;
- case eltwise_logistic: d = logistic_fwd(s); break;
- default: assert(!"unknown eltwise alg_kind");
- }
- };
-
- // FIXME: integer overflow?
-
- parallel_nd(MB, C_PADDED, SP,
- [&](int n, int c, int sp) {
- auto d_off = (n*C_PADDED*SP + c*SP + sp) * block;
- if (c < C) {
- for (int v = 0; v < block; v++)
- ker(dst[d_off + v], src[d_off + v]);
- } else {
- for (int v = 0; v < tail; v++)
- ker(dst[d_off + v], src[d_off + v]);
- }
- });
-}
-
-template <impl::data_type_t data_type>
-void ref_eltwise_fwd_t<data_type>::execute_forward_generic(
- const exec_ctx_t &ctx) const {
- /* fast return */
- if (pd()->has_zero_dim_memory()) return;
-
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper data_d(pd()->src_md());
-
- const int MB = pd()->MB();
- const int C = pd()->C();
- const int D = pd()->D();
- const int H = pd()->H();
- const int W = pd()->W();
- const auto alg_kind = pd()->desc()->alg_kind;
- const float alpha = pd()->desc()->alpha;
- const float beta = pd()->desc()->beta;
- const bool is_3d = pd()->desc()->data_desc.ndims == 5;
-
- parallel_nd(MB, C, D, H, W,
- [&](int n, int c, int id, int h, int w) {
- auto d_off = is_3d
- ? data_d.off(n, c, id, h, w) : data_d.off(n, c, h, w);
- data_t s = src[d_off];
- data_t &d = dst[d_off];
- switch (alg_kind) {
- case eltwise_relu: d = relu_fwd(s, alpha); break;
- case eltwise_tanh: d = tanh_fwd(s); break;
- case eltwise_elu: d = elu_fwd(s, alpha); break;
- case eltwise_square: d = square_fwd(s); break;
- case eltwise_abs: d = abs_fwd(s); break;
- case eltwise_sqrt: d = sqrt_fwd(s); break;
- case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
- case eltwise_bounded_relu:
- d = bounded_relu_fwd(s, alpha); break;
- case eltwise_soft_relu: d = soft_relu_fwd(s); break;
- case eltwise_logistic: d = logistic_fwd(s); break;
- default: assert(!"unknown eltwise alg_kind");
- }
- });
-}
-
-template <impl::data_type_t data_type>
-void ref_eltwise_fwd_t<data_type>::execute_forward_dense(
- const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper data_d(pd()->src_md());
-
- const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true));
- const auto alg_kind = pd()->desc()->alg_kind;
- const float alpha = pd()->desc()->alpha;
- const float beta = pd()->desc()->beta;
-
- src += data_d.offset0();
- dst += data_d.offset0();
-
- if (alg_kind == eltwise_relu) {
- // a fast path for relu as the most popular activation
- parallel_nd(nelems, [&](ptrdiff_t e) {
- dst[e] = relu_fwd(src[e], alpha);
- });
- return;
- }
-
- parallel_nd(nelems, [&](ptrdiff_t e) {
- const data_t s = src[e];
- data_t &d = dst[e];
-
- switch (alg_kind) {
- case eltwise_tanh: d = tanh_fwd(s); break;
- case eltwise_elu: d = elu_fwd(s, alpha); break;
- case eltwise_square: d = square_fwd(s); break;
- case eltwise_abs: d = abs_fwd(s); break;
- case eltwise_sqrt: d = sqrt_fwd(s); break;
- case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
- case eltwise_bounded_relu: d = bounded_relu_fwd(s, alpha); break;
- case eltwise_soft_relu: d = soft_relu_fwd(s); break;
- case eltwise_logistic: d = logistic_fwd(s); break;
- default: assert(!"unknown eltwise alg_kind");
- }
- });
-}
-
-template <impl::data_type_t data_type>
-void ref_eltwise_bwd_t<data_type>::execute_backward_generic(
- const exec_ctx_t &ctx) const {
- /* fast return */
- if (pd()->has_zero_dim_memory()) return;
-
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const memory_desc_wrapper data_d(pd()->src_md());
- const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
-
- const int MB = pd()->MB();
- const int C = pd()->C();
- const int D = pd()->D();
- const int H = pd()->H();
- const int W = pd()->W();
- const auto alg_kind = pd()->desc()->alg_kind;
- const float alpha = pd()->desc()->alpha;
- const float beta = pd()->desc()->beta;
- const bool is_3d = pd()->desc()->data_desc.ndims == 5;
-
- parallel_nd(MB, C, D, H, W,
- [&](int n, int c, int d, int h, int w) {
- auto data_off = is_3d
- ? data_d.off(n, c, d, h, w) : data_d.off(n, c, h, w);
- auto diff_data_off = is_3d
- ? diff_data_d.off(n, c, d, h, w)
- : diff_data_d.off(n, c, h, w);
- data_t s = src[data_off];
- data_t dd = diff_dst[diff_data_off];
- data_t &ds = diff_src[diff_data_off];
- switch (alg_kind) {
- case eltwise_relu: ds = relu_bwd(dd, s, alpha); break;
- case eltwise_tanh: ds = tanh_bwd(dd, s); break;
- case eltwise_elu: ds = elu_bwd(dd, s, alpha); break;
- case eltwise_square: ds = square_bwd(dd, s); break;
- case eltwise_abs: ds = abs_bwd(dd, s); break;
- case eltwise_sqrt: ds = sqrt_bwd(dd, s); break;
- case eltwise_linear:
- ds = linear_bwd(dd, s, alpha, beta); break;
- case eltwise_bounded_relu:
- ds = bounded_relu_bwd(dd, s, alpha); break;
- case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break;
- case eltwise_logistic: ds = logistic_bwd(dd, s); break;
- default: assert(!"unknown eltwise alg_kind");
- }
- });
-}
-
-template <impl::data_type_t data_type>
-void ref_eltwise_bwd_t<data_type>::execute_backward_dense(
- const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const memory_desc_wrapper data_d(pd()->src_md());
- const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
-
- const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true));
- const auto alg_kind = pd()->desc()->alg_kind;
- const float alpha = pd()->desc()->alpha;
- const float beta = pd()->desc()->beta;
-
- src += data_d.offset0();
- diff_dst += diff_data_d.offset0();
- diff_src += diff_data_d.offset0();
-
- parallel_nd(nelems, [&](ptrdiff_t e) {
- const data_t dd = diff_dst[e];
- const data_t s = src[e];
- data_t &ds = diff_src[e];
-
- switch (alg_kind) {
- case eltwise_relu: ds = relu_bwd(dd, s, alpha); break;
- case eltwise_tanh: ds = tanh_bwd(dd, s); break;
- case eltwise_elu: ds = elu_bwd(dd, s, alpha); break;
- case eltwise_square: ds = square_bwd(dd, s); break;
- case eltwise_abs: ds = abs_bwd(dd, s); break;
- case eltwise_sqrt: ds = sqrt_bwd(dd, s); break;
- case eltwise_linear: ds = linear_bwd(dd, s, alpha, beta); break;
- case eltwise_bounded_relu: ds = bounded_relu_bwd(dd, s, alpha); break;
- case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break;
- case eltwise_logistic: ds = logistic_bwd(dd, s); break;
- default: assert(!"unknown eltwise alg_kind");
- }
- });
-}
-
-template struct ref_eltwise_fwd_t<data_type::f32>;
-template struct ref_eltwise_fwd_t<data_type::s32>;
-template struct ref_eltwise_fwd_t<data_type::s8>;
-template struct ref_eltwise_fwd_t<data_type::u8>;
-
-template struct ref_eltwise_bwd_t<data_type::f32>;
-template struct ref_eltwise_bwd_t<data_type::s32>;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp
deleted file mode 100644
index 8f4ab35413..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp
+++ /dev/null
@@ -1,168 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_REF_ELTWISE_HPP
-#define CPU_REF_ELTWISE_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_eltwise_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct ref_eltwise_scalar_fwd_t {
-public:
- ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha, float beta);
-
- // note that eltwise.scale is ignored
- ref_eltwise_scalar_fwd_t(const post_ops_t::entry_t::eltwise_t &eltwise);
-
- float compute_scalar(float s);
-
- const alg_kind_t alg_;
- const float alpha_;
- const float beta_;
-};
-
-template <impl::data_type_t data_type>
-struct ref_eltwise_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_eltwise_fwd_pd_t {
- using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t;
-
- DECLARE_COMMON_PD_T("ref:any", ref_eltwise_fwd_t);
-
- status_t init() {
- using namespace utils;
-
- auto src_d = memory_desc_wrapper(src_md());
-
- use_dense_ = false
- || src_d.is_dense()
- || (src_d.is_dense(true) && is_zero_preserved());
-
- use_nCspBc_padded_ = !use_dense_
- && src_d.blocking_desc().inner_nblks == 1
- && one_of(src_d.blocking_desc().inner_blks[0], 8, 16)
- && src_d.blocking_desc().inner_idxs[0] == 1
- && src_d.only_padded_dim(1)
- && src_d.is_dense(true);
-
- if (has_zero_dim_memory())
- use_dense_ = use_nCspBc_padded_ = false;
-
- const bool use_generic = !use_dense_ && !use_nCspBc_padded_;
-
- bool ok = true
- && is_fwd()
- && everyone_is(data_type, desc()->data_desc.data_type)
- && IMPLICATION(use_generic, one_of(src_d.ndims(), 4, 5))
- && attr()->has_default_values();
- if (!ok) return status::unimplemented;
-
- return status::success;
- }
-
- bool use_dense_, use_nCspBc_padded_;
- };
-
- ref_eltwise_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
- typedef typename prec_traits<data_type>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- if (pd()->use_dense_)
- execute_forward_dense(ctx);
- else if (pd()->use_nCspBc_padded_)
- execute_forward_nCspBc_padded(ctx);
- else
- execute_forward_generic(ctx);
- return status::success;
- }
-
-private:
- void execute_forward_nCspBc_padded(const exec_ctx_t &ctx) const;
- void execute_forward_dense(const exec_ctx_t &ctx) const;
- void execute_forward_generic(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-template <impl::data_type_t data_type>
-struct ref_eltwise_bwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_eltwise_bwd_pd_t {
- using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t;
-
- DECLARE_COMMON_PD_T("ref:any", ref_eltwise_bwd_t);
-
- status_t init() {
- using namespace utils;
-
- bool ok = true
- && !is_fwd()
- && everyone_is(data_type,
- desc()->data_desc.data_type,
- desc()->diff_data_desc.data_type)
- && attr()->has_default_values();
- if (!ok) return status::unimplemented;
-
- auto diff_dst_d = memory_desc_wrapper(diff_dst_md());
- const bool same_fmt_ = diff_dst_d == memory_desc_wrapper(src_md());
-
- use_dense_ = true
- && same_fmt_
- && diff_dst_d.is_dense(true)
- && is_zero_preserved()
- && !has_zero_dim_memory();
- const bool use_generic = !use_dense_;
-
- if (use_generic && !one_of(diff_dst_d.ndims(), 4, 5))
- return status::unimplemented;
-
- return status::success;
- }
-
- bool use_dense_;
- };
-
- ref_eltwise_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
- typedef typename prec_traits<data_type>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- if (pd()->use_dense_)
- execute_backward_dense(ctx);
- else
- execute_backward_generic(ctx);
- return status::success;
- }
-
-private:
- void execute_backward_dense(const exec_ctx_t &ctx) const;
- void execute_backward_generic(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp
deleted file mode 100644
index c807a9ffd0..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp
+++ /dev/null
@@ -1,285 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "mkldnn_thread.hpp"
-#include "mkldnn_traits.hpp"
-#include "math_utils.hpp"
-
-#include "ref_inner_product.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using math::saturate;
-using math::get_bias;
-
-template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type,
- data_type_t acc_type>
-void ref_inner_product_fwd_t<src_type, wei_type, dst_type, acc_type>::
-execute_forward(const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
- const memory_desc_wrapper bias_d(pd()->weights_md(1));
-
- const int MB = pd()->MB();
- const int OC = pd()->OC();
- const int IC = pd()->IC();
-
- const bool src_has_spatial = utils::one_of(src_d.ndims(), 3, 4, 5);
- const int ndims = src_d.ndims() - 2;
-
- const auto &post_ops = pd()->attr()->post_ops_;
- const bool do_relu = post_ops.len_ == 1;
- const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f;
-
- auto ker_has_spatial = [=](int mb, int oc) {
- acc_data_t d = 0;
- const int KD = pd()->KD();
- const int KH = pd()->KH();
- const int KW = pd()->KW();
- for (int ic = 0; ic < IC; ++ic) {
- for (int kd = 0; kd < KD; ++kd) {
- for (int kh = 0; kh < KH; ++kh) {
- for (int kw = 0; kw < KW; ++kw) {
- switch (ndims) {
- case 3:
- d += (acc_data_t)src[src_d.off(mb, ic, kd, kh, kw)]
- * weights[weights_d.off(
- oc, ic, kd, kh, kw)];
- break;
- case 2:
- d += (acc_data_t)src[src_d.off(mb, ic, kh, kw)]
- * weights[weights_d.off(oc, ic, kh, kw)];
- break;
- case 1:
- d += (acc_data_t)src[src_d.off(mb, ic, kw)]
- * weights[weights_d.off(oc, ic, kw)];
- break;
- default: assert(!"unsupported ndims size");
- }
- }
- }
- }
- }
- return d;
- };
-
- auto ker_no_spatial = [=](int mb, int oc) {
- acc_data_t d = 0;
- for (int ic = 0; ic < IC; ++ic) {
- d += (acc_data_t)src[src_d.off(mb, ic)]
- * weights[weights_d.off(oc, ic)];
- }
- return d;
- };
-
- parallel_nd(MB, OC, [&](int mb, int oc) {
- float a = bias
- ? get_bias(bias, bias_d.off(oc), pd()->desc()->bias_desc.data_type)
- : 0;
- if (src_has_spatial)
- a += ker_has_spatial(mb, oc);
- else
- a += ker_no_spatial(mb, oc);
- if (do_relu && a < (acc_data_t)0)
- a *= nslope;
- dst[dst_d.off(mb, oc)] = saturate<dst_data_t>(a);
- });
-}
-
-using namespace data_type;
-template struct ref_inner_product_fwd_t<f32>;
-template struct ref_inner_product_fwd_t<u8, s8, f32, s32>;
-template struct ref_inner_product_fwd_t<u8, s8, s32, s32>;
-template struct ref_inner_product_fwd_t<u8, s8, s8, s32>;
-template struct ref_inner_product_fwd_t<u8, s8, u8, s32>;
-
-template <data_type_t diff_src_type, data_type_t wei_type,
- data_type_t diff_dst_type, data_type_t acc_type>
-void ref_inner_product_bwd_data_t<diff_src_type, wei_type, diff_dst_type,
- acc_type>::execute_backward_data(const exec_ctx_t &ctx) const {
- auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
- auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
- auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper weights_d(pd()->weights_md(0));
- const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
-
- const int MB = pd()->MB();
- const int OC = pd()->OC();
- const int IC = pd()->IC();
-
- const bool diff_src_has_spatial
- = utils::one_of(diff_src_d.ndims(), 3, 4, 5);
- const int ndims = diff_src_d.ndims() - 2;
-
- parallel_nd(MB, IC, [&](int mb, int ic) {
- if (diff_src_has_spatial) {
- const int KD = pd()->KD();
- const int KH = pd()->KH();
- const int KW = pd()->KW();
- for (int kd = 0; kd < KD; ++kd)
- for (int kh = 0; kh < KH; ++kh)
- for (int kw = 0; kw < KW; ++kw) {
- acc_data_t ds = acc_data_t(0);
- for (int oc = 0; oc < OC; ++oc) {
- switch (ndims) {
- case 3:
- ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)]
- * weights[weights_d.off(oc, ic, kd, kh, kw)]);
- break;
- case 2:
- ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)]
- * weights[weights_d.off(oc, ic, kh, kw)]);
- break;
- case 1:
- ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)]
- * weights[weights_d.off(oc, ic, kw)]);
- break;
- default: assert(!"unsupported ndims size");
- }
- }
- switch (ndims) {
- case 3:
- diff_src[diff_src_d.off(mb, ic, kd, kh, kw)]
- = (diff_src_data_t)ds;
- break;
- case 2:
- diff_src[diff_src_d.off(mb, ic, kh, kw)]
- = (diff_src_data_t)ds;
- break;
- case 1:
- diff_src[diff_src_d.off(mb, ic, kw)] = (diff_src_data_t)ds;
- break;
- default: assert(!"unsupported ndims size");
- }
- }
- } else {
- acc_data_t ds = acc_data_t(0);
- for (int oc = 0; oc < OC; ++oc) {
- ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] *
- weights[weights_d.off(oc, ic)]);
- }
- diff_src[diff_src_d.off(mb, ic)] = (diff_src_data_t)ds;
- }
- });
-}
-
-template struct ref_inner_product_bwd_data_t<f32, f32, f32, f32>;
-
-template <impl::data_type_t data_type>
-void ref_inner_product_bwd_weights_t<data_type>::execute_backward_weights(
- const exec_ctx_t &ctx) const {
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
- auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
- const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1));
-
- const int MB = pd()->MB();
- const int OC = pd()->OC();
- const int IC = pd()->IC();
-
- const bool src_has_spatial = utils::one_of(src_d.ndims(), 3, 4 ,5);
- const int ndims = src_d.ndims() - 2;
-
- parallel_nd(OC, IC, [&](int oc, int ic) {
- if (src_has_spatial) {
- const int KD = pd()->KD();
- const int KH = pd()->KH();
- const int KW = pd()->KW();
- for (int kd = 0; kd < KD; ++kd) {
- for (int kh = 0; kh < KH; ++kh) {
- for (int kw = 0; kw < KW; ++kw) {
- data_t *dw(nullptr);
- switch (ndims) {
- case 3:
- dw = &diff_weights[diff_weights_d.off(
- oc, ic, kd, kh, kw)];
- break;
- case 2:
- dw = &diff_weights[diff_weights_d.off(
- oc, ic, kh, kw)];
- break;
- case 1:
- dw = &diff_weights[diff_weights_d.off(oc, ic, kw)];
- break;
- default: assert(!"unsupported ndims size");
- }
- *dw = data_t(0);
- for (int mb = 0; mb < MB; ++mb) {
- switch (ndims) {
- case 3:
- *dw += diff_dst[diff_dst_d.off(mb, oc)]
- * src[src_d.off(mb, ic, kd, kh, kw)];
- break;
- case 2:
- *dw += diff_dst[diff_dst_d.off(mb, oc)]
- * src[src_d.off(mb, ic, kh, kw)];
- break;
- case 1:
- *dw += diff_dst[diff_dst_d.off(mb, oc)]
- * src[src_d.off(mb, ic, kw)];
- break;
- default: assert(!"unsupported ndims size");
- }
- }
- }
- }
- }
- } else {
- data_t *dw = &diff_weights[diff_weights_d.off(oc, ic)];
- *dw = data_t(0);
- for (int mb = 0; mb < MB; ++mb) {
- *dw += diff_dst[diff_dst_d.off(mb, oc)] *
- src[src_d.off(mb, ic)];
- }
- }
- });
-
- if (diff_bias) {
- diff_bias += diff_bias_d.offset0();
-
- parallel_nd(OC, [&](int oc) {
- data_t *db = &diff_bias[oc];
- *db = data_t(0);
- for (int mb = 0; mb < MB; ++mb)
- *db += diff_dst[diff_dst_d.off(mb, oc)];
- });
- }
-}
-
-template struct ref_inner_product_bwd_weights_t<data_type::f32>;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp
deleted file mode 100644
index bf87dbd514..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp
+++ /dev/null
@@ -1,159 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_REF_INNER_PRODUCT_HPP
-#define CPU_REF_INNER_PRODUCT_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_inner_product_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <impl::data_type_t src_type, impl::data_type_t wei_type = src_type,
- impl::data_type_t dst_type = src_type,
- impl::data_type_t acc_type = dst_type>
-struct ref_inner_product_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_inner_product_fwd_pd_t {
- using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t;
-
- DECLARE_COMMON_PD_T("ref:any", ref_inner_product_fwd_t);
-
- status_t init() {
- using namespace data_type;
-
- bool ok = true
- && set_default_params() == status::success
- && is_fwd()
- && src_md()->data_type == src_type
- && weights_md()->data_type == wei_type
- && desc()->accum_data_type == acc_type
- && dst_md()->data_type == dst_type
- && IMPLICATION(with_bias(), utils::one_of(
- weights_md(1)->data_type, f32, s32, s8, u8))
- && attr()->output_scales_.has_default_values()
- && attr()->post_ops_.len_ <= 1
- && IMPLICATION(attr()->post_ops_.len_ == 1,
- attr()->post_ops_.entry_[0].is_relu(true, false));
- return ok ? status::success : status::unimplemented;
- }
- };
-
- ref_inner_product_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
-
- typedef typename prec_traits<src_type>::type src_data_t;
- typedef typename prec_traits<wei_type>::type wei_data_t;
- typedef typename prec_traits<dst_type>::type dst_data_t;
- typedef typename prec_traits<acc_type>::type acc_data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-template <impl::data_type_t diff_src_type, impl::data_type_t wei_type,
- impl::data_type_t diff_dst_type,
- impl::data_type_t acc_type = diff_src_type>
-struct ref_inner_product_bwd_data_t: public cpu_primitive_t {
- struct pd_t: public cpu_inner_product_bwd_data_pd_t {
- using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t;
-
- DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_data_t);
-
- status_t init() {
- bool ok = true
- && set_default_params() == status::success
- && desc()->prop_kind == prop_kind::backward_data
- && diff_src_md()->data_type == diff_src_type
- && weights_md()->data_type == wei_type
- && desc()->accum_data_type == acc_type
- && diff_dst_md()->data_type == diff_dst_type
- && attr()->has_default_values();
- return ok ? status::success : status::unimplemented;
- }
- };
-
- ref_inner_product_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {}
-
- typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
- typedef typename prec_traits<wei_type>::type wei_data_t;
- typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
- typedef typename prec_traits<acc_type>::type acc_data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward_data(ctx);
- return status::success;
- }
-
-private:
- void execute_backward_data(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-template <impl::data_type_t data_type>
-struct ref_inner_product_bwd_weights_t: public cpu_primitive_t {
- struct pd_t: public cpu_inner_product_bwd_weights_pd_t {
- using cpu_inner_product_bwd_weights_pd_t::cpu_inner_product_bwd_weights_pd_t;
-
- DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_weights_t);
-
- status_t init() {
- bool ok = true
- && set_default_params() == status::success
- && desc()->prop_kind == prop_kind::backward_weights
- && utils::everyone_is(data_type,
- src_md()->data_type,
- diff_dst_md()->data_type,
- diff_weights_md()->data_type)
- && IMPLICATION(with_bias(),
- data_type == diff_weights_md(1)->data_type)
- && attr()->has_default_values();
- return ok ? status::success : status::unimplemented;
- }
- };
-
- ref_inner_product_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {}
- typedef typename prec_traits<data_type>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward_weights(ctx);
- return status::success;
- }
-
-private:
- void execute_backward_weights(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp
deleted file mode 100644
index 325e97963b..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp
+++ /dev/null
@@ -1,252 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include <math.h>
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-
-#include "ref_lrn.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-static inline float fast_negative_powf(float omega, float beta) {
- float Y;
-/*
- * Y = omega^(-3/4) =
- * = 1.0f / sqrtf(omega) * sqrtf(1.0f / sqrtf(omega))
- * = sqrtf(1.0f / sqrtf(omega)) * 1.0f / sqrtf(omega)
- * = sqrtf(1.0f / sqrtf(omega)) / sqrtf(omega)
- * = sqrtf(1.0f / sqrtf(omega) / omega)
- * = sqrtf(1.0f / (sqrtf(omega) * omega))
- */
- if (beta == 0.75f) {
- Y = sqrtf(1.0f / (sqrtf(omega) * omega));
- } else {
- Y = 1.0f / powf(omega, beta);
- }
- return Y;
-};
-
-template <impl::data_type_t data_type>
-template <impl::format_tag_t tag>
-void ref_lrn_fwd_t<data_type>::execute_forward(const exec_ctx_t &ctx) const {
- using namespace alg_kind;
- using namespace format_tag;
-
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper data_d(pd()->src_md());
-
- const int C = pd()->C();
- const int H = pd()->H();
- const int W = pd()->W();
- const size_t stride_mb = data_d.blocking_desc().strides[0];
- const bool across_channels = pd()->desc()->alg_kind == lrn_across_channels;
- constexpr int blksize = tag == nChw16c ? 16 : 8;
-
- auto data_off = [&](int mb, int c, int h, int w) -> size_t {
- switch (tag) {
- case nChw16c:
- case nChw8c: return mb * stride_mb + c / blksize * H * W * blksize
- + h * W * blksize + w * blksize + c % blksize;
- case nchw: return mb * stride_mb + c * H * W + h * W + w;
- case nhwc: return mb * stride_mb + h * W * C + w * C + c;
- default: return data_d.off(mb, c, h, w);
- }
- };
-
- auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) {
- const float alpha = static_cast<float>(pd()->desc()->lrn_alpha);
- const float beta = static_cast<float>(pd()->desc()->lrn_beta);
- const float k = static_cast<float>(pd()->desc()->lrn_k);
-
- const int size = pd()->desc()->local_size;
- const int half_size = (size - 1) / 2;
-
- float sum = 0;
- if (across_channels) {
- const int c_st = nstl::max(oc - half_size + 0, 0);
- const int c_en = nstl::min(oc + half_size + 1, C);
-
- for (int c = c_st; c < c_en; ++c) {
- const float s = src[data_off(mb, c, oh, ow)];
- sum += s * s;
- }
- } else {
- int h_st = nstl::max(oh - half_size + 0, 0);
- int h_en = nstl::min(oh + half_size + 1, H);
- int w_st = nstl::max(ow - half_size + 0, 0);
- int w_en = nstl::min(ow + half_size + 1, W);
- for (int h = h_st; h < h_en; ++h) {
- for (int w = w_st; w < w_en; ++w) {
- const float s = src[data_off(mb, oc, h, w)];
- sum += s * s;
- }
- }
- }
- const int summands = across_channels ? size : size * size;
- sum = k + alpha * sum / summands;
- size_t off = data_off(mb, oc, oh, ow);
- d[0] = static_cast<data_t>(src[off] * fast_negative_powf(sum, beta));
- };
-
- const int MB = pd()->MB();
- if (tag == nChw16c || tag == nChw8c) {
- parallel_nd(MB, utils::div_up(C, blksize), H, W,
- [&](int mb, int c_blk, int h, int w) {
- int c = c_blk * blksize;
- const size_t off = mb * stride_mb + c * H * W
- + (h * W + w) * blksize;
- PRAGMA_OMP_SIMD()
- for (int cc = 0; cc < nstl::min(blksize, C - c); ++cc)
- ker(&dst[off + cc], mb, c + cc, h, w);
- });
- } else if (tag == nhwc) {
- parallel_nd(MB, H, W, C,
- [&](int mb, int h, int w, int c) {
- const size_t off = mb * stride_mb + h * W * C + w * C + c;
- ker(&dst[off], mb, c, h, w);
- });
- } else {
- parallel_nd(MB, C, H, W,
- [&](int mb, int c, int h, int w) {
- const size_t off = data_off(mb, c, h, w);
- ker(&dst[off], mb, c, h, w);
- });
- }
-}
-
-template <impl::data_type_t data_type>
-template <mkldnn_format_tag_t tag>
-void ref_lrn_bwd_t<data_type>::execute_backward(const exec_ctx_t &ctx) const {
- using namespace alg_kind;
- using namespace format_tag;
-
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const memory_desc_wrapper data_d(pd()->src_md());
-
- const int MB = pd()->MB();
- const int C = pd()->C();
- const int H = pd()->H();
- const int W = pd()->W();
- const size_t stride_mb = data_d.blocking_desc().strides[0];
- constexpr int blksize = tag == nChw16c ? 16 : 8;
-
- const float alpha = static_cast<float>(pd()->desc()->lrn_alpha);
- const float beta = static_cast<float>(pd()->desc()->lrn_beta);
- const float k = static_cast<float>(pd()->desc()->lrn_k);
- const int kernel_size = pd()->desc()->local_size;
- const int half_ksize = (kernel_size - 1) / 2;
-
- auto data_off = [&](int mb, int c, int h, int w) -> size_t {
- switch (tag) {
- case nChw16c:
- case nChw8c: return mb * stride_mb + c/blksize * H * W * blksize
- + h * W * blksize + w * blksize + c%blksize;
- case nchw: return mb * stride_mb + c * H * W + h * W + w;
- case nhwc: return mb * stride_mb + h * W * C + w * C + c;
- default: return data_d.off(mb, c, h, w);
- }
- };
-
- auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) {
- const int c_st = nstl::max(oc - half_ksize + 0, 0);
- const int c_en = nstl::min(oc + half_ksize + 1, C);
-
- float A = 0, B = 0, omega_mid = 0;
- for (int c = c_st; c < c_en; c++) {
- float sum = 0.0;
- const int i_st = nstl::max(c - half_ksize, 0);
- const int i_en = nstl::min(c + kernel_size - half_ksize, C);
-
- for (int i = i_st; i < i_en; ++i) {
- const float value = src[data_off(mb, i, oh, ow)];
- sum += value * value;
- }
- const float omega = static_cast<float>(k + sum * alpha / kernel_size);
- if (c == oc) omega_mid = omega;
- float t = src[data_off(mb, c, oh, ow)]
- * fast_negative_powf(omega, beta);
- B += 1.0f / omega * t * diff_dst[data_off(mb, c, oh, ow)];
- }
-
- const size_t off = data_off(mb, oc, oh, ow);
- A = fast_negative_powf(omega_mid, beta) * diff_dst[off];
- B *= src[off];
- B *= (2.0f * alpha * beta) / kernel_size;
- *d = static_cast<data_t>(A - B); // final cast down to data_t
- };
-
- if (tag == nChw16c || tag == nChw8c) {
- parallel_nd(MB, utils::div_up(C, blksize), H, W,
- [&](int mb, int c_blk, int h, int w) {
- int c = c_blk * blksize;
- const size_t off = mb * stride_mb + c * H * W +
- (h * W + w) * blksize;
- PRAGMA_OMP_SIMD()
- for (int cc = 0; cc < nstl::min(blksize, C - c); ++cc)
- ker(&diff_src[off + cc], mb, c + cc, h, w);
- });
- } else if (tag == nhwc) {
- parallel_nd(MB, H, W, C,
- [&](int mb, int h, int w, int c) {
- const size_t off = mb * stride_mb + h * W * C + w * C + c;
- ker(&diff_src[off], mb, c, h, w);
- });
- } else {
- parallel_nd(MB, C, H, W,
- [&](int mb, int c, int h, int w) {
- const size_t off = data_off(mb, c, h, w);
- ker(&diff_src[off], mb, c, h, w);
- });
- }
-}
-
-template void ref_lrn_fwd_t<data_type::f32>::
-execute_forward<format_tag::nChw16c>(const exec_ctx_t &ctx) const;
-template void ref_lrn_fwd_t<data_type::f32>::
-execute_forward<format_tag::nChw8c>(const exec_ctx_t &ctx) const;
-template void ref_lrn_fwd_t<data_type::f32>::
-execute_forward<format_tag::nchw>(const exec_ctx_t &ctx) const;
-template void ref_lrn_fwd_t<data_type::f32>::
-execute_forward<format_tag::nhwc>(const exec_ctx_t &ctx) const;
-template void ref_lrn_fwd_t<data_type::f32>::
-execute_forward<format_tag::any>(const exec_ctx_t &ctx) const;
-template void ref_lrn_bwd_t<data_type::f32>::
-execute_backward<format_tag::nChw16c>(const exec_ctx_t &ctx) const;
-template void ref_lrn_bwd_t<data_type::f32>::
-execute_backward<format_tag::nChw8c>(const exec_ctx_t &ctx) const;
-template void ref_lrn_bwd_t<data_type::f32>::
-execute_backward<format_tag::nchw>(const exec_ctx_t &ctx) const;
-template void ref_lrn_bwd_t<data_type::f32>::
-execute_backward<format_tag::nhwc>(const exec_ctx_t &ctx) const;
-template void ref_lrn_bwd_t<data_type::f32>::
-execute_backward<format_tag::any>(const exec_ctx_t &ctx) const;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp
deleted file mode 100644
index f25cfb7fae..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp
+++ /dev/null
@@ -1,136 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_REF_LRN_HPP
-#define CPU_REF_LRN_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_lrn_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <impl::data_type_t data_type>
-struct ref_lrn_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_lrn_fwd_pd_t {
- using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t;
-
- DECLARE_COMMON_PD_T("ref:any", ref_lrn_fwd_t);
-
- status_t init() {
- using namespace format_tag;
-
- bool ok = true
- && is_fwd()
- && src_md()->data_type == data_type
- && attr()->has_default_values();
- if (!ok) return status::unimplemented;
-
- dat_tag_ = memory_desc_matches_one_of_tag(
- *src_md(), nChw16c, nChw8c, nchw, nhwc);
-
- return status::success;
- }
-
- format_tag_t dat_tag_;
- };
-
- ref_lrn_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
- typedef typename prec_traits<data_type>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- using namespace format_tag;
- switch (pd()->dat_tag_) {
- case nChw16c: execute_forward<nChw16c>(ctx); break;
- case nChw8c: execute_forward<nChw8c>(ctx); break;
- case nchw: execute_forward<nchw>(ctx); break;
- case nhwc: execute_forward<nhwc>(ctx); break;
- default: execute_forward<any>(ctx);
- }
- return status::success;
- }
-
-private:
- template<format_tag_t tag>
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-template <impl::data_type_t data_type>
-struct ref_lrn_bwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_lrn_bwd_pd_t {
- using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t;
-
- DECLARE_COMMON_PD_T("ref:any", ref_lrn_bwd_t);
-
- status_t init() {
- using namespace format_tag;
- using namespace alg_kind;
-
- bool ok = true
- && !is_fwd()
- && utils::one_of(desc()->alg_kind, lrn_across_channels
- /*, lrn_within_channel */) // not supported yet
- && utils::everyone_is(data_type,
- src_md()->data_type,
- diff_src_md()->data_type)
- && attr()->has_default_values();
- if (!ok) return status::unimplemented;
-
- dat_tag_ = memory_desc_matches_one_of_tag(
- *src_md(), nChw16c, nChw8c, nchw, nhwc);
-
- return status::success;
- }
-
- format_tag_t dat_tag_;
- };
-
- ref_lrn_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
- typedef typename prec_traits<data_type>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- using namespace format_tag;
- switch (pd()->dat_tag_) {
- case nChw16c: execute_backward<nChw16c>(ctx); break;
- case nChw8c: execute_backward<nChw8c>(ctx); break;
- case nchw: execute_backward<nchw>(ctx); break;
- case nhwc: execute_backward<nhwc>(ctx); break;
- default: execute_backward<any>(ctx);
- }
- return status::success;
- }
-
-private:
- template<format_tag_t tag>
- void execute_backward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp
deleted file mode 100644
index 65b934e123..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp
+++ /dev/null
@@ -1,381 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include <math.h>
-
-#include "c_types_map.hpp"
-#include "math_utils.hpp"
-#include "mkldnn_thread.hpp"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-
-#include "ref_pooling.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <data_type_t data_type, data_type_t acc_type>
-void ref_pooling_fwd_t<data_type, acc_type>::execute_forward(
- const exec_ctx_t &ctx) const {
- using namespace alg_kind;
- using namespace prop_kind;
-
- auto alg = pd()->desc()->alg_kind;
-
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
- auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE);
-
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
- const memory_desc_wrapper ws_d(pd()->workspace_md());
- const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
-
- const int ID = pd()->ID();
- const int IH = pd()->IH();
- const int IW = pd()->IW();
- const int KD = pd()->KD();
- const int KH = pd()->KH();
- const int KW = pd()->KW();
- const int SD = pd()->KSD();
- const int SH = pd()->KSH();
- const int SW = pd()->KSW();
- const int padF = pd()->padFront();
- const int padT = pd()->padT();
- const int padL = pd()->padL();
-
- const bool is_3d = pd()->desc()->src_desc.ndims == 5;
-
- auto apply_offset = [=](int index, int offset) {
- return (index > offset) ? index - offset : 0;
- };
-
- auto set_ws = [=](int mb, int oc, int od, int oh, int ow, int value) {
- if (ws) {
- assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
- size_t offset = is_3d
- ? ws_d.off(mb, oc, od, oh, ow) : ws_d.off(mb, oc, oh, ow);;
- if (ws_dt == data_type::u8) {
- assert(0 <= value && value <= 255);
- ws[offset] = value;
- } else
- reinterpret_cast<int *>(ws)[offset] = value;
- }
- };
-
- auto ker_max = [=](data_t *d, int mb, int oc, int oh, int ow) {
- for (int kh = 0; kh < KH; ++kh) {
- for (int kw = 0; kw < KW; ++kw) {
- const int ih = oh * SH - padT + kh;
- const int iw = ow * SW - padL + kw;
-
- if (ih < 0 || ih >= IH) continue;
- if (iw < 0 || iw >= IW) continue;
-
- auto s = src[src_d.off(mb, oc, ih, iw)];
- if (s > d[0]) {
- d[0] = s;
- set_ws(mb, oc, 1, oh, ow, kh*KW + kw);
- }
- }
- }
- };
-
- auto ker_avg = [=](data_t *d, int mb, int oc, int oh, int ow) {
- auto ih_start = apply_offset(oh*SH, padT);
- auto iw_start = apply_offset(ow*SW, padL);
- auto ih_end = nstl::min(oh*SH - padT + KH, IH);
- auto iw_end = nstl::min(ow*SW - padL + KW, IW);
-
- auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH
- : (ih_end - ih_start)*(iw_end - iw_start);
-
- acc_data_t dst = 0;
- for (int ih = ih_start; ih < ih_end; ++ih) {
- for (int iw = iw_start; iw < iw_end; ++iw) {
- dst += src[src_d.off(mb, oc, ih, iw)];
- }
- }
-
- d[0] = math::out_round<data_t>((float)dst / num_summands);
- };
-
- auto ker_max_3d = [=](data_t *d, int mb, int oc, int od, int oh, int ow) {
- for (int kd = 0; kd < KD; ++kd) {
- for (int kh = 0; kh < KH; ++kh) {
- for (int kw = 0; kw < KW; ++kw) {
- const int id = od * SD - padF + kd;
- const int ih = oh * SH - padT + kh;
- const int iw = ow * SW - padL + kw;
-
- if (id < 0 || id >= ID) continue;
- if (ih < 0 || ih >= IH) continue;
- if (iw < 0 || iw >= IW) continue;
-
- auto s = src[src_d.off(mb, oc, id, ih, iw)];
- if (s > d[0]) {
- d[0] = s;
- set_ws(mb, oc, od, oh, ow, kd * KH * KW + kh*KW + kw);
- }
- }
- }
- }
- };
-
- auto ker_avg_3d = [=](data_t *d, int mb, int oc, int od, int oh, int ow) {
- auto id_start = apply_offset(od*SD, padF);
- auto ih_start = apply_offset(oh*SH, padT);
- auto iw_start = apply_offset(ow*SW, padL);
- auto id_end = nstl::min(od*SD - padF + KD, ID);
- auto ih_end = nstl::min(oh*SH - padT + KH, IH);
- auto iw_end = nstl::min(ow*SW - padL + KW, IW);
-
- auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH*KD
- : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start);
-
- acc_data_t dst = 0;
- for (int id = id_start; id < id_end; ++id) {
- for (int ih = ih_start; ih < ih_end; ++ih) {
- for (int iw = iw_start; iw < iw_end; ++iw) {
- dst += src[src_d.off(mb, oc, id, ih, iw)];
- }
- }
- }
-
- d[0] = math::out_round<data_t>((float)dst / num_summands);
- };
-
- const int MB = pd()->MB();
- const int OC = pd()->C();
- const int OD = pd()->OD();
- const int OH = pd()->OH();
- const int OW = pd()->OW();
-
- if (alg == pooling_max) {
- parallel_nd(MB, OC, OD, OH, OW,
- [&](int mb, int oc, int od, int oh, int ow) {
- data_t *d = is_3d
- ? &dst[dst_d.off(mb, oc, od, oh, ow)]
- : &dst[dst_d.off(mb, oc, oh, ow)];
- d[0] = nstl::numeric_limits<data_t>::lowest();
- set_ws(mb, oc, od, oh, ow, 0);
- if (is_3d) ker_max_3d(d, mb, oc, od, oh, ow);
- else ker_max(d, mb, oc, oh, ow);
- });
- } else {
- parallel_nd(MB, OC, OD, OH, OW,
- [&](int mb, int oc, int od, int oh, int ow) {
- data_t *d = is_3d
- ? &dst[dst_d.off(mb, oc, od, oh, ow)]
- : &dst[dst_d.off(mb, oc, oh, ow)];
- d[0] = 0;
- if (is_3d) ker_avg_3d(d, mb, oc, od, oh, ow);
- else ker_avg(d, mb, oc, oh, ow);
- });
- }
-}
-
-template <data_type_t data_type, data_type_t acc_type>
-void ref_pooling_bwd_t<data_type, acc_type>::execute_backward(
- const exec_ctx_t &ctx) const {
- using namespace alg_kind;
-
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE);
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
- const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
- const memory_desc_wrapper ws_d(pd()->workspace_md());
-
- const int ID = pd()->ID();
- const int IH = pd()->IH();
- const int IW = pd()->IW();
- const int KD = pd()->KD();
- const int KH = pd()->KH();
- const int KW = pd()->KW();
- const int SD = pd()->KSD();
- const int SH = pd()->KSH();
- const int SW = pd()->KSW();
- const int padF = pd()->padFront();
- const int padT = pd()->padT();
- const int padL = pd()->padL();
-
- const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
-
- auto alg = pd()->desc()->alg_kind;
-
- auto apply_offset = [=](int index, int offset) {
- return (index > offset) ? index - offset : 0;
- };
-
- auto ker_zero = [=](int _mb, int _oc) {
- for (int ih = 0; ih < IH; ++ih) {
- for (int iw = 0; iw < IW; ++iw) {
- diff_src[diff_src_d.off(_mb, _oc, ih, iw)] = data_type_t(0);
- }
- }
- };
-
- auto ker_max = [=](const data_t *d, int mb, int oc, int oh, int ow) {
- const size_t ws_off = ws_d.off(mb, oc, oh, ow);
- const int index = ws_d.data_type() == data_type::u8
- ? (int)ws[ws_off] : ((int *)ws)[ws_off];
- const int kw = index % KW;
- const int kh = index / KW;
- const int ih = oh * SH - padT + kh;
- const int iw = ow * SW - padL + kw;
-
- // If padding area could fit the kernel,
- // then input displacement would be out of bounds.
- // No need to back propagate there as padding is
- // virtual in pooling_max case.
- if (ih < 0 || ih >= IH)
- return;
- if (iw < 0 || iw >= IW)
- return;
-
- diff_src[diff_src_d.off(mb, oc, ih, iw)] += d[0];
- };
-
- auto ker_avg = [=](const data_t *d, int mb, int oc, int oh, int ow) {
- auto ih_start = apply_offset(oh*SH, padT);
- auto iw_start = apply_offset(ow*SW, padL);
- auto ih_end = nstl::min(oh*SH - padT + KH, IH);
- auto iw_end = nstl::min(ow*SW - padL + KW, IW);
-
- auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH
- : (ih_end - ih_start)*(iw_end - iw_start);
-
- for (int ih = ih_start; ih < ih_end; ++ih) {
- for (int iw = iw_start; iw < iw_end; ++iw) {
- diff_src[diff_src_d.off(mb, oc, ih, iw)] += d[0] / num_summands;
- }
- }
- };
-
- auto ker_zero_3d = [=](int _mb, int _oc) {
- for (int id = 0; id < ID; ++id) {
- for (int ih = 0; ih < IH; ++ih) {
- for (int iw = 0; iw < IW; ++iw) {
- diff_src[diff_src_d.off(_mb, _oc, id, ih, iw)] =
- data_type_t(0);
- }
- }
- }
- };
-
- auto ker_max_3d = [=](const data_t *d, int mb, int oc, int od, int oh,
- int ow) {
- const size_t ws_off = ws_d.off(mb, oc, od, oh, ow);
- const int index = ws_d.data_type() == data_type::u8
- ? (int)ws[ws_off] : ((int *)ws)[ws_off];
- const int kw = index % KW;
- const int kh = (index / KW) % KH;
- const int kd = (index / KW) / KH;
- const int id = od * SD - padF + kd;
- const int ih = oh * SH - padT + kh;
- const int iw = ow * SW - padL + kw;
-
- // If padding area could fit the kernel,
- // then input displacement would be out of bounds.
- // No need to back propagate there as padding is
- // virtual in pooling_max case.
- if (id < 0 || id >= ID)
- return;
- if (ih < 0 || ih >= IH)
- return;
- if (iw < 0 || iw >= IW)
- return;
-
- diff_src[diff_src_d.off(mb, oc, id, ih, iw)] += d[0];
- };
-
- auto ker_avg_3d = [=](const data_t *d, int mb, int oc, int od, int oh,
- int ow) {
- auto id_start = apply_offset(od*SD, padF);
- auto ih_start = apply_offset(oh*SH, padT);
- auto iw_start = apply_offset(ow*SW, padL);
- auto id_end = nstl::min(od*SD - padF + KD, ID);
- auto ih_end = nstl::min(oh*SH - padT + KH, IH);
- auto iw_end = nstl::min(ow*SW - padL + KW, IW);
-
- auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH*KD
- : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start);
-
- for (int id = id_start; id < id_end; ++id)
- for (int ih = ih_start; ih < ih_end; ++ih)
- for (int iw = iw_start; iw < iw_end; ++iw) {
- diff_src[diff_src_d.off(mb, oc, id, ih, iw)] += d[0] / num_summands;
- }
- };
-
- const int MB = pd()->MB();
- const int OC = pd()->C();
- const int OD = pd()->OD();
- const int OH = pd()->OH();
- const int OW = pd()->OW();
-
- if (pd()->desc()->alg_kind == alg_kind::pooling_max) {
- parallel_nd(MB, OC, [&](int mb, int oc) {
- if (is_3d) ker_zero_3d(mb, oc);
- else ker_zero(mb, oc);
- for (int od = 0; od < OD; ++od) {
- for (int oh = 0; oh < OH; ++oh) {
- for (int ow = 0; ow < OW; ++ow) {
- const data_t *d = is_3d
- ? &diff_dst[diff_dst_d.off(mb, oc, od, oh, ow)]
- : &diff_dst[diff_dst_d.off(mb, oc, oh, ow)];
- if (is_3d) ker_max_3d(d, mb, oc, od, oh, ow);
- else ker_max(d, mb, oc, oh, ow);
- }
- }
- }
- });
- } else {
- parallel_nd(MB, OC, [&](int mb, int oc) {
- if (is_3d) ker_zero_3d(mb, oc);
- else ker_zero(mb, oc);
- for (int od = 0; od < OD; ++od) {
- for (int oh = 0; oh < OH; ++oh) {
- for (int ow = 0; ow < OW; ++ow) {
- const data_t *d = is_3d
- ? &diff_dst[diff_dst_d.off(mb, oc, od, oh, ow)]
- : &diff_dst[diff_dst_d.off(mb, oc, oh, ow)];
- if (is_3d) ker_avg_3d(d, mb, oc, od, oh, ow);
- else ker_avg(d, mb, oc, oh, ow);
- }
- }
- }
- });
- }
-}
-
-template struct ref_pooling_fwd_t<data_type::f32>;
-template struct ref_pooling_fwd_t<data_type::s32>;
-template struct ref_pooling_fwd_t<data_type::s8, data_type::s32>;
-template struct ref_pooling_fwd_t<data_type::u8, data_type::s32>;
-
-template struct ref_pooling_bwd_t<data_type::f32>;
-template struct ref_pooling_bwd_t<data_type::s32>;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp
deleted file mode 100644
index e43ceaa82b..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp
+++ /dev/null
@@ -1,119 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_REF_POOLING_HPP
-#define CPU_REF_POOLING_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_pooling_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <impl::data_type_t data_type, impl::data_type_t acc_type = data_type>
-struct ref_pooling_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_pooling_fwd_pd_t {
- using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t;
-
- DECLARE_COMMON_PD_T("ref:any", ref_pooling_fwd_t);
-
- status_t init() {
- bool ok = true
- && set_default_params() == status::success
- && is_fwd()
- && utils::everyone_is(data_type, src_md()->data_type,
- dst_md()->data_type)
- && desc()->accum_data_type == acc_type
- && attr()->has_default_values();
- if (!ok) return status::unimplemented;
-
- bool is_training = desc_.prop_kind == prop_kind::forward_training;
- if (desc()->alg_kind == alg_kind::pooling_max && is_training)
- init_default_ws();
-
- return status::success;
- }
- };
-
- ref_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
-
- typedef typename prec_traits<data_type>::type data_t;
- typedef typename prec_traits<acc_type>::type acc_data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_forward(ctx);
- return status::success;
- }
-
-private:
- void execute_forward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-template <impl::data_type_t data_type, impl::data_type_t acc_type = data_type>
-struct ref_pooling_bwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_pooling_bwd_pd_t {
- using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t;
-
- DECLARE_COMMON_PD_T("ref:any", ref_pooling_bwd_t);
-
- status_t init() {
- bool ok = true
- && set_default_params() == status::success
- && !is_fwd()
- && utils::everyone_is(data_type, diff_dst_md()->data_type,
- diff_src_md()->data_type)
- && attr()->has_default_values();
- if (!ok) return status::unimplemented;
-
- if (desc()->alg_kind == alg_kind::pooling_max) {
- init_default_ws();
- if (!compare_ws(hint_fwd_pd_))
- return status::unimplemented;
- }
-
- return status::success;
- }
- };
-
- ref_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
- typedef typename prec_traits<data_type>::type data_t;
- typedef typename prec_traits<acc_type>::type acc_data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_backward(ctx);
- return status::success;
- }
-
-private:
- void execute_backward(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp
deleted file mode 100644
index af27743110..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp
+++ /dev/null
@@ -1,153 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include <math.h>
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-
-#include "ref_shuffle.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace format_tag;
-
-template <int data_type_size>
-template <mkldnn_format_tag_t tag>
-void ref_shuffle_t<data_type_size>::execute_(const exec_ctx_t &ctx) const {
- using namespace prop_kind;
- using namespace utils;
-
- const memory_desc_wrapper data_d(pd()->data_md());
-
- auto i_arg = pd()->is_fwd() ? MKLDNN_ARG_SRC : MKLDNN_ARG_DIFF_DST;
- auto o_arg = pd()->is_fwd() ? MKLDNN_ARG_DST : MKLDNN_ARG_DIFF_SRC;
- auto input = CTX_IN_MEM(const data_t *, i_arg);
- auto output = CTX_OUT_MEM(data_t *, o_arg);
-
- const int axis = pd()->axis();
- const int axis_size = pd()->axis_size();
-
- const int MB = pd()->MB();
- const int C = pd()->C();
- int H = 1, W = 1, D = 1, HW = 1, SP = 1;
- const bool has_spatial = utils::one_of(data_d.ndims(), 3, 4 ,5);
- if (has_spatial)
- {
- D = pd()->D();
- H = pd()->H();
- W = pd()->W();
- HW = H * W;
- SP = D * HW;
- }
- const size_t stride_mb = data_d.blocking_desc().strides[0];
- constexpr int blksize = one_of(tag, nChw16c, nCdhw16c) ? 16 : 8;
-
- if (axis == 1 && one_of(tag, nChw16c, nChw8c, nCdhw16c, nCdhw16c)) {
-#if MKLDNN_THR == MKLDNN_THR_OMP
-# pragma omp parallel for collapse(3) schedule(static)
- for (int mb = 0; mb < MB; ++mb)
- for (int cb = 0; cb < C; cb += blksize)
- for (int sp = 0; sp < SP; ++sp) {
- const size_t off = mb * stride_mb + sp * blksize;
- const size_t output_off = off + cb * SP;
- PRAGMA_OMP_SIMD()
- for (int cc = 0; cc < nstl::min(blksize, C - cb); ++cc)
- {
- int input_c = rev_transposed_[cb + cc];
- const size_t input_off = off + input_c / blksize * SP * blksize
- + input_c % blksize;
- output[output_off + cc] = input[input_off];
- }
- }
-#else
- parallel_nd(MB, utils::div_up(C, blksize), SP, [&](int mb, int c,
- int sp) {
- const size_t off = mb * stride_mb + sp * blksize;
- const int cb = c * blksize;
- const size_t output_off = off + cb * SP;
- for (int cc = 0; cc < nstl::min(blksize, C - cb); ++cc)
- {
- int input_c = rev_transposed_[cb + cc];
- const size_t input_off = off + input_c / blksize * SP * blksize
- + input_c % blksize;
- output[output_off + cc] = input[input_off];
- }
- });
-#endif
- } else if (axis == 1 && one_of(tag, nhwc, ndhwc)) {
- parallel_nd(MB, SP, [&](int mb, int sp) {
- const size_t off = mb * stride_mb + sp * C;
- PRAGMA_OMP_SIMD()
- for (int c = 0; c < C; ++c)
- output[off + c] = input[off + rev_transposed_[c]];
- });
- } else if (axis == 1 && one_of(tag, nchw, ncdhw)) {
- parallel_nd(MB, C, [&](int mb, int c) {
- const size_t output_off = mb * stride_mb + c * SP;
- const size_t input_off = mb * stride_mb + rev_transposed_[c] * SP;
- PRAGMA_OMP_SIMD()
- for (int sp = 0; sp < SP; ++sp) {
- output[output_off + sp] = input[input_off + sp];
- }
- });
- } else {
- auto dims = pd()->desc()->data_desc.dims;
- auto ndims = pd()->desc()->data_desc.ndims;
- const size_t outer_size = utils::array_product(dims, axis);
- const size_t inner_size = utils::array_product(dims + axis + 1,
- ndims - axis - 1);
- const size_t dim = axis_size * inner_size;
-
- parallel_nd(outer_size, axis_size, inner_size, [&](size_t ou, int a,
- size_t in)
- {
- const size_t off = ou * dim + in;
- auto &o = output[data_d.off_l(off + a * inner_size)];
- o = input[data_d.off_l(off + rev_transposed_[a] * inner_size)];
- });
- }
-}
-
-template void ref_shuffle_t<4>::execute_<nCdhw16c>(const exec_ctx_t &ctx) const;
-template void ref_shuffle_t<4>::execute_<nChw16c>(const exec_ctx_t &ctx) const;
-template void ref_shuffle_t<4>::execute_<nCdhw8c>(const exec_ctx_t &ctx) const;
-template void ref_shuffle_t<4>::execute_<nChw8c>(const exec_ctx_t &ctx) const;
-template void ref_shuffle_t<4>::execute_<ncdhw>(const exec_ctx_t &ctx) const;
-template void ref_shuffle_t<4>::execute_<nchw>(const exec_ctx_t &ctx) const;
-template void ref_shuffle_t<4>::execute_<ndhwc>(const exec_ctx_t &ctx) const;
-template void ref_shuffle_t<4>::execute_<nhwc>(const exec_ctx_t &ctx) const;
-template void ref_shuffle_t<4>::execute_<any>(const exec_ctx_t &ctx) const;
-
-template void ref_shuffle_t<1>::execute_<nCdhw16c>(const exec_ctx_t &ctx) const;
-template void ref_shuffle_t<1>::execute_<nChw16c>(const exec_ctx_t &ctx) const;
-template void ref_shuffle_t<1>::execute_<nCdhw8c>(const exec_ctx_t &ctx) const;
-template void ref_shuffle_t<1>::execute_<nChw8c>(const exec_ctx_t &ctx) const;
-template void ref_shuffle_t<1>::execute_<ncdhw>(const exec_ctx_t &ctx) const;
-template void ref_shuffle_t<1>::execute_<nchw>(const exec_ctx_t &ctx) const;
-template void ref_shuffle_t<1>::execute_<ndhwc>(const exec_ctx_t &ctx) const;
-template void ref_shuffle_t<1>::execute_<nhwc>(const exec_ctx_t &ctx) const;
-template void ref_shuffle_t<1>::execute_<any>(const exec_ctx_t &ctx) const;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp
deleted file mode 100644
index 5e09a1a69b..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp
+++ /dev/null
@@ -1,111 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_REF_SHUFFLE_HPP
-#define CPU_REF_SHUFFLE_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_shuffle_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template<int data_type_size>
-struct ref_shuffle_t : public cpu_primitive_t {
- using shuffle_class = ref_shuffle_t<data_type_size>;
-
- struct pd_t: public cpu_shuffle_pd_t {
- using cpu_shuffle_pd_t::cpu_shuffle_pd_t;
-
- DECLARE_COMMON_PD_T("ref:any", shuffle_class);
-
- status_t init() {
- using namespace format_tag;
-
- bool ok = true
- && data_type_size
- == types::data_type_size(data_md()->data_type);
- if (!ok) return status::unimplemented;
-
- if (ndims() == 5) {
- dat_tag_ = memory_desc_matches_one_of_tag(
- *data_md(), nCdhw16c, nCdhw8c, ncdhw, ndhwc);
- } else if (ndims() == 4) {
- dat_tag_ = memory_desc_matches_one_of_tag(
- *data_md(), nChw16c, nChw8c, nchw, nhwc);
- } else
- dat_tag_ = any;
-
- return status::success;
- }
-
- format_tag_t dat_tag_;
- };
-
- ref_shuffle_t(const pd_t *apd): cpu_primitive_t(apd) {
- const int axis_size = pd()->axis_size();
- const int group_size = pd()->group_size();
- const int transpose_row = pd()->is_fwd() ? group_size
- : axis_size / group_size;
- const int transpose_col = pd()->is_fwd() ? axis_size / group_size
- : group_size;
- rev_transposed_ = (int *)malloc(axis_size * sizeof(int), 64);
- parallel_nd(transpose_col, transpose_row, [&](int i, int j) {
- rev_transposed_[j * transpose_col + i] = i * transpose_row + j;
- });
- }
-
- ~ref_shuffle_t() { free(rev_transposed_); }
-
- typedef typename typesize_traits<data_type_size>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- using namespace format_tag;
- switch (pd()->dat_tag_) {
- case nCdhw16c: execute_<nCdhw16c>(ctx); break;
- case nChw16c: execute_<nChw16c>(ctx); break;
- case nCdhw8c: execute_<nCdhw8c>(ctx); break;
- case nChw8c: execute_<nChw8c>(ctx); break;
- case ncdhw: execute_<ncdhw>(ctx); break;
- case nchw: execute_<nchw>(ctx); break;
- case ndhwc: execute_<ndhwc>(ctx); break;
- case nhwc: execute_<nhwc>(ctx); break;
- default: execute_<any>(ctx); break;
- }
- return status::success;
- }
-
-private:
- template<format_tag_t tag>
- void execute_(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- int *rev_transposed_;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp
deleted file mode 100644
index 36d5237f56..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp
+++ /dev/null
@@ -1,264 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include <assert.h>
-#include <float.h>
-#include <math.h>
-
-#include "c_types_map.hpp"
-#include "mkldnn_thread.hpp"
-#include "type_helpers.hpp"
-
-#include "ref_softmax.hpp"
-#include "gemm/os_blas.hpp"
-
-#ifdef USE_MKL
-#include "mkl_vml_functions.h"
-#endif
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <impl::data_type_t data_type>
-void ref_softmax_fwd_t<data_type>::execute_forward_dense(
- const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
-
- parallel_nd(outer_size_, [&](int ou) {
- const data_t *src_data = src + ou * channels_;
- data_t *dst_data = dst + ou * channels_;
- data_t scalar = 0;
-
- _max(channels_, src_data, &scalar);
- _sub(channels_, scalar, src_data, dst_data);
- _exp(channels_, dst_data, dst_data);
- _sum(channels_, dst_data, &scalar);
- _scal(channels_, data_t(1)/scalar, dst_data);
- });
-}
-
-template <impl::data_type_t data_type>
-void ref_softmax_fwd_t<data_type>::execute_forward_generic(
- const exec_ctx_t &ctx) const {
- auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
-
- data_t space_max_val = 0, space_denom_val = 0;
- data_t *space_max = &space_max_val, *space_denom = &space_denom_val;
- if (inner_size_ > 1) {
- using namespace memory_tracking::names;
- space_max = scratchpad(ctx).template get<data_t>(key_softmax_reduction);
- space_denom = space_max + inner_size_;
- }
-
- const memory_desc_wrapper data_d(pd()->src_md());
- const size_t dim = channels_ * inner_size_;
-
- for (int ou = 0; ou < outer_size_; ou++) {
- utils::array_set(space_max, -FLT_MAX, inner_size_);
- utils::array_set(space_denom, 0, inner_size_);
-
- for (int c = 0; c < channels_; c++) {
- for(int in = 0; in < inner_size_; in++) {
- size_t off = data_d.off_l(ou * dim + c * inner_size_ + in);
- space_max[in] = nstl::max(space_max[in], src[off]);
- }
- }
-
- for (int c = 0; c < channels_; c++) {
- for(int in = 0; in < inner_size_; in++) {
- size_t off = data_d.off_l(ou * dim + c * inner_size_ + in);
- space_denom[in] += dst[off] = exp(src[off] - space_max[in]);
- }
- }
-
- for (int c = 0; c < channels_; c++) {
- for (int in = 0; in < inner_size_; in++) {
- size_t off = data_d.off_l(ou * dim + c * inner_size_ + in);
- dst[off] /= space_denom[in];
- }
- }
- }
-}
-
-template <impl::data_type_t data_type>
-void ref_softmax_fwd_t<data_type>::_max(int n, const data_t *x,
- data_t *max_data) const {
-// Intel(R) C++ Compiler generates the maxps + shuffle pattern
-// for the max search which works faster
-#if !defined(__INTEL_COMPILER)
- // The code below makes a compiler to generate maxps instruction
- // rather than maxss, which is generated for the 'else' code path
- auto max_wrapper = [](data_t a, data_t b) { return nstl::max(a, b); };
- auto min_wrapper = [](int a, int b) { return nstl::min(a, b); };
-
- constexpr int unroll_factor = 32;
- data_t max_values[unroll_factor];
-
- if (n < unroll_factor) {
- data_t max_val = x[0];
- for (int i = 1; i < n; i++) {
- max_val = max_wrapper(max_val, x[i]);
- }
- max_data[0] = max_val;
- return;
- }
- for (int i = 0; i < unroll_factor; i++) {
- max_values[i] = x[i];
- }
- for (int i = unroll_factor; i < n; i += unroll_factor) {
- int offset = min_wrapper(i, n - unroll_factor);
- for (int j = 0; j < unroll_factor; j++) {
- max_values[j] = max_wrapper(max_values[j], x[offset + j]);
- }
- }
- data_t max_val = max_values[0];
- for (int i = 1; i < unroll_factor; i++) {
- max_val = max_wrapper(max_val, max_values[i]);
- }
- max_data[0] = max_val;
-#else
- max_data[0] = x[0];
- for (int c = 1; c < n; ++c)
- max_data[0] = nstl::max(max_data[0], x[c]);
-#endif
-}
-
-template <impl::data_type_t data_type>
-void ref_softmax_fwd_t<data_type>::_sub(int n, data_t alpha, const data_t *x,
- data_t *y) const {
- constexpr int unroll_factor = 32;
- int tail = n % unroll_factor;
- for (int i = 0; i < n - tail; i += unroll_factor) {
- PRAGMA_OMP_SIMD()
- for (int j = 0; j < unroll_factor; j++) {
- y[i + j] = x[i + j] - alpha;
- }
- }
- PRAGMA_OMP_SIMD()
- for (int i = n - tail; i < n; i++) {
- y[i] = x[i] - alpha;
- }
-}
-
-template <impl::data_type_t data_type>
-void ref_softmax_fwd_t<data_type>::_exp(int n, const data_t *a,
- data_t *r) const {
-#ifdef USE_MKL
- if (data_type == data_type::f32) {
- vsExp(n, a, r);
- return;
- }
-#endif
- parallel_nd(n, [&](int c) { r[c] = expf(a[c]); });
-}
-
-template <impl::data_type_t data_type>
-void ref_softmax_fwd_t<data_type>::_sum(int n, const data_t *x,
- data_t *sum_data) const {
-#ifdef USE_CBLAS
- // Here we are summing x's eg. e^z , which are positives
- // so we can use BLAS ASUM
- if (data_type == data_type::f32) {
- sum_data[0] = cblas_sasum(n, x, 1);
- return;
- }
-#endif
- data_t tsum = static_cast<data_t>(0);
- PRAGMA_OMP_SIMD(reduction(+ : tsum))
- for (int c = 0; c < n; ++c)
- tsum += x[c];
- sum_data[0] = tsum;
-}
-
-template <impl::data_type_t data_type>
-void ref_softmax_fwd_t<data_type>::_scal(int n, data_t alpha, data_t *x) const {
-#ifdef USE_CBLAS
- if (data_type == data_type::f32) {
- cblas_sscal(n, alpha, x, 1);
- return;
- }
-#endif
- parallel_nd(n, [&](int c) { x[c] *= alpha; });
-}
-
-template struct ref_softmax_fwd_t<data_type::f32>;
-
-
-// NC/NCHW softmax for along final axe (1 for NC, 3 for NCHW)
-template <impl::data_type_t data_type>
-void ref_softmax_bwd_t<data_type>::execute_backward_dense(
- const exec_ctx_t &ctx) const {
- auto dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DST);
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
-
- parallel_nd(outer_size_, [&](int ou) {
- data_t sbr = 0;
- size_t off = channels_*ou;
- for (int c = 0; c < channels_; c++) {
- size_t loff = off + c;
- data_t ldata = dst[loff];
- sbr += diff_dst[loff]*ldata;
- diff_src[loff] = ldata;
- }
-
- for(int c=0; c < channels_ ; ++c) {
- size_t loff = off + c;
- diff_src[loff] *= (diff_dst[loff] - sbr);
- }
- });
-}
-
-template <impl::data_type_t data_type>
-void ref_softmax_bwd_t<data_type>::execute_backward_generic(
- const exec_ctx_t &ctx) const {
- auto dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DST);
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
-
- const memory_desc_wrapper diff_d(pd()->diff_src_md());
- const memory_desc_wrapper data_d(pd()->dst_md());
-
- const size_t dim = channels_ * inner_size_;
-
- parallel_nd(outer_size_, [&](int ou) {
- for (int in = 0; in < inner_size_; in++) {
- data_t sbr = 0;
- for (int c = 0; c < channels_; c++) {
- size_t off_diff = diff_d.off_l(ou * dim + c * inner_size_ + in);
- size_t off_data = diff_d.off_l(ou * dim + c * inner_size_ + in);
- sbr += diff_dst[off_diff] * dst[off_data];
- }
-
- for(int c=0; c < channels_ ; ++c) {
- size_t off_diff = diff_d.off_l(ou * dim + c * inner_size_ + in);
- size_t off_data = data_d.off_l(ou * dim + c * inner_size_ + in);
- diff_src[off_diff] = dst[off_data] * (diff_dst[off_diff] - sbr);
- }
- }
- });
-}
-
-template struct ref_softmax_bwd_t<data_type::f32>;
-
-}
-}
-}
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp
deleted file mode 100644
index 5cb74d8007..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp
+++ /dev/null
@@ -1,186 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_REF_SOFTMAX_HPP
-#define CPU_REF_SOFTMAX_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "cpu_softmax_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <impl::data_type_t data_type>
-struct ref_softmax_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_softmax_fwd_pd_t {
- using cpu_softmax_fwd_pd_t::cpu_softmax_fwd_pd_t;
-
- DECLARE_COMMON_PD_T("ref:any", ref_softmax_fwd_t);
-
- status_t init() {
- bool ok = true
- && is_fwd()
- && src_md()->data_type == data_type
- && attr()->has_default_values();
- if (!ok) return status::unimplemented;
-
- init_scratchpad();
-
- return status::success;
- }
-
- private:
- void init_scratchpad() {
- const int inner_size = utils::array_product(
- desc()->data_desc.dims + desc()->softmax_axis + 1,
- desc()->data_desc.ndims - desc()->softmax_axis - 1);
-
- if (inner_size > 1) {
- auto scratchpad = scratchpad_registry().registrar();
- scratchpad.book(memory_tracking::names::key_softmax_reduction,
- sizeof(data_t) * 2 * inner_size);
- }
- }
- };
-
- ref_softmax_fwd_t(const pd_t *apd): cpu_primitive_t(apd)
- {
- auto ndims = pd()->desc()->data_desc.ndims;
- auto dims = pd()->desc()->data_desc.dims;
- auto axis = pd()->desc()->softmax_axis;
-
- outer_size_ = utils::array_product(dims, axis);
- channels_ = dims[axis];
- inner_size_ = utils::array_product(dims + axis + 1, ndims - axis - 1);
-
- const memory_desc_wrapper data_d(pd()->src_md());
-
- bool no_axis_blocking = true;
- for (int iblk = 0; iblk < data_d.blocking_desc().inner_nblks; ++iblk)
- if (data_d.blocking_desc().inner_idxs[iblk] == axis)
- no_axis_blocking = false;
-
- use_dense_ = inner_size_ == 1 && data_d.is_dense()
- && no_axis_blocking
- && data_d.blocking_desc().strides[axis] == 1;
- }
-
- typedef typename prec_traits<data_type>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- if (use_dense_)
- execute_forward_dense(ctx);
- else
- execute_forward_generic(ctx);
- return status::success;
- }
-
-private:
- void execute_forward_dense(const exec_ctx_t &ctx) const;
- void execute_forward_generic(const exec_ctx_t &ctx) const;
-
- void _max(int n, const data_t *x, data_t *max_data) const;
- void _sub(int n, data_t alpha, const data_t *x, data_t *y) const;
- void _exp(int n, const data_t *a, data_t *r) const;
- void _sum(int n, const data_t *x, data_t *sum_data) const;
- void _scal(int n, data_t alpha, data_t *x) const;
-
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- bool use_dense_;
- int outer_size_, channels_, inner_size_;
-};
-
-template <impl::data_type_t data_type>
-struct ref_softmax_bwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_softmax_bwd_pd_t {
- using cpu_softmax_bwd_pd_t::cpu_softmax_bwd_pd_t;
-
- DECLARE_COMMON_PD_T("ref:any", ref_softmax_bwd_t);
-
- status_t init() {
- bool ok = true
- && !is_fwd()
- && utils::everyone_is(data_type,
- dst_md()->data_type,
- diff_src_md()->data_type)
- && attr()->has_default_values();
- if (!ok) return status::unimplemented;
-
- return status::success;
- }
- };
-
- ref_softmax_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {
- auto dims = pd()->desc()->diff_desc.dims;
- auto axis = pd()->desc()->softmax_axis;
- auto ndims = pd()->desc()->diff_desc.ndims;
-
- outer_size_ = utils::array_product(dims, axis);
- channels_ = dims[axis];
- inner_size_ = utils::array_product(dims + axis + 1, ndims - axis - 1);
-
- const memory_desc_wrapper data_d(pd()->dst_md());
- const memory_desc_wrapper diff_d(pd()->diff_dst_md());
-
- bool no_axis_blocking = true;
- for (int iblk = 0; iblk < diff_d.blocking_desc().inner_nblks; ++iblk)
- if (diff_d.blocking_desc().inner_idxs[iblk] == axis)
- no_axis_blocking = false;
-
- use_dense_ = true
- && inner_size_ == 1
- && diff_d == data_d
- && diff_d.is_dense()
- && no_axis_blocking
- && diff_d.blocking_desc().strides[axis] == 1;
- }
-
- typedef typename prec_traits<data_type>::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- if (use_dense_)
- execute_backward_dense(ctx);
- else
- execute_backward_generic(ctx);
- return status::success;
- }
-
-private:
- void execute_backward_dense(const exec_ctx_t &ctx) const;
- void execute_backward_generic(const exec_ctx_t &ctx) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- bool use_dense_;
- int outer_size_, channels_, inner_size_;
-};
-
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp
deleted file mode 100644
index 3b2a75d99b..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp
+++ /dev/null
@@ -1,101 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef REF_SUM_HPP
-#define REF_SUM_HPP
-
-#include "reorder_pd.hpp"
-
-#include "cpu_sum_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct ref_sum_t: public cpu_primitive_t {
- struct pd_t: public cpu_sum_pd_t {
- using cpu_sum_pd_t::cpu_sum_pd_t;
-
- pd_t(const pd_t &rhs): cpu_sum_pd_t(rhs) {
- for (size_t i = 0; i < rhs.reorder_pds_.size(); ++i)
- reorder_pds_.push_back(
- (const reorder_pd_t *)rhs.reorder_pds_[i]->clone());
- }
-
- ~pd_t() { for (auto &rpd: reorder_pds_) delete rpd; }
-
- DECLARE_SUM_PD_T("ref:any", ref_sum_t);
-
- status_t init() {
- bool ok = cpu_sum_pd_t::init() == status::success;
- if (!ok) return status::unimplemented;
-
- for (int i = 0; i < n_; ++i) {
- auto r_impls = engine_->get_reorder_implementation_list();
- for (auto r = r_impls; *r; ++r) {
- primitive_attr_t attr;
- attr.output_scales_.set(scales_[i]);
- if (i != 0) attr.post_ops_.append_sum(1.0);
-
- reorder_pd_t *r_pd;
- if ((*r)(&r_pd, engine_, &attr, engine_, src_md(i),
- engine_, dst_md()) == status::success) {
- r_pd->init_info();
- reorder_pds_.push_back(r_pd);
- break;
- }
- }
- }
-
- ok = reorder_pds_.size() == (size_t)n_;
- return ok ? status::success : status::unimplemented;
- }
-
- nstl::vector<const reorder_pd_t *> reorder_pds_;
- };
-
- ref_sum_t(const pd_t *apd): cpu_primitive_t(apd) {
- const int n = pd()->n_inputs();
- reorders_.resize(n);
- for (int i = 0; i < n; ++i)
- pd()->reorder_pds_[i]->create_primitive(&reorders_[i]);
- }
-
- ~ref_sum_t() { for (auto &r: reorders_) delete r; }
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- const auto n = pd()->n_inputs();
- for (int i = 0; i < n; ++i) {
- exec_args_t r_args;
- r_args[MKLDNN_ARG_SRC] = ctx.args().at(MKLDNN_ARG_MULTIPLE_SRC + i);
- r_args[MKLDNN_ARG_DST] = ctx.args().at(MKLDNN_ARG_DST);
- exec_ctx_t r_ctx(ctx.stream(), std::move(r_args));
- reorders_[i]->execute(r_ctx);
- }
- return status::success;
- }
-
-private:
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- nstl::vector<primitive_t *> reorders_;
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp
deleted file mode 100644
index 537084db91..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp
+++ /dev/null
@@ -1,90 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-/*
- * Common for RNN and LSTM cell execution
- */
-#include "ref_rnn.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-using namespace rnn_utils;
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-rnn_cell_execution_sig(
- (_ref_rnn_common_t<aprop, src_type, weights_type>::cell_execution)) {
- if (!rnn.merge_gemm_layer) {
- (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb,
- rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld,
- states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_,
- rnn.gates_ws_ld);
- }
- (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic,
- 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_,
- rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld);
-
- if (rnn_postgemm_ != nullptr)
- rnn_postgemm_->execute<src_data_t, acc_data_t>(rnn, ws_gates_, states_t_l_, c_states_t_l_,
- states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
- diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
- ws_cell_);
- else
- (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_,
- states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
- diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
- ws_cell_);
-}
-template rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution);
-template rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution);
-
-template <>
-rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution) {
- ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
- (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_,
- states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
- diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
- ws_cell_);
-
- /// bwd by data on the cell
- (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic,
- 1.0, w_iter_[0], rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld,
- 0.0, diff_states_t_l_, rnn.states_ws_ld);
-
- if (!rnn.merge_gemm_layer) {
- (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
- rnn.n_gates * rnn.dic, 1.0, w_layer_[0],
- rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0,
- &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld);
-
- /// bwd by weights on the cell
- gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_,
- rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0,
- diff_w_layer_, rnn.diff_weights_layer_ld);
- }
-
- if (!rnn.merge_gemm_iter)
- gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_,
- rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0,
- diff_w_iter_, rnn.diff_weights_iter_ld);
-
- /// bwd by bias we just accumulate diffs from the gates
- gates_reduction(rnn, ws_gates_, diff_bias_);
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp
deleted file mode 100644
index e1a61d4c62..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp
+++ /dev/null
@@ -1,180 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-/*
- * Cell execution GRU
- */
-
-#include "math_utils.hpp"
-#include "mkldnn_thread.hpp"
-
-#include "ref_rnn.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::math;
-using namespace rnn_utils;
-
-#define AOC array_offset_calculator
-template <>
-rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru) {
- ws_gates_aoc_t ws_gates(rnn, ws_gates_);
- bias_aoc_t bias(rnn, bias_[0]);
- ws_states_aoc_t states_t_l(rnn, states_t_l_);
- ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
-
- // 1. gemm Wx[0-2],x
- if (!rnn.merge_gemm_layer) {
- (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb,
- rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld,
- states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_,
- rnn.gates_ws_ld);
- }
-
- // 2. gemm Wh[0-1],h
- (this->*gemm_iter_func)('N', 'N', (rnn.n_gates - 1) * rnn.dic, rnn.mb,
- rnn.sic, 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_,
- rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld);
-
- // 3. activation zt and rt + elemwise multiplication rt,ht-1
- parallel_nd(rnn.mb, [&](int i) {
- PRAGMA_OMP_SIMD()
- for (int j = 0; j < rnn.dic; j++) {
- ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j));
- ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j));
- states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 1, j);
- }
- });
-
- // 4. gemm Wh[2],h~t
- (this->*gemm_iter_func)('N', 'N', rnn.dic, rnn.mb, rnn.sic, 1.0, w_iter_[1],
- rnn.weights_iter_ld, states_t_l_, rnn.states_ws_ld, 1.0,
- &(ws_gates(0, 2, 0)), rnn.gates_ws_ld);
-
- // 5. activation h~t + calculate ht
- parallel_nd(rnn.mb, [&](int i) {
- PRAGMA_OMP_SIMD()
- for (int j = 0; j < rnn.dic; j++) {
- ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j));
- states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j)
- + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j);
- }
- });
-}
-
-template <>
-rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru) {
- assert(!"GRU int8 is not supported");
-}
-
-template <>
-rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru) {
- ws_gates_aoc_t ws_gates(rnn, ws_gates_);
- ws_states_aoc_t states_t_l(rnn, states_t_l_);
- ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
- ws_diff_w_iter_aoc_t diff_w_iter(rnn, diff_w_iter_);
- ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
- ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_);
- ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_);
-
- // use state memory for intermediate computations
- // TODO: use cell ws for that
- float *dhG1_ = &(diff_states_t_l(rnn.n_states, 0, 0));
- float *hG1_ = dhG1_;
- AOC<float, 2> dhG1(dhG1_, rnn.states_nld, rnn.states_ws_ld);
- AOC<float, 2> hG1(hG1_, rnn.states_nld, rnn.states_ws_ld);
-
- // 1. calculate dG2, dG1, and part of dht-1
- // dG2^ = dh * (1 - G0) * (1 - G2^2)
- // dG0^ = dh * (ht-1 - G2) * u * (1 - G0)
- // dht-1 (part) = dh * G0
- parallel_nd(rnn.mb, [&](int i) {
- PRAGMA_OMP_SIMD()
- for (int j = 0; j < rnn.dic; j++) {
- float h = states_tm1_l(i, j);
- float dHt = diff_states_tp1_l(0, i, j)
- + diff_states_t_lp1(rnn.n_states, i, j);
- float dG2 = (1.0f - ws_gates(i, 0, j)) * dHt
- * one_m_square(ws_gates(i, 2, j));
- float dG0 = (h - ws_gates(i, 2, j)) * dHt
- * x_m_square(ws_gates(i, 0, j));
-
- diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j);
- ws_gates(i, 0, j) = dG0;
- ws_gates(i, 2, j) = dG2;
- }
- });
-
- // 2. calculate intermediate d(hG1)
- // d(hG1) = dG2 * W2h^t
- (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.dic, 1.0, w_iter_[1],
- rnn.weights_iter_ld, &(ws_gates(0, 2, 0)), rnn.gates_ws_ld, 0.0,
- dhG1_, rnn.states_ws_ld);
-
- // 3. calculate dG1^ and part of dht-1
- // dG1^ = d(hG1) * h * G1 * (1 - G1)
- // dht-1 (part) += d(hG1) * G1
- // h * G1 (required for dWh)
- parallel_nd(rnn.mb, [&](int i) {
- PRAGMA_OMP_SIMD()
- for (int j = 0; j < rnn.dic; j++) {
- float h = states_tm1_l(i, j);
- float G1 = ws_gates(i, 1, j);
- diff_states_t_l(0, i, j) += dhG1(i, j) * G1;
- ws_gates(i, 1, j) = dhG1(i, j) * h * x_m_square(G1);
- hG1(i, j) = G1 * h;
- }
- });
-
- // 4. calculate diff weights
- // dWh1 += dG1 * h, dWh2 += dG2 * h, dWh3 += dG3 * (G1(*)h)
- gemm('N', 'T', (rnn.n_gates - 1) * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_,
- rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_,
- rnn.diff_weights_iter_ld);
- gemm('N', 'T', rnn.dic, rnn.sic, rnn.mb, 1.0, &(ws_gates(0, 2, 0)),
- rnn.gates_ws_ld, hG1_, rnn.states_ws_ld, 1.0,
- &(diff_w_iter(0, 2, 0)), rnn.diff_weights_iter_ld);
-
- // 5. calculate diff states
- // dht-1 += dG1 * W1h + dG0 * W0h
- (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb,
- (rnn.n_gates - 1) * rnn.dic, 1.0, w_iter_[0],
- rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld, 1.0,
- diff_states_t_l_, rnn.states_ws_ld);
-
- if (!rnn.merge_gemm_layer) {
- // dWx += [dG0 dG1 dG2] * [x]
- gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_,
- rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0,
- diff_w_layer_, rnn.diff_weights_layer_ld);
- // dx = dG2 * W2x + dG1 * W1x + dG0 * W0x
- (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
- rnn.n_gates * rnn.dic, 1.0, w_layer_[0],
- rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0,
- &(diff_states_t_l(rnn.n_states, 0, 0)), rnn.states_ws_ld);
- }
-
- // 6. calculate diff bias
- gates_reduction(rnn, ws_gates_, diff_bias_);
-}
-#undef AOC
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp
deleted file mode 100644
index 8dea8c90a4..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp
+++ /dev/null
@@ -1,170 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-/*
- * Cell execution GRU with linear before reset
- */
-
-#include "math_utils.hpp"
-#include "mkldnn_thread.hpp"
-
-#include "ref_rnn.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::math;
-using namespace rnn_utils;
-#define AOC array_offset_calculator
-
-template <>
-rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise) {
- ws_gates_aoc_t ws_gates(rnn, ws_gates_);
- bias_aoc_t bias(rnn, bias_);
- ws_states_aoc_t states_t_l(rnn, states_t_l_);
- ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
- ws_gates_aoc_t ws_gemm_state(rnn, ws_cell_);
- AOC<float, 2> ws_Wh_b(ws_grid_, rnn.mb, rnn.dic);
-
- parallel_nd(rnn.mb, [&](int i) {
- PRAGMA_OMP_SIMD()
- for (int j = 0; j < rnn.dic; j++) {
- float Wh_b = ws_gemm_state(i, 2, j) + bias(3, j);
- ws_gates(i, 0, j) = logistic_fwd(
- ws_gates(i, 0, j) + ws_gemm_state(i, 0, j) + bias(0, j));
- ws_gates(i, 1, j) = logistic_fwd(
- ws_gates(i, 1, j) + ws_gemm_state(i, 1, j) + bias(1, j));
- ws_gates(i, 2, j) = tanh_fwd(
- ws_gates(i, 2, j) + ws_gates(i, 1, j) * Wh_b + bias(2, j));
- states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j)
- + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j);
- if (rnn.is_training)
- ws_Wh_b(i, j) = Wh_b;
- }
- });
-}
-
-template <>
-rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise) {
- assert(!"GRU LBR int8 is not supported");
-}
-
-template <>
-rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr) {
- if (!rnn.merge_gemm_layer) {
- (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb,
- rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld,
- states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_,
- rnn.gates_ws_ld);
- }
- (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic,
- 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_,
- rnn.states_ws_ld, 0.0, ws_cell_, rnn.gates_ws_ld);
- (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_,
- states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
- diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
- ws_cell_);
-}
-
-template <>
-rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr) {
- assert(!"GRU LBR int8 is not supported");
-}
-
-template <>
-rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise) {
- ws_gates_aoc_t ws_gates(rnn, ws_gates_);
- ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
- ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
- ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_);
- ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_);
- ws_gates_aoc_t ws_gates_r(rnn, ws_cell_);
- AOC<float, 2> ws_Wh_b(ws_grid_, rnn.mb, rnn.dic);
-
- // 1. calculate dG1 dG2 dG3
- // dG0 = (dht - G2) * dht * (1 - G0) * G0
- // dG1 = (W*h + b) * dG2 * (1 - G1) * G1
- // dG2 = (1 - G0) * dht * (1 - G2*G2)
- parallel_nd(rnn.mb, [&](int i) {
- PRAGMA_OMP_SIMD()
- for (int j = 0; j < rnn.dic; j++) {
- float h = states_tm1_l(i, j);
- float dHt = diff_states_tp1_l(0, i, j)
- + diff_states_t_lp1(rnn.n_states, i, j);
- float dG0 = (h - ws_gates(i, 2, j)) * dHt
- * x_m_square(ws_gates(i, 0, j));
- float dG2 = (1.0f - ws_gates(i, 0, j))
- * one_m_square(ws_gates(i, 2, j)) * dHt;
- float dG1 = ws_Wh_b(i, j) * dG2 * x_m_square(ws_gates(i, 1, j));
-
- diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j);
- ws_gates(i, 2, j) = dG2;
- ws_gates_r(i, 2, j) = dG2 * ws_gates(i, 1, j);
- ws_gates(i, 0, j) = ws_gates_r(i, 0, j) = dG0;
- ws_gates(i, 1, j) = ws_gates_r(i, 1, j) = dG1;
- }
- });
-}
-
-template <>
-rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr) {
- ws_gates_aoc_t ws_gates_r(rnn, ws_cell_);
- ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
-
- (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_,
- states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
- diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
- ws_cell_);
-
- if (!rnn.merge_gemm_layer) {
- // dx = dG * Wx^t
- (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
- rnn.n_gates * rnn.dic, 1.0, w_layer_[0],
- rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0,
- &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld);
- // dWx += dG^t * x
- gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_,
- rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0,
- diff_w_layer_, rnn.diff_weights_layer_ld);
- }
- // dh += dGr * Wh^t
- (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic,
- 1.0, w_iter_[0], rnn.weights_iter_ld, ws_cell_, rnn.gates_ws_ld,
- 1.0, diff_states_t_l_, rnn.states_ws_ld);
-
- // dWh += dGr^t * h
- gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_cell_,
- rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_,
- rnn.diff_weights_layer_ld);
-
- // db1-3 += e * dG
- // db4 += e * (r * dG2)
- gates_reduction(rnn, ws_gates_, diff_bias_);
-
- parallel_nd(rnn.dic, [&](int j) {
- for (int i = 0; i < rnn.mb; i++) {
- diff_bias_[3 * rnn.dic + j] += ws_gates_r(i, 2, j);
- }
- });
-}
-
-#undef AOC
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp
deleted file mode 100644
index a15ba00d4c..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp
+++ /dev/null
@@ -1,143 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-/*
- * Cell execution LSTM
- */
-
-#include "math_utils.hpp"
-#include "mkldnn_thread.hpp"
-
-#include "../simple_q10n.hpp"
-#include "ref_rnn.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::math;
-using namespace rnn_utils;
-
-template <>
-rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise) {
- ws_gates_aoc_t ws_gates(rnn, ws_gates_);
- bias_aoc_t bias(rnn, bias_);
- ws_states_aoc_t states_t_l(rnn, states_t_l_);
- ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_);
- ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_);
-
- parallel_nd(rnn.mb, [&](int i) {
- PRAGMA_OMP_SIMD()
- for (int j = 0; j < rnn.dic; j++) {
- ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j));
- ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j));
- ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j));
- ws_gates(i, 3, j) = logistic_fwd(ws_gates(i, 3, j) + bias(3, j));
-
- float tmp = ws_gates(i, 1, j) * c_states_tm1_l(i, j)
- + ws_gates(i, 0, j) * ws_gates(i, 2, j);
- states_t_l(i, j) = ws_gates(i, 3, j) * tanh_fwd(tmp);
- c_states_t_l(i, j) = tmp;
- }
- });
-}
-
-template <>
-rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise) {
- ws_gates_aoc_s32_t ws_gates_s32(rnn, ws_gates_);
- bias_aoc_t bias(rnn, bias_);
- ws_states_aoc_u8_t states_t_l(rnn, states_t_l_);
- ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_);
- ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_);
-
- float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_;
- float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
- float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
-
- auto q_d = [&](float f) {
- float qf = f * data_scale + data_shift;
- return qz_a1b0<float, src_data_t>()(qf);
- };
-
- auto deq_w = [&](acc_data_t s, int gate, int j) {
- return pd()->attr()->rnn_weights_qparams_.mask_ == 0 ?
- saturate<float>(s) * (1.f / (weights_scales[0] * data_scale)) :
- saturate<float>(s) * (1.f / (weights_scales[gate * rnn.dic + j]
- * data_scale));
- };
-
- parallel_nd(rnn.mb, [&](int i) {
- PRAGMA_OMP_SIMD()
- for (int j = 0; j < rnn.dic; j++) {
- float G0 = logistic_fwd<float>(
- deq_w(ws_gates_s32(i, 0, j), 0, j) + bias(0, j));
- float G1 = logistic_fwd<float>(
- deq_w(ws_gates_s32(i, 1, j), 1, j) + bias(1, j));
- float G2 = tanh_fwd<float>(
- deq_w(ws_gates_s32(i, 2, j), 2, j) + bias(2, j));
- float G3 = logistic_fwd<float>(
- deq_w(ws_gates_s32(i, 3, j), 3, j) + bias(3, j));
- float tmp = G1 * c_states_tm1_l(i, j) + G0 * G2;
- states_t_l(i, j) = q_d(G3 * tanh_fwd(tmp));
- c_states_t_l(i, j) = tmp;
- }
- });
-}
-
-template <>
-rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise) {
- ws_gates_aoc_t ws_gates(rnn, ws_gates_);
- bias_aoc_t bias(rnn, bias_);
- ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_);
- ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_);
- ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
- ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_);
- ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_);
-
- parallel_nd(rnn.mb, [&](int i) {
- PRAGMA_OMP_SIMD()
- for (int j = 0; j < rnn.dic; j++) {
- float Ct = c_states_t_l(i, j);
- /// @todo save it in the workspace in fwd pass or recompute it to
- /// save bw
- float tanhCt = tanh_fwd(Ct);
- // we have 2 incoming diffs on Ht
- float dHt = diff_states_tp1_l(0, i, j)
- + diff_states_t_lp1(rnn.n_states, i, j);
- float dCt = diff_states_tp1_l(1, i, j)
- + one_m_square(tanhCt) * ws_gates(i, 3, j) * dHt;
-
- float dG1 = c_states_tm1_l(i, j) * dCt
- * x_m_square(ws_gates(i, 1, j));
- float dG0 = ws_gates(i, 2, j) * dCt * x_m_square(ws_gates(i, 0, j));
- float dG3 = tanhCt * dHt * x_m_square(ws_gates(i, 3, j));
- float dG2
- = ws_gates(i, 0, j) * dCt * one_m_square(ws_gates(i, 2, j));
-
- diff_states_t_l(1, i, j) = dCt * ws_gates(i, 1, j);
-
- ws_gates(i, 0, j) = dG0;
- ws_gates(i, 1, j) = dG1;
- ws_gates(i, 2, j) = dG2;
- ws_gates(i, 3, j) = dG3;
- }
- });
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp
deleted file mode 100644
index 4536e8dfad..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp
+++ /dev/null
@@ -1,113 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-/*
- * Cell execution of Vanilla RNN
- */
-
-#include "math_utils.hpp"
-#include "mkldnn_thread.hpp"
-
-#include "ref_rnn.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::math;
-using namespace rnn_utils;
-
-template <>
-float activation<alg_kind::eltwise_relu, prop_kind::forward>(
- float dd, float s, float alpha, float cliping) {
- return relu_fwd<float>(s, alpha);
-}
-
-template <>
-float activation<alg_kind::eltwise_relu, prop_kind::backward>(
- float dd, float s, float alpha, float cliping) {
- return relu_bwd<float>(dd, s, alpha);
-}
-
-template <>
-float activation<alg_kind::eltwise_tanh, prop_kind::forward>(
- float dd, float s, float alpha, float cliping) {
- return tanh_fwd<float>(s);
-}
-
-template <>
-float activation<alg_kind::eltwise_tanh, prop_kind::backward>(
- float dd, float s, float alpha, float cliping) {
- return dd * one_m_square<float>(s);
-}
-
-template <>
-float activation<alg_kind::eltwise_logistic, prop_kind::forward>(
- float dd, float s, float alpha, float cliping) {
- return logistic_fwd<float>(s);
-}
-
-template <>
-float activation<alg_kind::eltwise_logistic, prop_kind::backward>(
- float dd, float s, float alpha, float cliping) {
- return dd * x_m_square<float>(s);
-}
-
-template <>
-rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise) {
- ws_gates_aoc_t ws_gates(rnn, ws_gates_);
- bias_aoc_t bias(rnn, bias_);
- ws_states_aoc_t states_t_l(rnn, states_t_l_);
- ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
-
- parallel_nd(rnn.mb, [&](int i) {
- for (int j = 0; j < rnn.dic; j++) {
- const float h
- = activation_func(0, ws_gates(i, 0, j) + bias(0, j), 0, 0);
- ws_gates(i, 0, j) = states_t_l(i, j) = h;
- }
- });
-}
-
-template <>
-rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise) {
- assert(!"VANILLA RNN int8 is not supported");
-}
-
-template <>
-rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise) {
- ws_gates_aoc_t ws_gates(rnn, ws_gates_);
- bias_aoc_t bias(rnn, bias_);
- ws_states_aoc_t states_t_l(rnn, states_t_l_);
- ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
- ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
- ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_);
- ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_);
-
- parallel_nd(rnn.mb, [&](int i) {
- for (int j = 0; j < rnn.dic; ++j) {
- const float dH = diff_states_t_lp1(rnn.n_states, i, j)
- + diff_states_tp1_l(0, i, j);
- auto g = ws_gates(i, 0, j);
- ws_gates(i, 0, j) = activation_func(dH, g, 0, 0);
- }
- });
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp
deleted file mode 100644
index b39427caf9..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp
+++ /dev/null
@@ -1,191 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_RNN_PD_HPP
-#define CPU_RNN_PD_HPP
-
-#include "c_types_map.hpp"
-#include "nstl.hpp"
-#include "rnn_pd.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-#include "rnn_utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct cpu_rnn_fwd_pd_t : public rnn_fwd_pd_t {
- using rnn_fwd_pd_t::rnn_fwd_pd_t;
-
-protected:
- status_t set_default_params() {
- using namespace format_tag;
- if (src_layer_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(src_layer_md_, tnc));
- if (dst_layer_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc));
-
- // Optional parameters
- if (with_src_iter() && src_iter_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(src_iter_md_, ldsnc));
- if (with_bias() && bias_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(bias_md_, ldgo));
- if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(dst_iter_md_, ldsnc));
-
- return status::success;
- }
-
- status_t check_layout_consistency() {
- using namespace format_tag;
- using namespace data_type;
- using namespace types;
-
- auto is_blocked = [&](memory_desc_t md, int ndims) {
- return md.format_kind == format_kind::blocked && md.ndims == ndims;
- };
-
- bool ok = true;
- ok = ok && is_blocked(src_layer_md_, 3)
- && is_blocked(dst_layer_md_, 3);
- ok = ok && IMPLICATION(!is_zero_md(&src_iter_md_),
- is_blocked(src_iter_md_, 5))
- && IMPLICATION(!is_zero_md(&dst_iter_md_),
- is_blocked(dst_iter_md_, 5));
-
- if (weights_layer_md_.format_kind == format_kind::rnn_packed)
- ok = ok && (weights_layer_md_.format_desc.rnn_packed_desc.format
- == mkldnn_ldigo_p);
- else
- ok = ok && rnn_utils::is_ldigo(&weights_layer_md_);
-
- if (weights_iter_md_.format_kind == format_kind::rnn_packed)
- ok = ok && (weights_iter_md_.format_desc.rnn_packed_desc.format
- == mkldnn_ldigo_p);
- else
- ok = ok && rnn_utils::is_ldigo(&weights_iter_md_);
-
- ok = ok && IMPLICATION(!is_zero_md(&bias_md_),
- memory_desc_matches_tag(bias_md_, ldgo));
-
- /* Int8 is supported only for packed weights */
- data_type_t weights_iter_dt = weights_iter_md_.data_type;
- data_type_t weights_layer_dt = weights_layer_md_.data_type;
- ok = ok && IMPLICATION(
- weights_iter_dt == s8, weights_iter_md_.format_kind
- == format_kind::rnn_packed);
- ok = ok && IMPLICATION(
- weights_layer_dt == s8, weights_layer_md_.format_kind
- == format_kind::rnn_packed);
-
- return ok ? status::success : status::unimplemented;
- }
-};
-
-struct cpu_rnn_bwd_pd_t : public rnn_bwd_pd_t {
- using rnn_bwd_pd_t::rnn_bwd_pd_t;
-
-protected:
- status_t set_default_params() {
- using namespace format_tag;
- if (src_layer_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(src_layer_md_, tnc));
- if (dst_layer_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc));
-
- if (diff_src_layer_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(diff_src_layer_md_, tnc));
- if (diff_weights_layer_md_.format_kind == format_kind::any) {
- CHECK(memory_desc_init_by_tag(diff_weights_layer_md_, ldigo));
- CHECK(rnn_utils::set_good_strides(diff_weights_layer_md_, ldigo));
- }
- if (diff_weights_iter_md_.format_kind == format_kind::any) {
- CHECK(memory_desc_init_by_tag(diff_weights_iter_md_, ldigo));
- CHECK(rnn_utils::set_good_strides(diff_weights_iter_md_, ldigo));
- }
- if (diff_dst_layer_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(diff_dst_layer_md_, tnc));
-
- // Optional parameters
- if (with_src_iter() && src_iter_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(src_iter_md_, ldsnc));
- if (with_bias() && bias_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(bias_md_, ldgo));
- if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(dst_iter_md_, ldsnc));
-
- if (with_src_iter() && diff_src_iter_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(diff_src_iter_md_, ldsnc));
- if (with_bias() && diff_bias_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(diff_bias_md_, ldgo));
- if (with_dst_iter() && diff_dst_iter_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(diff_dst_iter_md_, ldsnc));
-
- return status::success;
- }
-
- status_t check_layout_consistency() {
- using namespace format_tag;
- using namespace types;
-
- auto is_blocked = [&](memory_desc_t md, int ndims) {
- return md.format_kind == format_kind::blocked && md.ndims == ndims;
- };
-
- bool ok = true;
- ok = ok && is_blocked(src_layer_md_, 3)
- && is_blocked(dst_layer_md_, 3);
- ok = ok && IMPLICATION(!is_zero_md(&src_iter_md_),
- is_blocked(src_iter_md_, 5))
- && IMPLICATION(!is_zero_md(&dst_iter_md_),
- is_blocked(dst_iter_md_, 5));
-
- if (weights_layer_md_.format_kind == format_kind::rnn_packed)
- ok = ok && (weights_layer_md_.format_desc.rnn_packed_desc.format
- == mkldnn_ldgoi_p);
- else
- ok = ok && rnn_utils::is_ldgoi(&weights_layer_md_);
-
- if (weights_iter_md_.format_kind == format_kind::rnn_packed)
- ok = ok && (weights_iter_md_.format_desc.rnn_packed_desc.format
- == mkldnn_ldgoi_p);
- else
- ok = ok && rnn_utils::is_ldgoi(&weights_iter_md_);
-
- ok = ok && IMPLICATION(!is_zero_md(&bias_md_),
- memory_desc_matches_tag(bias_md_, ldgo));
-
- ok = ok && is_blocked(diff_src_layer_md_, 3)
- && is_blocked(diff_dst_layer_md_, 3);
- ok = ok && IMPLICATION(!is_zero_md(&diff_src_iter_md_),
- is_blocked(diff_src_iter_md_, 5))
- && IMPLICATION(!is_zero_md(&diff_dst_iter_md_),
- is_blocked(diff_dst_iter_md_, 5));
-
- ok = ok && rnn_utils::is_ldigo(&diff_weights_layer_md_)
- && rnn_utils::is_ldigo(&diff_weights_iter_md_);
- ok = ok && IMPLICATION(!is_zero_md(&diff_bias_md_),
- memory_desc_matches_tag(diff_bias_md_, ldgo));
-
- return ok ? status::success : status::unimplemented;
- }
-};
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp
deleted file mode 100644
index 09445648aa..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp
+++ /dev/null
@@ -1,401 +0,0 @@
-/*******************************************************************************
-* Copyright 2019 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-/*
- * Cell execution LSTM
- */
-
-#include "rnn_utils.hpp"
-#include "../jit_generator.hpp"
-#include "../jit_uni_eltwise.hpp"
-#include "c_types_map.hpp"
-#include "utils.hpp"
-
-#include "mkldnn_thread.hpp"
-
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-struct jit_uni_rnn_postgemm_kernel : public jit_generator {
-
- typedef void (*kernel_t)(void *gates_, const void *bias, void *states_t_l_,
- void *c_states_t_l_, void *c_states_tm1_l_);
-
- jit_uni_rnn_postgemm_kernel(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr): rnn_(rnn), attr_(attr){}
-
- virtual void init() = 0;
-
-template <typename src_data_t, typename acc_data_t>
- rnn_elemwise_sig(execute) {
- rnn_utils::ws_gates_aoc<acc_data_t> ws_gates(rnn, ws_gates_);
- rnn_utils::bias_aoc_t bias(rnn, bias_);
- rnn_utils::ws_states_aoc<src_data_t> states_t_l(rnn, states_t_l_);
- rnn_utils::ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_);
- rnn_utils::ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_);
-
- // Todo: add parallelization on dic for the batch 1 case
- // Assumption: the kernel runs a loop on dic elements
- parallel_nd(rnn.mb, [&](int i) {
- auto b_ = &bias(0, 0);
- auto g_ = &ws_gates(i, 0, 0);
- auto s_tl_ = &states_t_l(i, 0);
- auto c_tl_ = &c_states_t_l(i, 0);
- auto c_tm1l_ = &c_states_tm1_l(i, 0);
- kernel_(g_, b_, s_tl_, c_tm1l_, c_tl_);
- });
- }
-
-protected:
- kernel_t kernel_;
- const rnn_utils::rnn_conf_t &rnn_;
- const primitive_attr_t *attr_;
-};
-
-template <cpu_isa_t isa, impl::data_type_t src_data_t>
-struct jit_uni_lstm_postgemm_kernel_fwd: public jit_uni_rnn_postgemm_kernel
-{
- DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lstm_postgemm_kernel_fwd)
-
- typedef typename utils::conditional<src_data_t == data_type::u8, int32_t,
- float>::type acc_data_t;
- typedef typename utils::conditional<isa == avx512_core,
- jit_uni_eltwise_injector_f32<avx512_common>,
- jit_uni_eltwise_injector_f32<isa>>::type injector_t;
-
- jit_uni_lstm_postgemm_kernel_fwd(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr)
- : jit_uni_rnn_postgemm_kernel(rnn, attr){}
-
- void init() override {
- // we use rax for both constant tables as they use the same table
- sigmoid_injector_ = new injector_t(this,
- alg_kind::eltwise_logistic, 0.0f, 0.0f, true, rax);
- tanh_injector_ = new injector_t(this,
- alg_kind::eltwise_tanh, 0.0f, 0.0f, true, rax);
- generate();
- kernel_ = (kernel_t) this->getCode();
- }
-
-protected:
- injector_t *sigmoid_injector_;
- injector_t *tanh_injector_;
-
- // register size in bytes
- using Vmm = typename jit_uni_eltwise_injector_f32<isa>::Vmm;
- size_t vlen = cpu_isa_traits<isa>::vlen;
- size_t vlen_dst = (src_data_t == data_type::u8) ? vlen/4 : vlen;
- size_t cstate_dt_size = sizeof(float);
- size_t hstate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint8_t) : sizeof(float);
- size_t gate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint32_t) : sizeof(float);
- size_t qscale_dt_size = sizeof(float);
- size_t bias_dt_size = sizeof(float);
-
- void generate() {
- using namespace Xbyak;
-
- int mask = attr_->rnn_weights_qparams_.mask_;
- float *weights_scales = attr_->rnn_weights_qparams_.scales_;
- float data_scale = attr_->rnn_data_qparams_.scale_;
- float data_shift = attr_->rnn_data_qparams_.shift_;
-
- // Labels declaration
- Label vector_loop_start_label, vector_loop_end_label;
- Label rem_loop_start_label, rem_loop_end_label;
- Label table_label;
-
- // Register map
- Reg64 loop_cnt(r11); // loop counter
- Reg64 table_reg(rbx); // table is used for data scale and shifts
- Reg64 weights_scales_reg(r13);
- // We skip vmm0 as it can be used by the injector for masks on sse4.2
- Vmm G0(1), G1(2), G2(3), G3(4), tmp1_vmm(5), tmp2_vmm(6), zero_vmm(7);
-
- // constant table map
- Address dscale_off_addr = ptr[table_reg];
- Address dshift_off_addr = ptr[table_reg + vlen];
- Address ymm_perm_mask_addr = ptr[table_reg + 2*vlen];
- Address zmm_perm_mask_addr = ptr[table_reg + 2*vlen + cpu_isa_traits<avx>::vlen];
-
- // quantize from float to u8
- auto q_d = [&](Vmm f, Vmm tmp_vmm) {
- uni_vpxor(tmp_vmm, tmp_vmm, tmp_vmm);
- uni_vmulps(f, f, dscale_off_addr); // apply scale
- uni_vaddps(f, f, dshift_off_addr); // apply shift
- uni_vcvtps2dq(f, f); // convert to int32
- uni_vpackssdw(f, f, tmp_vmm); // convert from s32 to s16
- uni_vpackuswb(f, f, tmp_vmm); // convert from s16 to u8 with saturation
- // Note that the results are interleaved by 128 bit chunks, so we need to merge them together
- switch (vlen) {
- case 64: { //avx512
- Zmm fz(f.getIdx()), tmpz(tmp_vmm.getIdx());
- uni_vmovups(tmpz, zmm_perm_mask_addr);
- vpermd(fz, tmpz, fz);
- break; }
- case 32: { //avx
- Ymm fy(f.getIdx()), tmpy(tmp_vmm.getIdx());
- uni_vmovups(tmpy, ymm_perm_mask_addr);
- vpermd(fy, tmpy, fy);
- break; }
- case 16: // sse: nothing to do
- break;
- default: assert(!"Unsupported case");
- };
- };
-
- auto fast_recip =[&](Vmm s, Vmm tmp, bool packed) {
- if (packed)
- uni_vrcpps(tmp, s);
- else
- uni_vrcpss(tmp, s); // prevent divide by zero
- // we add one Newton iteration
- uni_vmulps(s, s, tmp);
- uni_vmulps(s, s, tmp); // s <- s * tmp^2
- uni_vaddps(tmp, tmp, tmp);
- uni_vsubps(tmp, tmp, s);
- uni_vmovups(s, tmp); // s <- 2 * tmp - s * tmp^2
- };
-
- // dequantize from s32 to float
- auto deq_w = [&](Vmm s, Vmm tmp1, Vmm tmp2, int gate, bool packed) {
- // TODO: if mask is 0 precompute mul and inverse
- if (mask == 0)
- uni_vbroadcastss(tmp1, ptr[weights_scales_reg]);
- else
- uni_vmovups(tmp1, ptr[weights_scales_reg + gate * rnn_.dic * qscale_dt_size]);
- uni_vcvtdq2ps(s, s);
- uni_vmulps(tmp1, tmp1, dscale_off_addr);
- fast_recip(tmp1, tmp2, packed);
- uni_vmulps(s, s, tmp1);
- };
-
- // We start code generations here
- preamble();
-
- // extract addresses passed as parameter
-#ifdef _WIN32
- auto addr_ws_gates_reg = abi_param1;
- auto addr_bias_reg = abi_param2;
- auto addr_states_t_l_reg = abi_param3;
- auto addr_c_states_tm1_l_reg = abi_param4;
- auto addr_c_states_t_l_reg = r10;
- // Here we cannot use rbp to have initial stack pointer so we
- // use rsp and offset it with the size of pushed registers in
- // preamble
- mov(addr_c_states_t_l_reg, ptr[rsp + get_size_of_abi_save_regs() + 40]);
-#else
- auto addr_ws_gates_reg = abi_param1;
- auto addr_bias_reg = abi_param2;
- auto addr_states_t_l_reg = abi_param3;
- auto addr_c_states_tm1_l_reg = abi_param4;
- auto addr_c_states_t_l_reg = abi_param5;
-#endif
-
- // initialize registers with addresses and constants
- mov(table_reg, table_label);
- mov(weights_scales_reg, size_t(weights_scales));
- // both sigmoid and tanh use the same table so load address just once in rax
- sigmoid_injector_->load_table_addr();
-
- mov(loop_cnt, rnn_.dic * gate_dt_size);
- cmp(loop_cnt, vlen);
- jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
-
- L(vector_loop_start_label);
- {
- // load G0 G1 G2 G3
- uni_vmovups(G0, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]);
- uni_vmovups(G1, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]);
- uni_vmovups(G2, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]);
- uni_vmovups(G3, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]);
-
- // dequantize the gates from s32 to f32 if needed
- if (src_data_t == data_type::u8){
- deq_w(G0, tmp1_vmm, tmp2_vmm, 0, true);
- deq_w(G1, tmp1_vmm, tmp2_vmm, 1, true);
- deq_w(G2, tmp1_vmm, tmp2_vmm, 2, true);
- deq_w(G3, tmp1_vmm, tmp2_vmm, 3, true);
- }
-
- // add biases
- uni_vaddps(G0, G0, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]);
- uni_vaddps(G1, G1, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]);
- uni_vaddps(G2, G2, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]);
- uni_vaddps(G3, G3, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]);
-
- // inject eltwise code
- sigmoid_injector_->compute_vector(G0.getIdx());
- sigmoid_injector_->compute_vector(G1.getIdx());
- tanh_injector_->compute_vector(G2.getIdx());
- sigmoid_injector_->compute_vector(G3.getIdx());
-
- // compute c_states_t_l = G1 * c_tm1_l + G0 * G2
- uni_vmovups(tmp1_vmm, ptr[addr_c_states_tm1_l_reg]);
- uni_vmulps(tmp1_vmm, tmp1_vmm, G1);
- uni_vfmadd231ps(tmp1_vmm, G0, G2);
- uni_vmovups(ptr[addr_c_states_t_l_reg], tmp1_vmm);
-
- // states_t_l = G3 * tanh(c_states_t_l)
- tanh_injector_->compute_vector(tmp1_vmm.getIdx());
- uni_vmulps(tmp1_vmm, tmp1_vmm, G3);
-
- // if int8, we quantize the resulting state
- if (src_data_t == data_type::u8)
- q_d(tmp1_vmm, tmp2_vmm);
-
- // write back the result
- if(vlen_dst == vlen)
- uni_vmovups(ptr[addr_states_t_l_reg], tmp1_vmm);
- else
- // we write only 1/4 of the register
- switch(vlen_dst){
- case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break;
- case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break;
- case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break;
- default:
- assert(!"Unsuported vector length for quantization");
- }
-
- // increment address pointers
- add(addr_ws_gates_reg, vlen);
- add(addr_bias_reg, vlen);
- add(addr_states_t_l_reg, vlen_dst);
- add(addr_c_states_tm1_l_reg, vlen);
- add(addr_c_states_t_l_reg, vlen);
- if (mask != 0)
- add(weights_scales_reg, vlen);
-
- // increment loop counter
- sub(loop_cnt, vlen);
- cmp(loop_cnt, vlen);
- jge(vector_loop_start_label);
- }
- L(vector_loop_end_label);
-
- cmp(loop_cnt, 0);
- je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
- // Same code as above, we just use movuss for accessing inputs
- // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar
- L(rem_loop_start_label);
- {
- // remaping registers to Xmms
- Xmm G0s(G0.getIdx()), G1s(G1.getIdx()), G2s(G2.getIdx()), G3s(G3.getIdx());
- Xmm tmp1s_vmm(tmp1_vmm.getIdx());
-
- // load G0 G1 G2 G3
- uni_vmovss(G0s, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]);
- uni_vmovss(G1s, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]);
- uni_vmovss(G2s, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]);
- uni_vmovss(G3s, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]);
-
- // dequantize the gates from s32 to f32 if needed
- if (src_data_t == data_type::u8){
- deq_w(G0, tmp1_vmm, tmp2_vmm, 0, false);
- deq_w(G1, tmp1_vmm, tmp2_vmm, 1, false);
- deq_w(G2, tmp1_vmm, tmp2_vmm, 2, false);
- deq_w(G3, tmp1_vmm, tmp2_vmm, 3, false);
- }
-
- // add biases
- uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]);
- uni_vaddps(G0s, G0s, tmp1s_vmm);
- uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]);
- uni_vaddps(G1s, G1s, tmp1s_vmm);
- uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]);
- uni_vaddps(G2s, G2s, tmp1s_vmm);
- uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]);
- uni_vaddps(G3s, G3s, tmp1s_vmm);
-
- // inject eltwise code
- sigmoid_injector_->compute_vector(G0s.getIdx());
- sigmoid_injector_->compute_vector(G1s.getIdx());
- tanh_injector_->compute_vector(G2s.getIdx());
- sigmoid_injector_->compute_vector(G3s.getIdx());
-
- // compute c_states_t_l = G1 * c_tm1_l + G0s * G2
- uni_vmovups(tmp1s_vmm, ptr[addr_c_states_tm1_l_reg]);
- uni_vmulps(tmp1s_vmm, tmp1s_vmm, G1s);
- uni_vfmadd231ps(tmp1s_vmm, G0s, G2s);
- uni_vmovss(ptr[addr_c_states_t_l_reg], tmp1s_vmm);
-
- // states_t_l = G3 * tanh(c_states_t_l)
- tanh_injector_->compute_vector(tmp1s_vmm.getIdx());
- uni_vmulps(tmp1s_vmm, tmp1s_vmm, G3s);
-
- // if int8, we quantize the resulting state
- if (src_data_t == data_type::u8)
- q_d(tmp1_vmm, tmp2_vmm);
-
- // write back the result
- if(vlen_dst == vlen)
- uni_vmovups(ptr[addr_states_t_l_reg], tmp1s_vmm);
- else
- // we write only 1/4 of the register
- switch(vlen_dst){
- case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break;
- case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break;
- case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break;
- default:
- assert(!"Unsuported vector length for quantization");
- }
-
- // increment address pointers
- add(addr_ws_gates_reg, gate_dt_size);
- add(addr_bias_reg, bias_dt_size);
- add(addr_states_t_l_reg, hstate_dt_size);
- add(addr_c_states_tm1_l_reg, cstate_dt_size);
- add(addr_c_states_t_l_reg, cstate_dt_size);
- if (mask != 0)
- add(weights_scales_reg, qscale_dt_size);
-
- // increment loop counter
- sub(loop_cnt, gate_dt_size);
- cmp(loop_cnt, 0);
- jg(rem_loop_start_label);
-
- }
- L(rem_loop_end_label);
-
- postamble();
-
- // Again, only one table is needed and shared between sigmoid and tanh
- sigmoid_injector_->prepare_table(false);
- tanh_injector_->prepare_table(true);
-
- L(table_label);
- {
- for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_scale));
- for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_shift));
- // perm mask for ymm
- dd(0); dd(4); dd(2); dd(3); dd(1); dd(5); dd(6); dd(7);
- // perm mask for zmm
- dd(0); dd(4); dd(8); dd(12); dd(1); dd(5); dd(6); dd(7);
- dd(2); dd(9); dd(10); dd(11); dd(3); dd(12); dd(13); dd(14);
- }
- }
-
-};
-
-template struct jit_uni_lstm_postgemm_kernel_fwd<sse42, data_type::f32>;
-template struct jit_uni_lstm_postgemm_kernel_fwd<avx2, data_type::f32>;
-template struct jit_uni_lstm_postgemm_kernel_fwd<avx512_core, data_type::f32>;
-
-template struct jit_uni_lstm_postgemm_kernel_fwd<sse42, data_type::u8>;
-template struct jit_uni_lstm_postgemm_kernel_fwd<avx2, data_type::u8>;
-template struct jit_uni_lstm_postgemm_kernel_fwd<avx512_core, data_type::u8>;
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp
deleted file mode 100644
index ead536816c..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp
+++ /dev/null
@@ -1,788 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-/*
- General architecture
-
- for diff states, we have n_states + 1 as we have n_states diff
- to propagate to the previous iteration and 1 states to propagate
- to the previous layer
- index 0 is dh for cell(t-1, l) to consume
- index 1 is dc for cell(t-1, l) to consume
- index 2 is dh for cell(t, l-1) to consume
- this indexing enables to have the same indexing for states in elemwise
- function
- only the cell execution function should be impacted
-
- */
-
-#include "math_utils.hpp"
-#include "mkldnn_thread.hpp"
-
-#include "ref_rnn.hpp"
-#include "../gemm/gemm.hpp"
-#include "../simple_q10n.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::memory_tracking::names;
-using namespace rnn_utils;
-#define AOC array_offset_calculator
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-void _ref_rnn_common_t<aprop, src_type, weights_type>::gates_reduction(
- const rnn_conf_t &rnn, const acc_data_t *ws_gates_,
- float *diff_bias_) const {
- auto body = [&](int i, int k) {
- for (int j = 0; j < rnn.mb; j++)
- diff_bias_[i * rnn.dic + k]
- += ws_gates_[j * rnn.gates_ws_ld + i * rnn.dic + k];
- };
-
- // @todo block k on simd-width
-#if MKLDNN_THR == MKLDNN_THR_OMP && _OPENMP >= 201307 \
- /* icc 17.0 has a problem with simd collapse */ \
- && !((defined __INTEL_COMPILER) && (__INTEL_COMPILER == 1700))
-#pragma omp parallel for simd collapse(2)
- for (int i = 0; i < rnn.n_gates; i++)
- for (int k = 0; k < rnn.dic; k++)
- body(i, k);
-#else
- parallel_nd(rnn.n_gates, rnn.dic, body);
-#endif
-}
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::gemm)) {
- assert(ldA * ldB * ldC != 0);
- extended_sgemm(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_, &ldB,
- &beta, c_, &ldC, nullptr, pd()->rnn_.use_jit_gemm);
-}
-
-template <>
-rnn_gemm_sig((ref_rnn_fwd_u8s8_t::gemm)) {
- assert(!"non packed gemm is disabled for int8");
-}
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::packed_gemm)) {
-#if (USE_MKL_PACKED_GEMM)
- assert(transA == 'N');
- cblas_sgemm_compute(CblasColMajor, CblasPacked,
- (transB == 'T') ? CblasTrans : CblasNoTrans, m, n, k, a_, ldA, b_,
- ldB, beta, c_, ldC);
-#else
- UNUSED(transA);
- UNUSED(transB);
- UNUSED(m);
- UNUSED(n);
- UNUSED(k);
- UNUSED(alpha);
- UNUSED(ldA);
- UNUSED(b_);
- UNUSED(ldB);
- UNUSED(beta);
- UNUSED(c_);
- UNUSED(ldC);
- assert(!"packed gemm is disabled");
-#endif
-}
-
-template <>
-rnn_gemm_sig((ref_rnn_fwd_u8s8_t::packed_gemm)) {
-#if (USE_MKL_PACKED_GEMM)
- int8_t offseta = 0, offsetb = 0;
- int32_t offsetc = 0;
- cblas_gemm_s8u8s32_compute(CblasColMajor, (CBLAS_TRANSPOSE)CblasPacked,
- CblasNoTrans, CblasFixOffset, m, n, k, alpha, a_, ldA, offseta, b_,
- ldB, offsetb, beta, c_, ldC, &offsetc);
-#else
- UNUSED(transA);
- UNUSED(transB);
- UNUSED(m);
- UNUSED(n);
- UNUSED(k);
- UNUSED(alpha);
- UNUSED(ldA);
- UNUSED(b_);
- UNUSED(ldB);
- UNUSED(beta);
- UNUSED(c_);
- UNUSED(ldC);
- assert(!"packed gemm is disabled");
-#endif
-}
-
-//*************** Grid computations strategy: linear ***************//
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-rnn_grid_execution_sig(
- (_ref_rnn_common_t<aprop, src_type, weights_type>::linear_execution)) {
- AOC<src_data_t, 4> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
- rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld);
- AOC<float, 4> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir,
- rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld);
- AOC<float, 5> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir,
- (rnn.n_states + 1), rnn.n_iter + 1,
- rnn.states_nld * rnn.states_ws_ld);
- AOC<acc_data_t, 4> ws_gates(ws_gates_, rnn.n_layer, rnn.n_dir, rnn.n_iter,
- rnn.gates_nld * rnn.gates_ws_ld);
- AOC<weights_data_t *, 3> weights_input(
- weights_layer_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_layer);
- AOC<weights_data_t *, 3> weights_states(
- weights_states_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_iter);
- AOC<float*, 3> bias(
- bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias);
- AOC<float, 3> diff_weights_layer(diff_weights_layer_, rnn.n_layer,
- rnn.n_dir,
- rnn.diff_weights_layer_nld * rnn.diff_weights_layer_ld);
- AOC<float, 3> diff_weights_iter(diff_weights_iter_, rnn.n_layer, rnn.n_dir,
- rnn.diff_weights_iter_nld * rnn.diff_weights_iter_ld);
- AOC<float, 3> diff_bias(
- diff_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic);
- AOC<float, 4> ws_grid(
- ws_grid_, rnn.n_layer, rnn.n_dir, rnn.n_iter, (int)rnn.ws_per_cell);
-
- // We run the grid of computation
- for (int dir = 0; dir < rnn.n_dir; dir++) {
- for (int j = 0; j < rnn.n_layer; j++) {
- int lay = (aprop == prop_kind::forward) ? j : rnn.n_layer - j - 1;
-
- if ((aprop == prop_kind::forward) && rnn.merge_gemm_layer) {
- (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic,
- rnn.mb * rnn.n_iter, rnn.slc, 1.0,
- weights_input(lay, dir, 0), rnn.weights_iter_ld,
- &(ws_states(lay, dir, 1, 0)), rnn.states_ws_ld, 0.0,
- &(ws_gates(lay, dir, 0, 0)), rnn.gates_ws_ld);
- }
-
- for (int i = 0; i < rnn.n_iter; i++) {
- int iter = (aprop == prop_kind::forward) ? i : rnn.n_iter - i - 1;
- (this->*cell_func)(rnn,
- &(ws_states(lay + 1, dir, iter + 1, 0)),
- &(ws_c_states(lay + 1, dir, iter + 1, 0)),
- &(ws_diff_states(lay, dir, 0, iter, 0)),
- &(weights_input(lay, dir, 0)),
- &(weights_states(lay, dir, 0)),
- &(bias(lay, dir, 0)),
- &(ws_states(lay, dir, iter + 1, 0)),
- &(ws_states(lay + 1, dir, iter, 0)),
- &(ws_c_states(lay + 1, dir, iter, 0)),
- &(ws_diff_states(lay + 1, dir, 0, iter, 0)),
- &(ws_diff_states(lay, dir, 0, iter + 1, 0)),
- &(diff_weights_layer(lay, dir, 0)),
- &(diff_weights_iter(lay, dir, 0)),
- &(diff_bias(lay, dir, 0)),
- &(ws_gates(lay, dir, iter, 0)),
- &(ws_grid(lay, dir, iter, 0)),
- ws_cell_);
- }
-
- if ((aprop == prop_kind::backward) && rnn.merge_gemm_layer) {
- (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb * rnn.n_iter,
- rnn.n_gates * rnn.dic, 1.0, weights_input(lay, dir, 0),
- rnn.weights_layer_ld,
- (src_data_t *)(&(ws_gates(lay, dir, 0, 0))),
- rnn.gates_ws_ld, 0.0,
- (acc_data_t *)(&(ws_diff_states(
- lay, dir, rnn.n_states, 0, 0))),
- rnn.states_ws_ld);
- gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc,
- rnn.mb * rnn.n_iter, 1.0,
- (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))),
- rnn.gates_ws_ld,
- (src_data_t *)(&(ws_states(lay, dir, 1, 0))),
- rnn.states_ws_ld, 1.0,
- (acc_data_t *)(&(diff_weights_layer(lay, dir, 0))),
- rnn.diff_weights_layer_ld);
- }
- if ((aprop == prop_kind::backward) && rnn.merge_gemm_iter) {
- gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic,
- rnn.mb * rnn.n_iter, 1.0,
- (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))),
- rnn.gates_ws_ld,
- (src_data_t *)(&(ws_states(lay + 1, dir, 0, 0))),
- rnn.states_ws_ld, 1.0,
- (acc_data_t *)(&(diff_weights_iter(lay, dir, 0))),
- rnn.diff_weights_iter_ld);
- }
- }
- }
-}
-
-//********* GRID computations strategy: utility functions **********//
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_init_layer(
- const rnn_conf_t &rnn, src_data_t *__restrict ws_states_,
- float *__restrict ws_diff_states_, const src_data_t *__restrict xt_,
- const float *__restrict diff_dst_layer_) const {
-
- AOC<src_data_t, 4> ws_states(
- ws_states_, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
- auto xt_d = memory_desc_wrapper(pd()->src_md(0));
-
- parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
- auto xxt = xt_ + xt_d.blk_off(it, b);
- src_data_t *ws_l2r_ptr = &(ws_states(0, it + 1, b, 0));
- src_data_t *ws_r2l_ptr = &(ws_states(rnn.n_dir - 1, rnn.n_iter - it, b, 0));
- if (rnn.exec_dir != r2l)
- for (int c = 0; c < rnn.slc; c++)
- ws_l2r_ptr[c] = xxt[c];
- if (rnn.exec_dir != l2r)
- for (int c = 0; c < rnn.slc; c++)
- ws_r2l_ptr[c] = xxt[c];
- });
-}
-
-template <>
-void ref_rnn_bwd_f32_t::copy_init_layer(const rnn_conf_t &rnn,
- src_data_t *ws_states_, float *ws_diff_states_, const src_data_t *xt_,
- const float *diff_dst_layer_) const {
- AOC<float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir,
- (rnn.n_states + 1), rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
- auto diff_dst_layer_d = memory_desc_wrapper(pd()->diff_dst_md(0));
-
- switch (rnn.exec_dir) {
- case bi_concat:
- parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
- auto diff_dst_layer_x
- = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
- for (int s = 0; s < rnn.dic; s++) {
- ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
- = diff_dst_layer_x[s];
- ws_diff_states(
- rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s)
- = diff_dst_layer_x[rnn.dic + s];
- }
- });
- break;
- case bi_sum:
- parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
- auto diff_dst_layer_x
- = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
- for (int s = 0; s < rnn.dic; s++) {
- ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
- = diff_dst_layer_x[s];
- ws_diff_states(
- rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s)
- = diff_dst_layer_x[s];
- }
- });
- break;
- case l2r:
- parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
- auto diff_dst_layer_x
- = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
- for (int s = 0; s < rnn.dic; s++) {
- ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
- = diff_dst_layer_x[s];
- }
- });
- break;
- case r2l:
- parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
- auto diff_dst_layer_x = diff_dst_layer_
- + diff_dst_layer_d.blk_off(rnn.n_iter - it - 1, b);
- for (int s = 0; s < rnn.dic; s++) {
- ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
- = diff_dst_layer_x[s];
- }
- });
- break;
- default: assert(!"Unsupported direction"); break;
- }
-}
-
-/* For int8 configuration, input iteration states may be of types f32 or u8
- * Internally h_state is always stored in u8 and c_state is always stored in f32
- * If input states are of type u8 then h state is copied and c state is dequantized
- * If input states are of type f32 then h state is quantized and c_state is copied
- * */
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-template <typename input_data_t>
-void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_init_iter(
- const rnn_conf_t &rnn, src_data_t *__restrict ws_states_,
- float *__restrict ws_c_states_, float *__restrict ws_diff_states_,
- const input_data_t *__restrict firstit_states_,
- const float *__restrict diff_dst_iter_) const {
- AOC<src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
- rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
- AOC<float, 5> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir,
- rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
- float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
- float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
-
- const bool quantize = pd()->with_src_iter()
- && pd()->src_md(1)->data_type == data_type::f32
- && rnn.dt_conf != all_f32;
- auto maybe_q = [&](input_data_t f) {
- if (quantize) {
- float qf = f * data_scale + data_shift;
- return qz_a1b0<float, src_data_t>()(qf);
- } else
- return (src_data_t)f;
- };
-
- const bool dequantize = pd()->with_src_iter()
- && pd()->src_md(1)->data_type == data_type::u8;
- auto maybe_deq = [&](input_data_t s) {
- if (dequantize)
- return (((float)s - data_shift) / data_scale);
- else
- return (float)s;
- };
- auto firstit_states_d = memory_desc_wrapper(pd()->src_md(1));
- if (firstit_states_) {
- parallel_nd(
- rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) {
- for (int s = 0; s < rnn.sic; s++)
- ws_states(lay + 1, dir, 0, b, s) = maybe_q(
- firstit_states_[firstit_states_d.blk_off(
- lay, dir, 0, b, s)]);
- if (pd()->cell_kind() == alg_kind::vanilla_lstm)
- for (int s = 0; s < rnn.sic; s++)
- ws_c_states(lay + 1, dir, 0, b, s) = maybe_deq(
- firstit_states_[firstit_states_d.blk_off(
- lay, dir, 1, b, s)]);
- });
- } else {
- parallel_nd(
- rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) {
- for (int j = 0; j < rnn.sic; j++) {
- ws_states(lay + 1, dir, 0, b, j) = (src_data_t)0;
- ws_c_states(lay + 1, dir, 0, b, j) = 0.0f;
- }
- });
- }
-}
-
-template <>
-template <typename input_data_t>
-void ref_rnn_bwd_f32_t::copy_init_iter(const rnn_conf_t &rnn,
- src_data_t *ws_states_, float *ws_c_states_, float *ws_diff_states_,
- const input_data_t *firstit_states_,
- const float *diff_dst_iter_) const {
- AOC<float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir,
- rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
- auto diff_dst_iter_d = memory_desc_wrapper(pd()->diff_dst_md(1));
- if (diff_dst_iter_) {
- parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb,
- [&](int lay, int dir, int state, int b) {
- array_copy(&(ws_diff_states(
- lay, dir, state, rnn.n_iter, b, 0)),
- diff_dst_iter_
- + diff_dst_iter_d.blk_off(
- lay, dir, state, b),
- rnn.dic);
- });
- } else {
- parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb,
- [&](int lay, int dir, int state, int i) {
- for (int j = 0; j < rnn.dic; j++)
- ws_diff_states(lay, dir, state, rnn.n_iter, i, j)
- = 0.0f;
- });
- }
-}
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-template <typename dst_data_t>
-void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_res_layer(
- const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer,
- const src_data_t *ws_states_, const float *ws_diff_states_) const {
-
- auto dst_layer_d = memory_desc_wrapper(pd()->dst_md(0));
- AOC<const src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
- rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
- float shift = (pd()->attr()->rnn_data_qparams_.shift_);
- float scale = (pd()->attr()->rnn_data_qparams_.scale_);
-
- const bool dequantize = pd()->dst_md(0)->data_type == data_type::f32
- && rnn.dt_conf != all_f32;
- auto maybe_deq = [&](src_data_t s) {
- if (dequantize)
- return (dst_data_t)(((float)s - shift) / scale);
- else
- return (dst_data_t)s;
- };
- parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
- int dir = 0;
- if (rnn.exec_dir != r2l) {
- for (int s = 0; s < rnn.dic; s++) {
- dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)]
- = maybe_deq(ws_states(rnn.n_layer, dir, it + 1, b, s));
- }
- dir = 1;
- }
- if (rnn.exec_dir != l2r) {
- for (int s = 0; s < rnn.dic; s++)
- switch (rnn.exec_dir) {
- case bi_sum:
- dst_layer_[dst_layer_d.blk_off(it, b, s)]
- += maybe_deq(ws_states(
- rnn.n_layer, dir, rnn.n_iter - it, b, s));
- break;
- default:
- dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)]
- = maybe_deq(ws_states(
- rnn.n_layer, dir, rnn.n_iter - it, b, s));
- }
- }
- });
-}
-
-template <>
-template <typename dst_data_t>
-void ref_rnn_bwd_f32_t::copy_res_layer(
- const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer_,
- const src_data_t *ws_states_, const float *ws_diff_states_) const {
- auto diff_src_layer_d = memory_desc_wrapper(pd()->diff_src_md(0));
- AOC<const float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1,
- rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb,
- rnn.states_ws_ld);
-
- parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
- int dir = 0;
- for (int s = 0; s < rnn.slc; s++) {
- float *dst_addr = diff_src_layer_
- + diff_src_layer_d.blk_off(
- (rnn.exec_dir == r2l) ? rnn.n_iter - 1 - it : it,
- b, dir * rnn.slc + s);
- float res = ws_diff_states(0, 0, rnn.n_states, it, b, s);
- if (rnn.n_dir - 1)
- res += ws_diff_states(
- 0, 1, rnn.n_states, rnn.n_iter - 1 - it, b, s);
- dst_addr[0] = res;
- }
- });
-}
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-template <typename output_data_t>
-void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_res_iter(
- const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_,
- const src_data_t *ws_states_, float *ws_c_states_,
- const float *ws_diff_states_) const {
- auto dst_iter_d = memory_desc_wrapper(pd()->dst_md(1));
- AOC<const src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
- rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
- AOC<const float, 5> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir,
- rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
- float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
- float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
-
- const bool quantize = pd()->with_dst_iter()
- && pd()->dst_md(1)->data_type == data_type::u8
- && rnn.dt_conf != all_f32;
- auto maybe_q = [&](float f) {
- if (quantize) {
- float qf = f * data_scale + data_shift;
- return qz_a1b0<float, output_data_t>()(qf);
- } else
- return (output_data_t)f;
- };
-
- const bool dequantize = pd()->with_dst_iter()
- && pd()->dst_md(1)->data_type == data_type::f32
- && rnn.dt_conf != all_f32;
- auto maybe_deq = [&](src_data_t s) {
- if (dequantize)
- return (output_data_t)(((float)s - data_shift) / data_scale);
- else
- return (output_data_t)s;
- };
- if (dst_iter_) {
- parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb,
- [&](int lay, int dir, int b) {
- for (int s = 0; s < rnn.dic; s++) {
- dst_iter_[dst_iter_d.blk_off(lay, dir, 0, b, s)]
- = maybe_deq(ws_states(lay + 1, dir, rnn.n_iter, b, s));
- }
- if (pd()->cell_kind() == alg_kind::vanilla_lstm)
- for (int s = 0; s < rnn.dic; s++) {
- dst_iter_[dst_iter_d.blk_off(lay, dir, 1, b, s)]
- = maybe_q(ws_c_states(
- lay + 1, dir, rnn.n_iter, b, s));
- }
- });
- }
-}
-
-template <>
-template <typename output_data_t>
-void ref_rnn_bwd_f32_t::copy_res_iter(
- const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_,
- const src_data_t *ws_states_, float *ws_c_states_,
- const float *ws_diff_states_) const {
- auto diff_src_iter_d = memory_desc_wrapper(pd()->diff_src_md(1));
- AOC<const float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1,
- rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb,
- rnn.states_ws_ld);
- if (diff_src_iter_) {
- parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb,
- [&](int lay, int dir, int state, int b) {
- for (int s = 0; s < rnn.sic; s++) {
- diff_src_iter_[diff_src_iter_d.blk_off(
- lay, dir, state, b, s)]
- = ws_diff_states(lay, dir, state, 0, b, s);
- }
- });
- }
-}
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-rnn_bias_prepare_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::bias_prepare)) {
- /* Original set of bias provided by the user */
- AOC<const float, 5> b(
- b_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic);
- /* Array of pointers initialized in packing */
- AOC<float *, 3> bias(bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias);
- AOC<float, 3> scratch_bias(
- scratch_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic);
-
- if (rnn.copy_bias) {
- parallel_nd(rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic,
- [&](size_t i) { scratch_bias_[i] = b_[i]; });
- }
-
- for (int i = 0; i < rnn.n_layer; i++) {
- for (int d = 0; d < rnn.n_dir; d++) {
- int offset_bias = 0;
- for (int p = 0; p < rnn.n_parts_bias; p++) {
- bias(i, d, p) = rnn.copy_bias
- ? (float *) &scratch_bias(i, d, offset_bias)
- : (float *) &b(i, d, offset_bias);
- offset_bias += rnn.parts_bias[p] * rnn.dic;
- }
- }
- }
-
-}
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-rnn_bias_finalize_sig(
- (_ref_rnn_common_t<aprop, src_type, weights_type>::bias_finalize)) {
- if (rnn.dt_conf != all_f32) {
- float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
- float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
- float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_;
- bool scale_per_oc = pd()->attr()->rnn_weights_qparams_.mask_ != 0;
- for (int i = 0; i < rnn.n_layer * rnn.n_dir; i++)
- for (int j = 0; j < rnn.n_bias * rnn.dic; j++) {
- size_t off = i * rnn.n_bias * rnn.dic + j;
- float weights_scale
- = scale_per_oc ? weights_scales[j] : weights_scales[0];
- scratch_bias_[off] -= (w_iter_comp[off] + w_layer_comp[off])
- * data_shift / (weights_scale * data_scale);
- }
- }
-}
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-rnn_weights_assign_sig((_ref_rnn_common_t<aprop, src_type,
- weights_type>::assign_packed_weights)) {
- assert(md->format_kind == format_kind::rnn_packed);
- const auto packed_desc = md->format_desc.rnn_packed_desc;
- AOC<weights_data_t *, 3> weights(weights_,
- rnn.n_layer, rnn.n_dir, packed_desc.n_parts);
-
- size_t offset_packed = 0;
- for (int l = 0; l < rnn.n_layer; l++)
- for (int d = 0; d < rnn.n_dir; d++) {
- for (int p = 0; p < packed_desc.n_parts; p++) {
- weights(l, d, p) = (weights_data_t *)&w_[offset_packed];
- offset_packed
- += packed_desc.part_pack_size[p] / sizeof(weights_data_t);
- }
- }
-}
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-rnn_weights_assign_sig(
- (_ref_rnn_common_t<aprop, src_type, weights_type>::assign_weights)) {
- assert(md->format_kind == format_kind::blocked);
- const auto &blk = md->format_desc.blocking;
- /* Original set of weights provided by the user */
- AOC<const weights_data_t, 3> w(w_,
- rnn.n_layer, rnn.n_dir, (int)blk.strides[1]);
- /* Array of pointers for each part of weights */
- AOC<weights_data_t *, 3> weights(weights_, rnn.n_layer, rnn.n_dir, n_parts);
-
- for (int i = 0; i < rnn.n_layer; i++)
- for (int d = 0; d < rnn.n_dir; d++) {
- size_t offset_weights = 0;
- for (int p = 0; p < n_parts; p++) {
- weights(i, d, p) = (weights_data_t *)&w(i, d, offset_weights);
- offset_weights += gates_per_part[p] * blk.strides[3];
- }
- }
-}
-
-//********************* Execution function *********************//
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-void _ref_rnn_common_t<aprop, src_type, weights_type>::execute_(
- const exec_ctx_t &ctx) const {
- const rnn_conf_t &rnn = this->pd()->rnn_;
- auto input = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC_LAYER);
- auto states = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC_ITER);
- auto layer_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_LAYER);
- auto iter_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_ITER);
- auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS);
-
- auto dst_last_layer = rnn.is_fwd
- ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_LAYER)
- : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_LAYER));
- auto dst_last_iter = rnn.is_fwd
- ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_ITER)
- : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_ITER));
-
- auto diff_dst_layer = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_LAYER);
- auto diff_dst_iter = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_ITER);
-
- auto w_layer = reinterpret_cast<const weights_data_t *>(layer_weights_n_comp);
- auto w_iter = reinterpret_cast<const weights_data_t *>(iter_weights_n_comp);
- auto w_iter_comp = reinterpret_cast<const float *>(
- iter_weights_n_comp + rnn.weights_iter_comp_offset);
- auto w_layer_comp = reinterpret_cast<const float *>(
- layer_weights_n_comp + rnn.weights_layer_comp_offset);
-
- auto scratchpad = this->scratchpad(ctx);
-
- auto ptr_wei_layer
- = scratchpad.template get<weights_data_t *>(key_rnn_ptrs_wei_layer);
- auto ptr_wei_iter
- = scratchpad.template get<weights_data_t *>(key_rnn_ptrs_wei_iter);
- auto ptr_bias =
- scratchpad.template get<float *>(key_rnn_ptrs_bia);
-
- // fetchihg buffers from the workspace
- // if no workspace was provided we use the scratchpad
- char *scratch_ptr = scratchpad.template get<char>(key_rnn_space);
- char *ws_ptr = nullptr;
- if (rnn.use_workspace)
- ws_ptr = rnn.is_fwd
- ? CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE)
- : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE));
-
- char *base_ptr = rnn.use_workspace ? ws_ptr : scratch_ptr;
- acc_data_t *ws_gates = (acc_data_t *)(base_ptr + ws_gates_offset_);
- src_data_t *ws_states = (src_data_t *)(base_ptr + ws_states_offset_);
- float *ws_c_states = (float *)(base_ptr + ws_c_states_offset_);
- float *ws_diff_states = (float *)(base_ptr + ws_diff_states_offset_);
- float *ws_grid = (float *)(base_ptr + ws_grid_comp_offset_);
- float *ws_cell = (float *)(base_ptr + ws_cell_comp_offset_);
-
- auto diff_src_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_LAYER);
- auto diff_src_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_ITER);
-
- auto diff_weights_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_LAYER);
- auto diff_weights_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_ITER);
- auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS);
-
- // Fetching extra buffers from scratchpad
- float *ws_bias = (float *)(scratch_ptr + ws_bias_offset_);
-
- // initialize diff_states to 0
- if (aprop == prop_kind::backward)
- array_set(ws_diff_states, 0.0f, rnn.ws_diff_states_size / sizeof(float));
-
- /* Pack(if using packed gemm API) or copy(if input arrays have bad leading
- * dimension */
- (this->*bias_preparation_func)(rnn, ptr_bias, bias, ws_bias);
-
- (this->*weights_iter_assign_func)(rnn, pd()->weights_md(1),
- rnn.weights_iter_nld, rnn.weights_iter_ld, rnn.dic,
- rnn.sic, rnn.n_parts_weights_iter, rnn.parts_weights_iter,
- rnn.part_weights_iter_pack_size, ptr_wei_iter, w_iter,
- ptr_bias, bias, ws_bias);
- (this->*weights_layer_assign_func)(rnn, pd()->weights_md(0),
- rnn.weights_layer_nld, rnn.weights_layer_ld, rnn.dic, rnn.slc,
- rnn.n_parts_weights_layer, rnn.parts_weights_layer,
- rnn.part_weights_layer_pack_size, ptr_wei_layer, w_layer, ptr_bias,
- bias, ws_bias);
-
- (this->*bias_finalization_func)(rnn, ws_bias, w_iter_comp, w_layer_comp);
-
- // we first need to copy the initial states and input into ws
- copy_init_layer(rnn, ws_states, ws_diff_states, input, diff_dst_layer);
- if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32
- || rnn.dt_conf == all_f32)
- copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states,
- (const float *)states, diff_dst_iter);
- else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32)
- copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states,
- (const uint8_t *)states, diff_dst_iter);
- else
- assert(!"unimplemented");
-
- // run the execution on the grid
- (this->*grid_computation)(rnn, ptr_wei_layer, ptr_wei_iter, ptr_bias,
- ws_states, ws_c_states, ws_diff_states, ws_gates, ws_cell, ws_grid,
- diff_weights_layer, diff_weights_iter, diff_bias);
-
- // Finally we copy the results to the result buffers
- if (rnn.dt_conf == u8u8u8f32 || rnn.dt_conf == f32u8f32f32
- || rnn.dt_conf == all_f32)
- copy_res_layer(rnn, (float *)dst_last_layer, diff_src_layer, ws_states,
- ws_diff_states);
- else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == f32u8f32u8)
- copy_res_layer(rnn, (uint8_t *)dst_last_layer, diff_src_layer,
- ws_states, ws_diff_states);
- else
- assert(!"unimplemented");
-
- if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32
- || rnn.dt_conf == all_f32)
- copy_res_iter(rnn, (float *)dst_last_iter, diff_src_iter, ws_states,
- ws_c_states, ws_diff_states);
- else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32)
- copy_res_iter(rnn, (uint8_t *)dst_last_iter, diff_src_iter, ws_states,
- ws_c_states, ws_diff_states);
- else
- assert(!"unimplemented");
-};
-
-/* Fix for MSVS warning C4661 */
-template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution);
-template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution);
-template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution);
-template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru);
-template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru);
-template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru);
-template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr);
-template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr);
-template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr);
-template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise);
-template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise);
-template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise);
-template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise);
-template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise);
-template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise);
-template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise);
-template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise);
-template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise);
-
-template struct _ref_rnn_common_t<prop_kind::forward, data_type::f32, data_type::f32>;
-template struct _ref_rnn_common_t<prop_kind::forward, data_type::u8, data_type::s8>;
-template struct _ref_rnn_common_t<prop_kind::backward, data_type::f32, data_type::f32>;
-
-#undef AOC
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp
deleted file mode 100644
index 6f449a9016..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp
+++ /dev/null
@@ -1,328 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_REF_RNN_HPP
-#define CPU_REF_RNN_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "memory_tracking.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-#include "../cpu_isa_traits.hpp"
-#include "../gemm/os_blas.hpp"
-
-#include "cpu_rnn_pd.hpp"
-#include "../cpu_primitive.hpp"
-#include "rnn_utils.hpp"
-#include "jit_uni_rnn_postgemm.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <alg_kind_t alg_kind, prop_kind_t prop_kind>
-float activation(float s, float alpha, float cliping, float dd);
-
-template <prop_kind_t aprop, impl::data_type_t src_type,
- impl::data_type_t weights_type>
-struct _ref_rnn_common_t : public cpu_primitive_t {
- typedef typename prec_traits<src_type>::type src_data_t;
- typedef typename prec_traits<weights_type>::type weights_data_t;
- typedef typename utils::conditional<src_type == data_type::u8, int32_t,
- float>::type acc_data_t;
-
- using class_name = _ref_rnn_common_t<aprop, src_type, weights_type>;
-
- typedef rnn_elemwise_sig((class_name::*elemwise_f));
- typedef rnn_cell_execution_sig((class_name::*cell_execution_f));
- typedef rnn_grid_execution_sig((class_name::*grid_execution_f));
-
- typedef rnn_gemm_sig((class_name::*gemm_t));
- typedef rnn_bias_prepare_sig((class_name::*bias_prepare_t));
- typedef rnn_bias_finalize_sig((class_name::*bias_finalize_t));
- typedef rnn_weights_assign_sig((class_name::*weights_assign_t));
-
- using base_pd_t =
- typename utils::conditional<false || aprop == prop_kind::forward,
- cpu_rnn_fwd_pd_t, cpu_rnn_bwd_pd_t>::type;
-
- struct pd_t : public base_pd_t {
- using base_pd_t::base_pd_t;
-
- DECLARE_COMMON_PD_T("ref:any", class_name);
-
- status_t init() {
- using namespace prop_kind;
- using namespace utils;
- using namespace format_tag;
- using namespace rnn_utils;
- const alg_kind_t cell_kind = this->desc()->cell_desc.cell_kind;
-
- data_type_t src_layer_dt = this->desc()->src_layer_desc.data_type;
- data_type_t weights_iter_dt
- = this->desc()->weights_iter_desc.data_type;
- data_type_t weights_layer_dt
- = this->desc()->weights_layer_desc.data_type;
-
- bool ok = true
- && one_of(cell_kind, alg_kind::vanilla_rnn,
- alg_kind::vanilla_lstm, alg_kind::vanilla_gru,
- alg_kind::gru_linear_before_reset)
- && IMPLICATION(aprop == prop_kind::forward,
- one_of(this->desc()->prop_kind, forward_training,
- forward_inference))
- && IMPLICATION(aprop == backward,
- one_of(this->desc()->prop_kind, backward))
- && src_layer_dt == src_type
- && everyone_is(
- weights_type, weights_iter_dt, weights_layer_dt)
- && this->set_default_params() == status::success
- && this->with_bias();
- if (!ok)
- return status::unimplemented;
-
- init_conf(rnn_, *this->desc(), this->src_md(0), this->src_md(1),
- this->weights_md(0), this->weights_md(1), this->dst_md(0));
-
- if (rnn_.dt_conf == all_f32)
- ok = ok && this->attr()->has_default_values();
-
- // Set weights descriptors to desired format
- memory_desc_t new_weights_layer_md = *this->weights_md(0);
- CHECK(set_expected_desc(rnn_, new_weights_layer_md, false));
- if (this->weights_layer_md_.format_kind == format_kind::any) {
- this->weights_layer_md_ = new_weights_layer_md;
- } else if (this->weights_layer_md_.format_kind
- == format_kind::rnn_packed) {
- if (this->weights_layer_md_ != new_weights_layer_md)
- return status::unimplemented;
- }
-
- memory_desc_t new_weights_iter_md = *this->weights_md(1);
- CHECK(set_expected_desc(rnn_, new_weights_iter_md, true));
- if (this->weights_iter_md_.format_kind == format_kind::any) {
- this->weights_iter_md_ = new_weights_iter_md;
- } else if (this->weights_iter_md_.format_kind
- == format_kind::rnn_packed) {
- if (this->weights_iter_md_ != new_weights_iter_md)
- return status::unimplemented;
- }
-
- CHECK(this->check_layout_consistency());
-
- set_conf(rnn_, *this->desc(), this->weights_md(0),
- this->weights_md(1), this->diff_weights_md(0),
- this->diff_weights_md(1));
-
- size_t scratchpad_sz{0}, ws_sz{0};
- get_scratchpad_and_workspace_sizes(rnn_, scratchpad_sz, ws_sz);
-
- // initialize the workspace if needed
- if (rnn_.is_training) {
- dims_t ws_dims = { (int)ws_sz };
- mkldnn_memory_desc_init_by_tag(&this->ws_md_, 1, ws_dims,
- data_type::u8, format_tag::x);
- }
-
- init_scratchpad(scratchpad_sz);
-
- return status::success;
- }
-
- rnn_utils::rnn_conf_t rnn_;
-
- private:
- void init_scratchpad(size_t scratchpad_sz) {
- using namespace memory_tracking::names;
- auto scratchpad = this->scratchpad_registry().registrar();
- scratchpad.book(key_rnn_space, sizeof(float) * scratchpad_sz, 4096);
-
- int max_nparts = this->cell_kind() == alg_kind::vanilla_gru ? 2 : 1;
- int ptr_wei_sz = rnn_.n_layer * rnn_.n_dir * max_nparts;
- scratchpad.book(key_rnn_ptrs_wei_layer,
- sizeof(float *) * ptr_wei_sz);
- scratchpad.book(key_rnn_ptrs_wei_iter,
- sizeof(float *) * ptr_wei_sz);
- scratchpad.book(key_rnn_ptrs_bia,
- sizeof(float *) * ptr_wei_sz);
- }
- };
-
- _ref_rnn_common_t(const pd_t *apd)
- : cpu_primitive_t(apd, true), rnn_postgemm_(nullptr) {
- /// @todo set max_feature_size assuming that we limit the number of
- /// iterations and layer to one if slc != dic and sic != dic
- /// respectively
-
- bias_preparation_func = &class_name::bias_prepare;
- bias_finalization_func = &class_name::bias_finalize;
-
- auto set_gemm_funcs
- = [](bool packed_gemm, gemm_t &g, weights_assign_t &a) {
- if (packed_gemm) {
- g = &class_name::packed_gemm;
- a = &class_name::assign_packed_weights;
- } else {
- g = &class_name::gemm;
- a = &class_name::assign_weights;
- }
- };
- set_gemm_funcs(pd()->rnn_.use_iter_packed_gemm, gemm_iter_func,
- weights_iter_assign_func);
-
- set_gemm_funcs(pd()->rnn_.use_layer_packed_gemm, gemm_layer_func,
- weights_layer_assign_func);
-
- switch (pd()->cell_kind()) {
- case alg_kind::vanilla_lstm:
- cell_func = &class_name::cell_execution;
- if (aprop == prop_kind::forward) {
- if (mayiuse(avx512_core))
- rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<avx512_core, src_type>(
- pd()->rnn_, pd()->attr());
- else if (mayiuse(avx2))
- rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<avx2, src_type>(
- pd()->rnn_, pd()->attr());
- else if (mayiuse(sse42))
- rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<sse42, src_type>(
- pd()->rnn_, pd()->attr());
- assert(rnn_postgemm_ != nullptr);
- rnn_postgemm_->init();
- }
- elemwise_func = &class_name::lstm_elemwise;
- break;
- case alg_kind::vanilla_rnn: // @todo switch on cell kind
- cell_func = &class_name::cell_execution;
- elemwise_func = &class_name::rnn_elemwise;
- switch (pd()->activation_kind()) {
- case alg_kind::eltwise_relu:
- activation_func = &activation<alg_kind::eltwise_relu, aprop>;
- break;
- case alg_kind::eltwise_tanh:
- activation_func = &activation<alg_kind::eltwise_tanh, aprop>;
- break;
- case alg_kind::eltwise_logistic:
- activation_func = &activation<alg_kind::eltwise_logistic, aprop>;
- break;
- default: break;
- }
- break;
- case alg_kind::vanilla_gru:
- cell_func = &class_name::cell_execution_gru;
- break;
- case alg_kind::gru_linear_before_reset:
- cell_func = &class_name::cell_execution_gru_lbr;
- elemwise_func = &class_name::gru_lbr_elemwise;
- break;
- default: break;
- }
-
- grid_computation = &class_name::linear_execution;
-
- size_t scratchpad_size, workspace_size;
- rnn_utils::set_offsets(pd()->rnn_, ws_gates_offset_, ws_states_offset_,
- ws_c_states_offset_, ws_diff_states_offset_,
- ws_grid_comp_offset_, ws_cell_comp_offset_,
- ws_bias_offset_, scratchpad_size, workspace_size);
- }
-
- ~_ref_rnn_common_t() {}
-
- // typedef typename prec_traits::type data_t;
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- execute_(ctx);
- return status::success;
- }
-
-private:
- void execute_(const exec_ctx_t &ctx) const;
- rnn_grid_execution_sig(linear_execution);
- rnn_cell_execution_sig(cell_execution);
- rnn_cell_execution_sig(cell_execution_gru);
- rnn_cell_execution_sig(cell_execution_gru_lbr);
- rnn_elemwise_sig(rnn_elemwise);
- rnn_elemwise_sig(lstm_elemwise);
- rnn_elemwise_sig(gru_lbr_elemwise);
- rnn_gemm_sig(gemm);
- rnn_gemm_sig(packed_gemm);
- rnn_bias_prepare_sig(bias_prepare);
- rnn_bias_finalize_sig(bias_finalize);
- rnn_weights_assign_sig(assign_weights);
- rnn_weights_assign_sig(assign_packed_weights);
-
- float (*activation_func)(float dd, float s, float alpha, float cliping);
-
- void copy_init_layer(const rnn_utils::rnn_conf_t &rnn,
- src_data_t *ws_states_, float *ws_diff_states_,
- const src_data_t *xt_, const float *diff_dst_layer) const;
-
- template <typename input_data_t>
- void copy_init_iter(const rnn_utils::rnn_conf_t &rnn,
- src_data_t *ws_states_, float *ws_c_states, float *ws_diff_states_,
- const input_data_t *firstit_states_,
- const float *diff_dst_iter) const;
-
- template <typename dst_data_t>
- void copy_res_layer(const rnn_utils::rnn_conf_t &rnn,
- dst_data_t *dst_layer_, float *diff_src_layer,
- const src_data_t *ws_states_, const float *ws_diff_states_) const;
-
- template <typename output_data_t>
- void copy_res_iter(const rnn_utils::rnn_conf_t &rnn,
- output_data_t *dst_iter_, float *diff_src_iter,
- const src_data_t *ws_states_, float *ws_c_states,
- const float *ws_diff_states_) const;
-
- void gates_reduction(const rnn_utils::rnn_conf_t &rnn,
- const acc_data_t *ws_gates_, float *diff_bias_) const;
-
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-
- size_t ws_gates_offset_;
- size_t ws_states_offset_;
- size_t ws_c_states_offset_;
- size_t ws_bias_offset_;
- size_t ws_diff_states_offset_;
- size_t ws_grid_comp_offset_;
- size_t ws_cell_comp_offset_;
- jit_uni_rnn_postgemm_kernel *rnn_postgemm_;
-
- grid_execution_f grid_computation;
- cell_execution_f cell_func;
-
- bias_prepare_t bias_preparation_func;
- bias_finalize_t bias_finalization_func;
- weights_assign_t weights_layer_assign_func;
- weights_assign_t weights_iter_assign_func;
-
- gemm_t gemm_layer_func;
- gemm_t gemm_iter_func;
- elemwise_f elemwise_func;
-};
-
-using ref_rnn_fwd_f32_t = _ref_rnn_common_t<prop_kind::forward, data_type::f32, data_type::f32>;
-using ref_rnn_bwd_f32_t = _ref_rnn_common_t<prop_kind::backward, data_type::f32, data_type::f32>;
-using ref_rnn_fwd_u8s8_t = _ref_rnn_common_t<prop_kind::forward, data_type::u8, data_type::s8>;
-}
-}
-}
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp
deleted file mode 100644
index 78cdedbae4..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp
+++ /dev/null
@@ -1,380 +0,0 @@
-/*******************************************************************************
- * Copyright 2018 Intel Corporation
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *******************************************************************************/
-
-#ifndef CPU_RNN_REORDERS_HPP
-#define CPU_RNN_REORDERS_HPP
-
-#include <assert.h>
-
-#include "type_helpers.hpp"
-#include "mkldnn_thread.hpp"
-#include "utils.hpp"
-#include "simple_q10n.hpp"
-#include "cpu_reorder_pd.hpp"
-#include "../gemm/os_blas.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <data_type_t type_i, data_type_t type_o>
-struct rnn_data_reorder_t : public cpu_primitive_t {
- struct pd_t : public cpu_reorder_pd_t {
- using cpu_reorder_pd_t::cpu_reorder_pd_t;
-
- DECLARE_COMMON_PD_T("rnn_data_reorder", rnn_data_reorder_t);
-
- static status_t create(reorder_pd_t **reorder_pd,
- engine_t *engine, const primitive_attr_t *attr,
- engine_t *src_engine, const memory_desc_t *src_md,
- engine_t *dst_engine, const memory_desc_t *dst_md) {
- const memory_desc_wrapper id(src_md), od(dst_md);
- bool args_ok = true
- && id.data_type() == type_i
- && od.data_type() == type_o
- && id.matches_one_of_tag(format_tag::tnc, format_tag::ldsnc)
- && od == id;
- if (!args_ok) return status::invalid_arguments;
-
- auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
- dst_md);
- if (_pd == nullptr) return out_of_memory;
- if (_pd->init() != success) { delete _pd; return unimplemented; }
- return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
- }
- };
-
-private:
- typedef typename prec_traits<type_i>::type in_data_t;
- typedef typename prec_traits<type_o>::type out_data_t;
-
- rnn_data_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {}
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM);
- auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO);
- const memory_desc_wrapper &input_d = pd()->src_md();
- const memory_desc_wrapper &output_d = pd()->dst_md();
- const size_t nelems = input_d.nelems();
- const float scale = pd()->attr()->rnn_data_qparams_.scale_;
- const float shift = pd()->attr()->rnn_data_qparams_.shift_;
-
- parallel_nd(nelems, [&](size_t i) {
- float in = (float)input[input_d.off_l(i)] * scale + shift;
- output[output_d.off_l(i)] = qz_a1b0<float, out_data_t>()(in);
- });
-
- return status::success;
- }
-
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-template <data_type_t type_i, data_type_t type_o>
-struct rnn_weights_reorder_t : public cpu_primitive_t {
- struct pd_t : public cpu_reorder_pd_t {
- using cpu_reorder_pd_t::cpu_reorder_pd_t;
-
- DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t);
-
- static status_t create(reorder_pd_t **reorder_pd,
- engine_t *engine, const primitive_attr_t *attr,
- engine_t *src_engine, const memory_desc_t *src_md,
- engine_t *dst_engine, const memory_desc_t *dst_md) {
-#if !USE_MKL_PACKED_GEMM
- return status::unimplemented;
-#endif
- const memory_desc_wrapper id(src_md), od(dst_md);
- bool args_ok = true
- && id.data_type() == type_i
- && od.data_type() == type_o
- && od.format_kind() == format_kind::rnn_packed
- && od.rnn_packed_desc().format == mkldnn_ldigo_p
- && od.rnn_packed_desc().n_parts == 1
- && attr != nullptr;
- if (!args_ok) return status::invalid_arguments;
-
- format_tag_t itag = id.matches_one_of_tag(
- format_tag::ldigo, format_tag::ldgoi);
- if (itag == format_tag::undef) return status::invalid_arguments;
-
- const int mask = attr->rnn_weights_qparams_.mask_;
- if (!utils::one_of(mask, 0, 3)) return status::unimplemented;
-
- auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
- dst_md);
- if (_pd == nullptr) return out_of_memory;
- _pd->itag_ = itag;
- if (_pd->init() != success) { delete _pd; return unimplemented; }
- return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
- }
-
- status_t init() {
- status_t status = cpu_reorder_pd_t::init();
- if (status != status::success) return status;
-
- init_scratchpad();
-
- return status::success;
- }
-
- format_tag_t itag_ = mkldnn_format_tag_undef;
-
- private:
- void init_scratchpad() {
- const memory_desc_wrapper id(src_md());
- const size_t nelems = id.nelems();
- const auto &dims = id.dims();
-
- using namespace memory_tracking::names;
- auto scratchpad = scratchpad_registry().registrar();
- size_t quantization_size = sizeof(int8_t) * nelems;
- size_t reduction_size = itag_ == ldigo
- ? sizeof(int32_t) * mkldnn_get_max_threads() * dims[0]
- * dims[1] * dims[3] * dims[4]
- : 0;
- scratchpad.book(
- key_reorder_rnn_weights_quantization, quantization_size);
- scratchpad.book(key_reorder_rnn_weights_reduction, reduction_size);
- }
- };
-
-private:
- typedef typename prec_traits<type_i>::type in_data_t;
- typedef typename prec_traits<type_o>::type out_data_t;
-
- rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {}
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
-#if USE_MKL_PACKED_GEMM
- auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM);
- auto output = CTX_OUT_MEM(char *, MKLDNN_ARG_TO);
- const memory_desc_wrapper &input_d = pd()->src_md();
- const memory_desc_wrapper &output_d = pd()->dst_md();
- const auto &dims = input_d.dims();
-
- const int L = dims[0];
- const int D = dims[1];
- const int I = dims[2];
- const int G = dims[3];
- const int O = dims[4];
-
- const bool is_igo = pd()->itag_ == format_tag::ldigo;
-
- /* Quantize input & compute compensation */
- auto quantized = (int8_t * __restrict)scratchpad(ctx).template get<void>(
- memory_tracking::names::key_reorder_rnn_weights_quantization);
- auto reduction = (int32_t * __restrict)scratchpad(ctx).template get<void>(
- memory_tracking::names::key_reorder_rnn_weights_reduction);
- float *comp = reinterpret_cast<float *>(
- output + output_d.rnn_packed_desc().offset_compensation);
- const float *scales = pd()->attr()->rnn_weights_qparams_.scales_;
- const int mask = pd()->attr()->rnn_weights_qparams_.mask_;
-
- if (is_igo) {
- int nthr = mkldnn_get_max_threads();
- int LD_nthr = nstl::min(L * D, nthr);
- int I_nthr = nstl::min(I, nthr / LD_nthr);
- parallel(nthr, [&](const int ithr, const int nthr) {
- int LD_ithr = -1, LD_s = -1, LD_e = -1;
- int I_ithr = -1, I_s = -1, I_e = -1;
- if (ithr < LD_nthr * I_nthr) {
- LD_ithr = ithr % LD_nthr;
- I_ithr = ithr / LD_nthr;
- balance211(L * D, LD_nthr, LD_ithr, LD_s, LD_e);
- balance211(I, I_nthr, I_ithr, I_s, I_e);
- }
- int32_t *comp_ithr = reduction + I_ithr * L * D * G * O;
- for (int ld = LD_s; ld < LD_e; ld++) {
- for (int go = 0; go < G * O; go++)
- comp_ithr[ld * G * O + go] = 0;
- for (int i = I_s; i < I_e; i++) {
- PRAGMA_OMP_SIMD()
- for (int go = 0; go < G * O; go++) {
- const float s = scales[(mask == 0) ? 0 : go];
- int8_t q = qz_b0<in_data_t, out_data_t>()(
- input[ld * I * G * O + i * G * O + go], s);
- quantized[ld * I * G * O + i * G * O + go]
- = (int32_t)q;
- comp_ithr[ld * G * O + go] += (int32_t)q;
- }
- }
- }
- });
- parallel_nd(L * D * G * O,
- [&](int s) { comp[s] = saturate<float>(reduction[s]); });
- for (int i = 1; i < I_nthr; i++) {
- parallel_nd(L * D * G * O, [&](int s) {
- comp[s] += saturate<float>(
- reduction[i * L * D * G * O + s]);
- });
- }
- } else {
- parallel_nd(L * D, G * O, [&](int ld, int go) {
- int32_t compensation = 0;
- const float s = scales[(mask == 0) ? 0 : go];
- PRAGMA_OMP_SIMD()
- for (int i = 0; i < I; i++) {
- int8_t q = qz_b0<in_data_t, out_data_t>()(
- input[ld * G * O * I + go * I + i], s);
- compensation += (int32_t)q;
- quantized[ld * G * O * I + go * I + i] = q;
- }
- comp[ld * G * O + go] = saturate<float>(compensation);
- });
- }
-
- /* Pack */
- auto off_igo = [&](int l, int d, int i, int g, int o) {
- return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o;
- };
- auto off_goi = [&](int l, int d, int i, int g, int o) {
- return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i;
- };
- int n_parts = output_d.rnn_packed_desc().n_parts;
- const size_t *size_packed_cell
- = output_d.rnn_packed_desc().part_pack_size;
- const int *parts = output_d.rnn_packed_desc().parts;
- const int n = output_d.rnn_packed_desc().n;
- char *to_pack = output;
- for (int l = 0; l < L; l++) {
- for (int d = 0; d < D; d++) {
- for (int p = 0; p < n_parts; p++) {
- int g = (p > 0) ? parts[p - 1] : 0;
- int m_p = parts[p] * O;
- int k_p = I;
- cblas_gemm_s8u8s32_pack(CblasColMajor, CblasAMatrix,
- is_igo ? CblasNoTrans : CblasTrans, m_p, n, k_p,
- &quantized[is_igo ? off_igo(l, d, 0, g, 0) :
- off_goi(l, d, g, 0, 0)],
- is_igo ? G * O : I, to_pack);
- to_pack += size_packed_cell[p];
- }
- }
- }
-#endif
- return status::success;
- }
-
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-template <>
-struct rnn_weights_reorder_t<data_type::f32, data_type::f32>
- : public cpu_primitive_t {
- struct pd_t : public cpu_reorder_pd_t {
- using cpu_reorder_pd_t::cpu_reorder_pd_t;
-
- DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t);
-
- static status_t create(reorder_pd_t **reorder_pd,
- engine_t *engine, const primitive_attr_t *attr,
- engine_t *src_engine, const memory_desc_t *src_md,
- engine_t *dst_engine, const memory_desc_t *dst_md) {
-#if !USE_MKL_PACKED_GEMM
- return status::unimplemented;
-#endif
- const memory_desc_wrapper id(src_md), od(dst_md);
- bool args_ok = true
- && id.data_type() == data_type::f32
- && od.data_type() == data_type::f32
- && od.format_kind() == format_kind::rnn_packed
- && utils::one_of(od.rnn_packed_desc().format,
- mkldnn_ldigo_p, mkldnn_ldgoi_p)
- && attr->has_default_values();
- if (!args_ok) return status::invalid_arguments;
-
- format_tag_t itag = id.matches_one_of_tag(
- format_tag::ldigo, format_tag::ldgoi);
- if (itag == format_tag::undef) return status::invalid_arguments;
-
- const int mask = attr->rnn_weights_qparams_.mask_;
- if (!utils::one_of(mask, 0, 3)) return status::unimplemented;
-
- auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
- dst_md);
- if (_pd == nullptr) return out_of_memory;
- if (_pd->init() != success) { delete _pd; return unimplemented; }
- _pd->itag_ = itag;
- return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
- }
-
- format_tag_t itag_;
- };
-
-private:
- rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {}
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
-#if USE_MKL_PACKED_GEMM
- auto input = CTX_IN_MEM(const float *, MKLDNN_ARG_FROM);
- auto output = CTX_OUT_MEM(float *, MKLDNN_ARG_TO);
- const memory_desc_wrapper &input_d = pd()->src_md();
- const memory_desc_wrapper &output_d = pd()->dst_md();
- const auto &dims = input_d.dims();
- const rnn_packed_desc_t &rnn_pdata = output_d.rnn_packed_desc();
- const int L = dims[0];
- const int D = dims[1];
- const int I = dims[2];
- const int G = dims[3];
- const int O = dims[4];
-
- /* Pack */
- bool cross_case = false
- || (pd()->itag_ == format_tag::ldigo
- && rnn_pdata.format == mkldnn_ldgoi_p)
- || (pd()->itag_ == format_tag::ldgoi
- && rnn_pdata.format == mkldnn_ldigo_p);
- auto trans = cross_case ? CblasTrans : CblasNoTrans;
- int n_parts = rnn_pdata.n_parts;
- const size_t *size_packed_cell = rnn_pdata.part_pack_size;
- const int *parts = rnn_pdata.parts;
- const int n = rnn_pdata.n;
-
- const bool is_igo = pd()->itag_ == format_tag::ldigo;
- auto off_igo = [&](int l, int d, int i, int g, int o) {
- return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o;
- };
- auto off_goi = [&](int l, int d, int i, int g, int o) {
- return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i;
- };
- for (int l = 0; l < L; l++) {
- for (int d = 0; d < D; d++) {
- for (int p = 0; p < n_parts; p++) {
- int g = (p > 0) ? parts[p - 1] : 0;
- int m_p = is_igo ? parts[p] * O : I;
- int k_p = is_igo ? I : parts[p] * O;
- int ld = is_igo ? G * O : I;
- cblas_sgemm_pack(CblasColMajor, CblasAMatrix, trans, m_p, n,
- k_p, 1.0f, &input[is_igo ? off_igo(l, d, 0, g, 0) :
- off_goi(l, d, 0, g, 0)],
- ld, output);
- output += size_packed_cell[p] / sizeof(float);
- }
- }
- }
-#endif
- return status::success;
- }
-
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-} // namespace cpu
-} // namespace impl
-} // namespace mkldnn
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp
deleted file mode 100644
index 1d60415cbc..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp
+++ /dev/null
@@ -1,426 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "c_types_map.hpp"
-#include "math_utils.hpp"
-#include "mkldnn_thread.hpp"
-
-#include "ref_rnn.hpp"
-#include "rnn_utils.hpp"
-#include "type_helpers.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::utils;
-using namespace rnn_utils;
-using namespace format_tag;
-using namespace rnn_packed_format;
-using namespace data_type;
-
-bool rnn_utils::is_ldigo(const memory_desc_wrapper &md) {
- if (md.format_kind() != format_kind::blocked)
- return false;
-
- auto blk = md.blocking_desc();
- auto str = blk.strides;
- auto dims = md.dims();
- return md.ndims() == 5 && blk.inner_nblks == 0 && str[4] == 1
- && str[3] == dims[4] && str[1] == str[2] * dims[2]
- && str[0] == str[1] * dims[1];
-};
-
-bool rnn_utils::is_ldgoi(const memory_desc_wrapper &md) {
- if (md.format_kind() != format_kind::blocked)
- return false;
-
- auto blk = md.blocking_desc();
- auto str = blk.strides;
- auto dims = md.dims();
- return md.ndims() == 5 && blk.inner_nblks == 0 && str[2] == 1
- && str[3] == dims[4] * str[4] && str[1] == str[3] * dims[3]
- && str[0] == str[1] * dims[1];
-};
-
-void rnn_utils::init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
- const memory_desc_wrapper &src_layer_d,
- const memory_desc_wrapper &src_iter_d,
- const memory_desc_wrapper &weights_layer_d,
- const memory_desc_wrapper &weights_iter_d,
- const memory_desc_wrapper &dst_layer_d) {
- rnn.is_fwd = utils::one_of(rd.prop_kind, prop_kind::forward_training,
- prop_kind::forward_inference);
- rnn.is_training = utils::one_of(
- rd.prop_kind, prop_kind::forward_training, prop_kind::backward);
- rnn.is_lbr = rd.cell_desc.cell_kind == mkldnn_gru_linear_before_reset;
-
- switch (rd.direction) {
- case mkldnn_unidirectional_left2right: rnn.exec_dir = l2r; break;
- case mkldnn_unidirectional_right2left: rnn.exec_dir = r2l; break;
- case mkldnn_bidirectional_concat: rnn.exec_dir = bi_concat; break;
- case mkldnn_bidirectional_sum: rnn.exec_dir = bi_sum; break;
- default: break;
- }
-
- if (everyone_is(f32, src_layer_d.data_type(), dst_layer_d.data_type(),
- weights_layer_d.data_type()))
- rnn.dt_conf = all_f32;
- else if (dst_layer_d.data_type() == u8) {
- if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8))
- rnn.dt_conf = u8u8u8u8;
- else
- rnn.dt_conf = f32u8f32u8;
- } else {
- if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8))
- rnn.dt_conf = u8u8u8f32;
- else
- rnn.dt_conf = f32u8f32f32;
- }
-
- rnn.n_layer = weights_layer_d.dims()[0];
- rnn.n_iter = src_layer_d.dims()[0];
- rnn.n_dir = weights_layer_d.dims()[1];
- rnn.n_gates = weights_layer_d.dims()[3];
- rnn.n_states = mkldnn_rnn_cell_get_states_count(&rd.cell_desc);
- rnn.n_bias = rnn.n_gates + rnn.is_lbr;
- rnn.mb = src_layer_d.dims()[1];
- rnn.sic = weights_iter_d.dims()[2];
- rnn.slc = weights_layer_d.dims()[2];
- rnn.dic = weights_layer_d.dims()[4];
- rnn.dlc = dst_layer_d.dims()[2];
-
- rnn.gates_ld = rnn.dic * rnn.n_gates;
- rnn.gates_nld = rnn.mb;
- rnn.states_nld = rnn.mb;
-
- /* Set the correct number of weights parts */
- bool is_orig_gru = rd.cell_desc.cell_kind == alg_kind::vanilla_gru;
- rnn.n_parts_weights_layer = 1;
- rnn.parts_weights_layer[0] = rnn.n_gates;
- rnn.parts_weights_layer[1] = 0;
-
- rnn.n_parts_weights_iter = is_orig_gru ? 2 : 1;
- rnn.parts_weights_iter[0] = is_orig_gru ? 2 : rnn.n_gates;
- rnn.parts_weights_iter[1] = is_orig_gru ? 1 : 0;
-
- rnn.n_parts_bias = 1;
- rnn.parts_bias[0] = rnn.n_bias;
- rnn.parts_bias[1] = 0;
-
- /* Decide wich gemm implementation to use: packed/nonpacked jit/cblas
- * and if to mergre gemm across iterations */
- bool is_int8 = rnn.dt_conf != all_f32;
- rnn.merge_gemm_layer = ((rnn.is_fwd && rnn.mb < 128) || !rnn.is_fwd)
- || is_int8;
- bool is_gru = utils::one_of(rd.cell_desc.cell_kind, alg_kind::vanilla_gru,
- alg_kind::gru_linear_before_reset);
- rnn.merge_gemm_iter = !(rnn.is_fwd || is_gru) || is_int8;
- bool is_inference = !rnn.is_training;
-
- rnn.use_jit_gemm = !mayiuse(avx512_mic)
- && ((is_inference && (rnn.n_layer > 1 || rnn.mb < 100))
- || (rnn.is_training && rnn.dic < 500));
-
- /* Decide to copy bias */
- rnn.copy_bias = rnn.dt_conf != all_f32;
-
-#if USE_MKL_PACKED_GEMM
- rnn.use_layer_packed_gemm
- = (weights_layer_d.format_kind() == format_kind::any
- && rnn.slc > 760 && rnn.dic > 760 && is_inference)
- || is_int8; // packed gemm is the only supported option for int8
- rnn.use_iter_packed_gemm
- = (weights_iter_d.format_kind() == format_kind::any && rnn.sic > 760
- && rnn.dic > 760 && is_inference)
- || is_int8;
-#else
- rnn.use_layer_packed_gemm = false;
- rnn.use_iter_packed_gemm = false;
-#endif
-
- /* Set packed gemm sizes */
- if (rnn.use_layer_packed_gemm) {
- rnn.weights_layer_pack_size = 0;
- for (int p = 0; p < rnn.n_parts_weights_layer; p++) {
- int m_p = rnn.is_fwd
- ? (rnn.parts_weights_layer[p] * rnn.dic)
- : rnn.slc;
- int k_p = rnn.is_fwd
- ? rnn.slc
- : (rnn.parts_weights_layer[p] * rnn.dic);
- int n_p = rnn.merge_gemm_layer ? rnn.mb * rnn.n_iter : rnn.mb;
-
-#if USE_MKL_PACKED_GEMM
- if (rnn.dt_conf == all_f32)
- rnn.part_weights_layer_pack_size[p] = cblas_sgemm_pack_get_size(
- CblasAMatrix, m_p, n_p, k_p);
- else
- rnn.part_weights_layer_pack_size[p]
- = cblas_gemm_s8u8s32_pack_get_size(
- CblasAMatrix, m_p, n_p, k_p);
-#else
- UNUSED(m_p);
- UNUSED(k_p);
- UNUSED(n_p);
- rnn.part_weights_layer_pack_size[p] = 0;
-#endif
- rnn.weights_layer_pack_size += rnn.n_layer * rnn.n_dir
- * rnn.part_weights_layer_pack_size[p];
- }
- rnn.weights_layer_comp_offset = rnn.weights_layer_pack_size;
- rnn.weights_layer_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer
- * rnn.n_dir * rnn.n_gates * rnn.dlc * sizeof(float);
- }
-
- if (rnn.use_iter_packed_gemm) {
- rnn.weights_iter_pack_size = 0;
- for (int p = 0; p < rnn.n_parts_weights_iter; p++) {
- int m_p = rnn.is_fwd ? (rnn.parts_weights_iter[p] * rnn.dic) :
- rnn.sic;
- int k_p = rnn.is_fwd ? rnn.sic :
- (rnn.parts_weights_iter[p] * rnn.dic);
- int n_p = rnn.merge_gemm_iter ? rnn.mb * rnn.n_iter : rnn.mb;
-
-#if USE_MKL_PACKED_GEMM
- if (rnn.dt_conf == all_f32)
- rnn.part_weights_iter_pack_size[p] = cblas_sgemm_pack_get_size(
- CblasAMatrix, m_p, n_p, k_p);
- else
- rnn.part_weights_iter_pack_size[p]
- = cblas_gemm_s8u8s32_pack_get_size(
- CblasAMatrix, m_p, n_p, k_p);
-#else
- UNUSED(m_p);
- UNUSED(k_p);
- UNUSED(n_p);
- rnn.part_weights_iter_pack_size[p] = 0;
-#endif
- rnn.weights_iter_pack_size += rnn.n_layer * rnn.n_dir
- * rnn.part_weights_iter_pack_size[p];
- }
- rnn.weights_iter_comp_offset = rnn.weights_iter_pack_size;
- rnn.weights_iter_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer
- * rnn.n_dir * rnn.n_gates * rnn.dic * sizeof(float);
- }
-
-}
-
-void rnn_utils::set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
- const memory_desc_wrapper &weights_layer_d,
- const memory_desc_wrapper &weights_iter_d,
- const memory_desc_wrapper &diff_weights_layer_d,
- const memory_desc_wrapper &diff_weights_iter_d) {
-
- /* Set leading dimensions for input weights arrays depending on input format
- */
- rnn.weights_layer_is_packed
- = weights_layer_d.format_kind() == format_kind::rnn_packed;
- rnn.weights_iter_is_packed
- = weights_iter_d.format_kind() == format_kind::rnn_packed;
-
- auto set_dims = [&](const memory_desc_wrapper &md, int &ld, int &nld) {
- ld = 0; nld = 0;
- if (md.is_blocking_desc()) {
- if (is_ldigo(md)) {
- ld = (int)md.blocking_desc().strides[2];
- nld = md.dims()[2];
- } else if (is_ldgoi(md)) {
- ld = (int)md.blocking_desc().strides[4];
- nld = md.dims()[3] * md.dims()[4];
- } else
- assert(!"unsupported weights format");
- }
- };
- set_dims(weights_layer_d, rnn.weights_layer_ld, rnn.weights_layer_nld);
- set_dims(weights_iter_d, rnn.weights_iter_ld, rnn.weights_iter_nld);
- if (!rnn.is_fwd) {
- set_dims(diff_weights_layer_d, rnn.diff_weights_layer_ld,
- rnn.diff_weights_layer_nld);
- set_dims(diff_weights_iter_d, rnn.diff_weights_iter_ld,
- rnn.diff_weights_iter_nld);
- }
-
- int sizeof_states_dt
- = rnn.dt_conf == all_f32 ? sizeof(float) : sizeof(uint8_t);
- rnn.states_ws_ld
- = get_good_ld(nstl::max(rnn.slc, nstl::max(rnn.sic, rnn.dic)),
- sizeof_states_dt);
- rnn.gates_ws_ld = get_good_ld(rnn.gates_ld, sizeof(float));
-
- /* Set workspace sizes to store:
- * states to copmute a pass
- * diff states to copmute bwd pass (training only)
- * intermediate results from the gates
- */
- rnn.use_workspace = rnn.is_training;
- rnn.ws_states_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir
- * (rnn.n_iter + 1) * rnn.mb * rnn.states_ws_ld * sizeof_states_dt;
- bool is_lstm = rd.cell_desc.cell_kind == mkldnn_vanilla_lstm;
- rnn.ws_c_states_size = is_lstm
- ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb
- * rnn.states_ws_ld * sizeof(float)
- : 0;
- rnn.ws_diff_states_size = rnn.is_training
- ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1)
- * (rnn.n_states + 1) * rnn.mb * rnn.states_ws_ld
- * sizeof(float)
- : (size_t)0;
- rnn.ws_gates_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.mb
- * rnn.gates_ws_ld * sizeof(float);
-
- /* set other sizes */
- rnn.ws_per_cell = (size_t)rnn.is_lbr * rnn.mb * rnn.dic * sizeof(float);
- rnn.ws_cell_comp_size
- = rnn.is_lbr || rnn.dt_conf != all_f32
- ? (size_t) rnn.gates_nld * rnn.gates_ws_ld * sizeof(float)
- : 0;
- rnn.ws_grid_comp_size = (size_t)rnn.is_lbr * rnn.is_training * rnn.n_layer
- * rnn.n_dir * rnn.n_iter * rnn.ws_per_cell * sizeof(float);
- rnn.ws_bias_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic
- * sizeof(float);
-}
-
-int rnn_utils::get_good_ld(int dim, int sizeof_dt) {
- // we want matrices leading dimentions to be 64-byte aligned,
- // and not divisible by 256 to avoid 4K aliasing effects
- int ld = rnd_up(dim, 64 / sizeof_dt);
- return (ld % 256 == 0) ? ld + 64 / sizeof_dt : ld;
-}
-
-void rnn_utils::set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset,
- size_t &ws_states_offset, size_t &ws_c_states_offset,
- size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset,
- size_t &ws_cell_comp_offset, size_t &ws_bias_offset,
- size_t &scratchpad_size, size_t &workspace_size) {
-
- const size_t page_size = 4096; // 2097152;
- size_t current_offset;
- /* Mandatory workspaces: go to workspace if use_workspace, scratchpad
- * otherwise */
- current_offset = 0; // assumes the workspace base pointer is page aligned
- ws_gates_offset = current_offset;
- current_offset += rnn.ws_gates_size;
-
- current_offset = utils::rnd_up(current_offset, page_size);
- ws_states_offset = current_offset;
- current_offset += rnn.ws_states_size;
-
- current_offset = utils::rnd_up(current_offset, page_size);
- ws_c_states_offset = current_offset;
- current_offset += rnn.ws_c_states_size;
-
- current_offset = utils::rnd_up(current_offset, page_size);
- ws_diff_states_offset = current_offset;
- current_offset += rnn.ws_diff_states_size;
-
- current_offset = utils::rnd_up(current_offset, page_size);
- ws_grid_comp_offset = current_offset;
- current_offset += rnn.ws_grid_comp_size;
-
- current_offset = utils::rnd_up(current_offset, page_size);
- ws_cell_comp_offset = current_offset;
- current_offset += rnn.ws_cell_comp_size;
-
- workspace_size = rnn.use_workspace ? current_offset : 0;
-
- /* Optional scratchpads */
- // Assumes the scratchpad base pointer is page aligned.
- // If use_workspace, the following goes to scratchpad alone,
- // otherwise, all goes to scratchpad and continue incrementing offset
- current_offset = rnn.use_workspace ? 0 : current_offset;
-
- if (rnn.copy_bias) {
- current_offset = utils::rnd_up(current_offset, page_size);
- ws_bias_offset = current_offset;
- current_offset += rnn.ws_bias_size;
- }
-
- scratchpad_size = current_offset;
-}
-
-void rnn_utils::get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn,
- size_t &scratchpad_size, size_t &workspace_size) {
- size_t ws_gates_offset, ws_states_offset, ws_c_states_offset,
- ws_diff_states_offset, ws_grid_comp_offset, ws_cell_comp_offset,
- ws_bias_offset;
- set_offsets(rnn, ws_gates_offset, ws_states_offset, ws_diff_states_offset,
- ws_c_states_offset, ws_grid_comp_offset, ws_cell_comp_offset,
- ws_bias_offset, scratchpad_size, workspace_size);
-}
-
-status_t rnn_utils::set_good_strides(
- memory_desc_t &weights_md, format_tag_t tag) {
- auto &strides = weights_md.format_desc.blocking.strides;
- auto dims = weights_md.dims;
-
- if (tag == ldigo) {
- strides[2] = rnn_utils::get_good_ld((int)strides[2],
- (int)types::data_type_size(weights_md.data_type));
- strides[1] = dims[2] * strides[2];
- strides[0] = dims[1] * strides[1];
- } else if (tag == ldgoi) {
- strides[4] = rnn_utils::get_good_ld((int)strides[4],
- (int)types::data_type_size(weights_md.data_type));
- strides[3] = dims[4] * strides[4];
- strides[1] = dims[3] * strides[3];
- strides[0] = dims[1] * strides[1];
- } else
- return status::unimplemented;
-
- return status::success;
-}
-
-status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn,
- memory_desc_t &weights_md, bool is_iter) {
- using namespace format_tag;
- bool use_packed_gemm = is_iter
- ? rnn.use_iter_packed_gemm
- : rnn.use_layer_packed_gemm;
- if (use_packed_gemm) {
- weights_md.format_kind = format_kind::rnn_packed;
- rnn_packed_desc_t &rnn_pdata = weights_md.format_desc.rnn_packed_desc;
- rnn_pdata.format = rnn.is_fwd ? mkldnn_ldigo_p : mkldnn_ldgoi_p;
- if (is_iter) {
- rnn_pdata.n = rnn.mb;
- rnn_pdata.n_parts = rnn.n_parts_weights_iter;
- array_copy(rnn_pdata.parts, rnn.parts_weights_iter,
- MKLDNN_RNN_MAX_N_PARTS);
- array_copy(rnn_pdata.part_pack_size,
- rnn.part_weights_iter_pack_size, MKLDNN_RNN_MAX_N_PARTS);
- rnn_pdata.offset_compensation = rnn.weights_iter_comp_offset;
- rnn_pdata.size = rnn.weights_iter_pack_size;
- } else {
- rnn_pdata.n = rnn.merge_gemm_layer ? rnn.n_iter * rnn.mb : rnn.mb;
- rnn_pdata.n_parts = rnn.n_parts_weights_layer;
- array_copy(rnn_pdata.parts, rnn.parts_weights_layer,
- MKLDNN_RNN_MAX_N_PARTS);
- array_copy(rnn_pdata.part_pack_size,
- rnn.part_weights_layer_pack_size, MKLDNN_RNN_MAX_N_PARTS);
- rnn_pdata.offset_compensation = rnn.weights_layer_comp_offset;
- rnn_pdata.size = rnn.weights_layer_pack_size;
- }
- } else {
- CHECK(memory_desc_init_by_tag(weights_md, rnn.is_fwd ? ldigo : ldgoi));
- // Adjust strides for good leading dimension in GEMM
- CHECK(set_good_strides(weights_md, rnn.is_fwd ? ldigo : ldgoi));
- }
- return status::success;
-}
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp
deleted file mode 100644
index 99eb787a64..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp
+++ /dev/null
@@ -1,225 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef RNN_UTILS_HPP
-#define RNN_UTILS_HPP
-
-#include "mkldnn.h"
-
-#include "cpu_rnn_pd.hpp"
-
-
-#define rnn_elemwise_sig(f) \
- void f(const rnn_utils::rnn_conf_t &rnn, acc_data_t *ws_gates_, \
- src_data_t *states_t_l_, float *c_states_t_l_, \
- src_data_t *states_tm1_l_, float *c_states_tm1_l_, \
- float *diff_states_t_l_, float *diff_states_t_lp1_, \
- float *diff_states_tp1_l_, float *bias_, float *ws_grid_, \
- float *ws_cell_) const
-
-#define rnn_cell_execution_sig(f) \
- void f(const rnn_utils::rnn_conf_t &rnn, src_data_t *states_t_l_, \
- float *c_states_t_l_, float *diff_states_t_l_, \
- weights_data_t **w_layer_, weights_data_t **w_iter_, \
- float **bias_, src_data_t *states_t_lm1_, \
- src_data_t *states_tm1_l_, float *c_states_tm1_l_, \
- float *diff_states_t_lp1_, float *diff_states_tp1_l_, \
- float *diff_w_layer_, float *diff_w_iter_, float *diff_bias_, \
- acc_data_t *ws_gates_, float *ws_grid_, float *ws_cell_) const
-
-#define rnn_grid_execution_sig(f) \
- void f(const rnn_utils::rnn_conf_t &rnn, weights_data_t **weights_layer_, \
- weights_data_t **weights_states_, float **bias_, \
- src_data_t *ws_states_, float *ws_c_states_, \
- float *ws_diff_states_, acc_data_t *ws_gates_, float *ws_cell_, \
- float *ws_grid_, float *diff_weights_layer_, \
- float *diff_weights_iter_, float *diff_bias_) const
-
-#define rnn_gemm_sig(f) \
- void f(const char transA, const char transB, int m, int n, int k, \
- const float alpha, const weights_data_t *a_, const int ldA, \
- const src_data_t *b_, const int ldB, const float beta, \
- acc_data_t *c_, const int ldC) const
-
-#define rnn_bias_prepare_sig(f) \
- void f(const rnn_utils::rnn_conf_t &rnn, float **bias_, const float *b_, \
- float *scratch_bias_) const
-
-#define rnn_bias_finalize_sig(f) \
- void f(const rnn_utils::rnn_conf_t &rnn, float *scratch_bias_, \
- const float *w_iter_comp, const float *w_layer_comp) const
-
-#define rnn_weights_assign_sig(f) \
- void f(const rnn_utils::rnn_conf_t &rnn, const memory_desc_t *md, int nld, \
- int ld, int OC_size, int IC_size, const int n_parts, \
- const int *gates_per_part, const size_t *part_weights_pack_size, \
- weights_data_t **weights_, const weights_data_t *w_, \
- float **bias_, const float *b_, float *scratch_bias_) const
-
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-namespace rnn_utils {
-
-using namespace mkldnn::impl::utils;
-
-enum execution_direction_t {
- l2r,
- r2l,
- bi_concat,
- bi_sum,
-};
-
-enum data_type_conf_t {
- all_f32,
- u8u8u8f32,
- f32u8f32f32,
- u8u8u8u8,
- f32u8f32u8
-};
-
-struct rnn_conf_t {
- execution_direction_t exec_dir;
- data_type_conf_t dt_conf;
- int n_layer, n_iter, n_dir, n_gates, n_states;
- int mb;
- int slc, sic, dic, dlc;
- int gates_ld, gates_nld, gates_ws_ld;
- int n_parts_weights_layer, parts_weights_layer[MKLDNN_RNN_MAX_N_PARTS];
- int n_parts_weights_iter, parts_weights_iter[MKLDNN_RNN_MAX_N_PARTS];
- int n_bias, n_parts_bias, parts_bias[MKLDNN_RNN_MAX_N_PARTS];
- size_t part_weights_iter_pack_size[MKLDNN_RNN_MAX_N_PARTS],
- part_weights_layer_pack_size[MKLDNN_RNN_MAX_N_PARTS];
- bool weights_layer_is_packed, weights_iter_is_packed;
- /* Size of packed data in bytes */
- size_t weights_layer_comp_offset, weights_layer_pack_size,
- weights_iter_comp_offset, weights_iter_pack_size;
-
- bool copy_bias;
- int weights_layer_ld, weights_layer_nld;
- int diff_weights_layer_ld, diff_weights_layer_nld;
- int weights_iter_ld, weights_iter_nld;
- int diff_weights_iter_ld, diff_weights_iter_nld;
- int states_nld, states_ws_ld;
- int weights_iter_compensation_size, weights_layer_compensation_size;
- bool is_fwd, is_training, is_lbr;
- bool use_workspace;
-
- /* Size of workspace for each tensor in bytes */
- size_t ws_gates_size, ws_states_size, ws_c_states_size, ws_diff_states_size,
- ws_cell_comp_size, ws_grid_comp_size, ws_per_cell, ws_bias_size;
- bool merge_gemm_iter, merge_gemm_layer, use_jit_gemm, use_layer_packed_gemm,
- use_iter_packed_gemm;
-};
-
-bool is_ldigo(const memory_desc_wrapper &md);
-bool is_ldgoi(const memory_desc_wrapper &md);
-
-int get_good_ld(int dim, int sizeof_dt);
-
-void init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
- const memory_desc_wrapper &src_layer_d,
- const memory_desc_wrapper &src_iter_d,
- const memory_desc_wrapper &weights_layer_d,
- const memory_desc_wrapper &weights_iter_d,
- const memory_desc_wrapper &dst_layer_d);
-
-void set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
- const memory_desc_wrapper &weights_layer_d,
- const memory_desc_wrapper &weights_iter_d,
- const memory_desc_wrapper &diff_weights_layer_d,
- const memory_desc_wrapper &diff_weights_iter_d);
-
-void set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset,
- size_t &ws_h_state_offset, size_t &ws_c_state_offset,
- size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset,
- size_t &ws_cell_comp_offset, size_t &ws_bias_offset,
- size_t &scratchpad_size, size_t &workspace_size);
-
-void get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn,
- size_t &scratchpad_size, size_t &workspace_size);
-status_t set_expected_desc(
- rnn_conf_t &rnn, memory_desc_t &weights_md, bool is_iter);
-status_t set_good_strides(memory_desc_t &weights_md, format_tag_t tag);
-
-template <typename T>
-struct ws_gates_aoc {
- ws_gates_aoc(const rnn_conf_t &rnn, T *data)
- : gates_(data, rnn.gates_nld, rnn.gates_ws_ld), DIC_(rnn.dic) {}
- T &operator()(int batch, int gate, int dic) {
- return gates_(batch, gate * DIC_ + dic);
- }
-
-private:
- mkldnn::impl::utils::array_offset_calculator<T, 2> gates_;
- int DIC_;
-};
-using ws_gates_aoc_t = ws_gates_aoc<float>;
-using ws_gates_aoc_s32_t = ws_gates_aoc<int32_t>;
-
-struct bias_aoc_t {
- bias_aoc_t(const rnn_conf_t &rnn, const float *data)
- : bias_(data, rnn.n_bias, rnn.dic) {}
- const float &operator()(int bias_n, int dic) { return bias_(bias_n, dic); }
-
-private:
- mkldnn::impl::utils::array_offset_calculator<const float, 2> bias_;
-};
-
-template <typename T>
-struct ws_states_aoc {
- ws_states_aoc(const rnn_conf_t &rnn, T *data)
- : state_(data, rnn.states_nld, rnn.states_ws_ld) {}
- T &operator()(int batch, int dic) { return state_(batch, dic); }
-
-private:
- mkldnn::impl::utils::array_offset_calculator<T, 2> state_;
-};
-using ws_states_aoc_t = ws_states_aoc<float>;
-using ws_states_aoc_u8_t = ws_states_aoc<uint8_t>;
-
-struct ws_diff_states_aoc_t {
- ws_diff_states_aoc_t(const rnn_conf_t &rnn, float *data)
- : diff_states_(data, rnn.n_states + 1, rnn.n_iter + 1, rnn.states_nld,
- rnn.states_ws_ld) {}
- float &operator()(int state_n, int batch, int dic) {
- return diff_states_(state_n, 0, batch, dic);
- }
-
-private:
- mkldnn::impl::utils::array_offset_calculator<float, 4> diff_states_;
-};
-
-struct ws_diff_w_iter_aoc_t {
- ws_diff_w_iter_aoc_t(const rnn_conf_t &rnn, float *data)
- : diff_weights_iter_(
- data, rnn.diff_weights_iter_nld, rnn.diff_weights_iter_ld)
- , DIC_(rnn.dic) {}
- float &operator()(int sic, int gate, int dic) {
- return diff_weights_iter_(sic, gate * DIC_ + dic);
- }
-
-private:
- mkldnn::impl::utils::array_offset_calculator<float, 2> diff_weights_iter_;
- int DIC_;
-};
-}
-}
-}
-}
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp
deleted file mode 100644
index 0420f87aa5..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp
+++ /dev/null
@@ -1,126 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "mkldnn_thread.hpp"
-
-#include "simple_concat.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace memory_tracking::names;
-
-template <data_type_t data_type>
-status_t simple_concat_t<data_type>::execute(const exec_ctx_t &ctx) const {
- auto scratchpad = this->scratchpad(ctx);
- auto iptrs = scratchpad.template get<const data_t *>(key_concat_iptrs);
- auto optrs = scratchpad.template get<data_t *>(key_concat_optrs);
- auto nelems_to_copy = scratchpad.template get<dim_t>(key_concat_nelems);
- auto is = scratchpad.template get<strides_t>(key_concat_istrides);
-
- const int num_arrs = pd()->n_inputs();
- const int *perm = pd()->perm_, *iperm = pd()->iperm_;
- const int concat_dim = pd()->concat_dim();
- auto o_base_ptr = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
-
- for (int a = 0; a < num_arrs; ++a) {
- const memory_desc_wrapper i_d(pd()->src_md(a));
- const memory_desc_wrapper o_d(pd()->src_image_md(a));
-
- iptrs[a] = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MULTIPLE_SRC + a)
- + i_d.blk_off(0);
- optrs[a] = o_base_ptr + o_d.blk_off(0);
- nelems_to_copy[a] = pd()->nelems_to_concat(i_d);
- for (int i = 0; i < MKLDNN_MAX_NDIMS; i++) {
- if (i < perm[concat_dim])
- is[a][i] = size_t(i_d.blocking_desc().strides[iperm[i]]);
- else
- is[a][i] = 0;
- }
- }
-
- const memory_desc_wrapper o_d(pd()->src_image_md(0));
-
- strides_t os = { 0 };
- for (int i = 0; i < perm[concat_dim]; i++)
- os[i] = o_d.blocking_desc().strides[iperm[i]];
-
- dims_t phys_dims;
- for (size_t i = 0; i < sizeof(phys_dims)/sizeof(phys_dims[0]); i++)
- phys_dims[i] = (i < (size_t)perm[concat_dim])
- ? o_d.dims()[iperm[i]] / pd()->blocks_[iperm[i]] : 1;
-
- if (perm[concat_dim] == 0) {
- for (int a = 0; a < num_arrs; ++a) {
- const data_t *i = &iptrs[a][0];
- data_t *o = &optrs[a][0];
- parallel_nd((ptrdiff_t)nelems_to_copy[a],
- [&](ptrdiff_t e) { o[e] = i[e]; });
- }
- } else {
- parallel_nd(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3],
- phys_dims[4], num_arrs,
- [&](dim_t n0, dim_t n1, dim_t n2, dim_t n3, dim_t n4, int a) {
- // XXX: this code may access uninitialized values in is[*][0-4] --
- // that's why we have to set them to zero although this is
- // probably benign
- size_t in_off = is[a][0] * n0 + is[a][1] * n1 + is[a][2] * n2
- + is[a][3] * n3 + is[a][4] * n4;
- size_t out_off = os[0] * n0 + os[1] * n1 + os[2] * n2
- + os[3] * n3 + os[4] * n4;
- const data_t *i = &iptrs[a][in_off];
- data_t *o = &optrs[a][out_off];
-#if defined(__GNUC__) && !defined(__INTEL_COMPILER)
- // The code below performs data copying: o[e] = i[e]
- // and uses a workaround to make GNU compilers optimize it
- uint8_t *ptro = reinterpret_cast<uint8_t *>(o);
- const uint8_t *ptri = reinterpret_cast<const uint8_t *>(i);
- const dim_t main_part =
- nelems_to_copy[a] * sizeof(data_t) / sizeof(uint32_t);
- const dim_t tail_part =
- nelems_to_copy[a] % sizeof(data_t) / sizeof(uint32_t);
-
- PRAGMA_OMP_SIMD()
- for (dim_t e = 0; e < main_part; ++e) {
- *(reinterpret_cast<uint32_t *>(ptro))
- = *(reinterpret_cast<const uint32_t *>(ptri));
- ptro += sizeof(uint32_t);
- ptri += sizeof(uint32_t);
- }
- for (dim_t e = 0; e < tail_part; ++e) {
- *ptro = *ptri;
- ++ptro;
- ++ptri;
- }
-#else
- PRAGMA_OMP_SIMD()
- for (dim_t e = 0; e < nelems_to_copy[a]; ++e) o[e] = i[e];
-#endif
- });
- }
-
- return status::success;
-}
-
-template struct simple_concat_t<data_type::f32>;
-template struct simple_concat_t<data_type::u8>;
-template struct simple_concat_t<data_type::s8>;
-template struct simple_concat_t<data_type::s32>;
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp
deleted file mode 100644
index 057cc3c4c7..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp
+++ /dev/null
@@ -1,155 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef SIMPLE_CONCAT_HPP
-#define SIMPLE_CONCAT_HPP
-
-#include "memory_tracking.hpp"
-
-#include "cpu_concat_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <data_type_t data_type>
-struct simple_concat_t: public cpu_primitive_t {
- struct pd_t: public cpu_concat_pd_t {
- using cpu_concat_pd_t::cpu_concat_pd_t;
-
- pd_t(const pd_t &rhs): cpu_concat_pd_t(rhs) {
- int ndims = rhs.dst_md_.ndims;
- utils::array_copy(perm_, rhs.perm_, ndims);
- utils::array_copy(iperm_, rhs.iperm_, ndims);
- utils::array_copy(blocks_, rhs.blocks_, ndims);
- }
-
- DECLARE_CONCAT_PD_T("simple:any", simple_concat_t);
-
- status_t init() {
- const memory_desc_wrapper dst_d(dst_md());
- bool ok = true
- && cpu_concat_pd_t::init() == status::success
- && dst_d.ndims() <= 6;
- if (!ok) return status::unimplemented;
-
- for (size_t i = 0; i < src_mds_.size(); ++i) {
- const memory_desc_wrapper i_d(&src_mds_[i]);
- const memory_desc_wrapper o_d(&src_image_mds_[i]);
-
- const int ignore_strides = 0;
-
- ok = ok
- && utils::everyone_is(data_type, i_d.data_type(),
- o_d.data_type())
- && utils::everyone_is(format_kind::blocked,
- i_d.format_kind(), o_d.format_kind())
- && types::blocking_desc_is_equal(i_d.blocking_desc(),
- o_d.blocking_desc(), ignore_strides)
- && types::blocking_desc_is_equal(i_d.blocking_desc(),
- dst_d.blocking_desc(), ignore_strides)
- && !i_d.is_additional_buffer();
- if (!ok) return status::unimplemented;
- }
-
- dst_d.compute_blocks(blocks_);
- format_perm();
-
- // start dim is the first dimension after which the concatenation
- // would happen contiguously
- const int start_dim = perm_[concat_dim()];
-
- // check that contiguous part is indeed contiguous (i.e. dense)
- if (nelems_to_concat(dst_d) !=
- dst_d.padded_dims()[concat_dim()] / blocks_[concat_dim()]
- * dst_d.blocking_desc().strides[concat_dim()])
- return status::unimplemented;
-
- // check that all inputs have the same strides for the
- // contiguous part [concat_dim .. ndims] for the *major* dims.
- // the block part is already checked above
- for (size_t i = 0; i < src_mds_.size(); ++i) {
- const memory_desc_wrapper i_d(&src_mds_[i]);
- for (int d = start_dim; d < dst_d.ndims(); ++d) {
- if (dst_d.blocking_desc().strides[iperm_[d]]
- != i_d.blocking_desc().strides[iperm_[d]])
- return status::unimplemented;
- }
- }
-
- init_scratchpad();
-
- return status::success;
- }
-
- int perm_[MKLDNN_MAX_NDIMS] {};
- int iperm_[MKLDNN_MAX_NDIMS] {};
- dims_t blocks_ {};
-
- dim_t nelems_to_concat(const memory_desc_wrapper &data_d) const {
- const int ndims = data_d.ndims();
-
- dim_t nelems = 1;
- for (int i = perm_[concat_dim()]; i < ndims; i++)
- nelems *= data_d.dims()[iperm_[i]] / blocks_[iperm_[i]];
- for (int i = 0; i < ndims; i++)
- nelems *= blocks_[i];
-
- return nelems;
- }
-
- private:
- void format_perm() {
- const memory_desc_wrapper dst_d(dst_md());
- const int ndims = dst_d.ndims();
-
- strides_t strides;
- utils::array_copy(strides, dst_d.blocking_desc().strides, ndims);
- for (int i = 0; i < ndims; i++) iperm_[i] = i;
-
- utils::simultaneous_sort(strides, iperm_, ndims,
- [](stride_t a, stride_t b) { return b - a; });
-
- for (int i = 0; i < ndims; i++) perm_[iperm_[i]] = i;
- }
-
- void init_scratchpad() {
- using namespace memory_tracking::names;
- auto scratchpad = scratchpad_registry().registrar();
- scratchpad.book(key_concat_iptrs, sizeof(data_t *) * n_inputs());
- scratchpad.book(key_concat_optrs, sizeof(data_t *) * n_inputs());
- scratchpad.book(key_concat_nelems, sizeof(dim_t) * n_inputs());
- scratchpad.book(key_concat_istrides,
- sizeof(strides_t) * n_inputs());
- }
- };
-
- simple_concat_t(const pd_t *apd): cpu_primitive_t(apd) {}
-
- virtual status_t execute(const exec_ctx_t &ctx) const override;
-
- typedef typename prec_traits<data_type>::type data_t;
-
-private:
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp
deleted file mode 100644
index e6c3b8d7af..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp
+++ /dev/null
@@ -1,98 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_SIMPLE_Q10N_HPP
-#define CPU_SIMPLE_Q10N_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "math_utils.hpp"
-#include "nstl.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::math;
-
-template <typename out_t>
-inline out_t round_and_saturate(float f)
-{ return math::saturate<out_t>(out_round<int>(f)); }
-
-/* Quantization with alpha == 1 and beta == 0 */
-template <typename in_t, typename out_t, typename enabled = void>
-struct qz_a1b0 {
- out_t operator()(in_t in)
- { return round_and_saturate<out_t>((float)in); }
-};
-
-template <typename in_t, typename out_t>
-struct qz_a1b0<in_t, out_t,
- typename utils::enable_if<true
- && nstl::is_integral<in_t>::value
- && !is_subset<in_t, out_t>::value
- >::type> {
- out_t operator()(in_t in) { return math::saturate<out_t>(in); }
-};
-
-template <typename in_t, typename out_t>
-struct qz_a1b0<in_t, out_t,
- typename utils::enable_if<is_subset<in_t, out_t>::value>::type> {
- out_t operator()(in_t in) { return (out_t)in; }
-};
-
-/* Quantization with alpha == 1 */
-template <typename in_t, typename out_t> struct qz_a1 {
- out_t operator()(in_t in, out_t out, float beta)
- { return round_and_saturate<out_t>((float)in + beta * out); }
-};
-
-template <typename in_t> struct qz_a1<in_t, float> {
- float operator()(in_t in, float out, float beta)
- { return (float)in + beta * out; }
-};
-
-/* Quantization with beta == 0 */
-template <typename in_t, typename out_t> struct qz_b0 {
- out_t operator()(in_t in, float alpha)
- { return round_and_saturate<out_t>(alpha * in); }
-};
-
-template <typename in_t> struct qz_b0<in_t, float> {
- float operator()(in_t in, float alpha) { return alpha * in; }
-};
-
-/* Quantization */
-template <typename in_t, typename out_t> struct qz {
- out_t operator()(in_t in, out_t out, float alpha, float beta) {
- return round_and_saturate<out_t>(
- alpha * in + (beta ? beta * out : 0));
- }
-};
-
-template <typename in_t> struct qz<in_t, float> {
- float operator()(in_t in, float out, float alpha, float beta)
- { return alpha * in + (beta ? beta * out : 0); }
-};
-
-}
-}
-}
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp
deleted file mode 100644
index ff845f5bd3..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp
+++ /dev/null
@@ -1,1022 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_SIMPLE_REORDER_HPP
-#define CPU_SIMPLE_REORDER_HPP
-
-#include <assert.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "math_utils.hpp"
-#include "mkldnn_thread.hpp"
-#include "utils.hpp"
-
-#include "tag_traits.hpp"
-#include "cpu_reorder_pd.hpp"
-#include "cpu_primitive.hpp"
-
-#include "simple_q10n.hpp"
-#include "cpu_isa_traits.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::status;
-using namespace mkldnn::impl::format_tag;
-using namespace mkldnn::impl::data_type;
-
-using bd = block_dim_t;
-using ib = inner_blk_t;
-
-using namespace mkldnn::impl::utils;
-using math::saturate;
-
-template<impl::data_type_t type>
-using data_t = typename prec_traits<type>::type;
-
-template<impl::data_type_t type_i, impl::data_type_t type_o>
-using _qz_a1b0 = qz_a1b0<data_t<type_i>, data_t<type_o>>;
-
-template<impl::data_type_t type_i, impl::data_type_t type_o>
-using _qz = qz<data_t<type_i>, data_t<type_o>>;
-
-namespace fmt_order {
- const bool keep = true;
- const bool reverse = false;
- const bool any = keep;
-}
-
-namespace spec {
-struct direct_copy {};
-struct direct_copy_except_dim_0 {};
-struct reference {};
-struct conv_s8s8 {};
-}
-
-#define SIMPLE_REORDER_TEMPL_DECL \
- impl::data_type_t type_i, impl::format_tag_t tag_i, \
- impl::data_type_t type_o, impl::format_tag_t tag_o, bool order_keep
-#define SIMPLE_REORDER_TEMPL_CALL \
- type_i, tag_i, type_o, tag_o, order_keep
-
-#define DECLARE_COMMON_PARAMS() \
- const memory_desc_wrapper &input_d = pd->src_md(); \
- const memory_desc_wrapper &output_d = pd->dst_md(); \
- const float alpha = pd->alpha(); MAYBE_UNUSED(alpha); \
- const float beta = pd->beta(); MAYBE_UNUSED(beta);
-
-/* specific reorders: common template */
-template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void>
-struct simple_reorder_impl {};
-
-namespace {
-inline bool simple_fmt_check(bool order_keep, impl::format_tag_t tag_i,
- impl::format_tag_t tag_o, const memory_desc_wrapper &input_d,
- const memory_desc_wrapper &output_d) {
- return input_d.matches_tag(order_keep ? tag_i : tag_o)
- && output_d.matches_tag(order_keep ? tag_o : tag_i);
-}
-inline bool simple_attr_check(const primitive_attr_t *attr, bool many_scales_support) {
- if (many_scales_support)
- return true;
- return IMPLICATION(attr, attr->output_scales_.mask_ == 0);
-}
-}
-
-/* specific reorders: implementation */
-template <SIMPLE_REORDER_TEMPL_DECL>
-struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
-typename utils::enable_if<tag_i == any && (false
- || tag_o == hwio
- || tag_o == hwigo)
- , spec::conv_s8s8>::type>
-{
- static bool is_applicable(const memory_desc_wrapper &input_d,
- const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
- {
- const size_t D_mask = utils::array_product(input_d.dims(),
- math::ilog2q(attr->output_scales_.mask_ + 1));
- const int oc = (input_d.dims()[tag_o == hwigo + 0]);
- const int g = (tag_o == hwigo) ? (input_d.dims()[0]) : 1;
-
- return output_d.matches_tag(tag_o)
- && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8)
- && (input_d.data_type() == f32 || input_d.data_type() == s8)
- && output_d.data_type() == s8
- && (D_mask == 1 || D_mask == (size_t)g * oc);
- }
-
- static status_t execute(const cpu_reorder_pd_t *pd,
- const data_t<type_i> *input, data_t<type_o> *output) {
- DECLARE_COMMON_PARAMS();
-
- static constexpr bool w_groups = tag_o == hwigo;
-
- const auto &dims = input_d.dims();
- const auto &pdims = output_d.padded_dims();
-
- const int G = w_groups ? dims[0] : 1;
- const int OC = dims[w_groups + 0];
- const int IC = dims[w_groups + 1];
- const int H = dims[w_groups + 2];
- const int W = dims[w_groups + 3];
-
- const float *scales = pd->attr()->output_scales_.scales_;
- const size_t D_mask = utils::array_product(input_d.dims(),
- math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
-
- assert(output_d.extra().flags
- & memory_extra_flags::compensation_conv_s8s8);
- float adj_scale =
- (output_d.extra().flags & memory_extra_flags::scale_adjust)
- ? output_d.extra().scale_adjust : 1.f;
-
- size_t offset = G * pdims[w_groups + 0] * pdims[w_groups + 1] * H * W;
- int32_t *cp = reinterpret_cast<int32_t *>(output + offset);
-
- parallel_nd(G, OC, [&](int g, int oc) {
- cp[g * OC + oc] = 0;
- for (int ic = 0; ic < IC; ic++)
- for (int h = 0; h < H; h++)
- for (int w = 0; w < W; w++) {
- auto i = input[input_d.blk_off<!w_groups>(g, oc, ic, h, w)];
- auto &o = output[output_d.blk_off<!w_groups>(g, oc, ic, h, w)];
- const float s = scales[(D_mask == 1) ? 0 : g * OC + oc];
-
- o = qz_b0<data_t<type_i>, data_t<type_o>>()(
- i, s * adj_scale);
- cp[g * OC + oc] -= (int32_t)o;
- }
- cp [g * OC + oc] *= 128;
- });
- return success;
- }
-};
-
-template <SIMPLE_REORDER_TEMPL_DECL>
-struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
- typename utils::enable_if<
- (tag_i == goiw && tag_o == gOIw4i16o4i)
- || (tag_i == oiw && tag_o == OIw4i16o4i)
- || (tag_i == goihw && tag_o == gOIhw4i16o4i)
- || (tag_i == oihw && tag_o == OIhw4i16o4i)
- || (tag_i == goihw && tag_o == gOIhw2i8o4i)
- || (tag_i == goihw && tag_o == gOIhw4o4i)
- , spec::conv_s8s8>::type>
-{
- static bool is_applicable(const memory_desc_wrapper &input_d,
- const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
- {
- const size_t D_mask = utils::array_product(input_d.dims(),
- math::ilog2q(attr->output_scales_.mask_ + 1));
- const bool w_groups = !utils::one_of(tag_o, OIw4i16o4i, OIhw4i16o4i);
- const int oc = (input_d.dims()[w_groups ? 1 : 0]);
- const int g = w_groups ? input_d.dims()[0] : 1;
-
- return input_d.matches_tag(tag_i)
- && output_d.matches_tag(tag_o)
- && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8)
- && (input_d.data_type() == f32 || input_d.data_type() == s8)
- && output_d.data_type() == s8
- && (D_mask == 1 || D_mask == (size_t)g * oc);
- }
-
- static status_t execute(const cpu_reorder_pd_t *pd,
- const data_t<type_i> *input, data_t<type_o> *output) {
- DECLARE_COMMON_PARAMS();
-
- static constexpr bool w_groups =
- !utils::one_of(tag_o, OIw4i16o4i, OIhw4i16o4i);
- constexpr int is_1d =
- utils::one_of(tag_o, gOIw4i16o4i, OIw4i16o4i);
- constexpr int blksize = tag_traits<tag_o>::inner_blks == ib::_4b4c
- ? 4
- : tag_traits<tag_o>::inner_blks == ib::_2c8b4c
- ? 8
- : 16;
-
- const auto &_g_oihw_d = order_keep ? input_d : output_d;
- const auto &dims = input_d.dims();
- const auto &pdims = order_keep
- ? output_d.padded_dims()
- : input_d.padded_dims();
-
- const int G = w_groups ? dims[0] : 1;
- const int OC = dims[w_groups + 0];
- const int NB_OC = pdims[w_groups + 0] / blksize;
- const int IC = dims[w_groups + 1];
- const int NB_IC = pdims[w_groups + 1] / blksize;
- const int H = is_1d ? 1 : dims[w_groups + 2];
- const int W = dims[w_groups + 3 - is_1d];
-
- const float *scales = pd->attr()->output_scales_.scales_;
- const size_t D_mask = utils::array_product(input_d.dims(),
- math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
-
- assert(output_d.extra().flags
- & memory_extra_flags::compensation_conv_s8s8);
- float adj_scale =
- (output_d.extra().flags & memory_extra_flags::scale_adjust)
- ? output_d.extra().scale_adjust : 1.f;
-
- auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
- int32_t *c, const float *s, const int oc_block, const int ic_block) {
-# define index AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>
-
- for (int ic = 0; ic < ic_block; ++ic) {
- for (int oc = 0; oc < oc_block; ++oc) {
- const auto _g_oihw_off =
- oc * _g_oihw_d.blocking_desc().strides[w_groups + 0]
- + ic * _g_oihw_d.blocking_desc().strides[w_groups + 1];
- out[index(oc, ic)]
- = qz_b0<data_t<type_i>, data_t<type_o>>()(
- inp[_g_oihw_off], s[oc] * adj_scale);
- c[oc] -= (128 * (int32_t)(out[index(oc, ic)]));
- }
- }
-# undef index
- };
-
- constexpr int i_mult = blksize;
- constexpr int o_mult = 1;
-
- size_t offset = G * pdims[w_groups+0] * pdims[w_groups+1] * H * W;
- int32_t *cp = reinterpret_cast<int32_t *>(output + offset);
- parallel_nd(G * NB_OC * blksize, [&](int i) {
- cp[i] = 0;
- });
-
-# define wei_blk_off(md, g, o, i, h, w) \
- (is_1d ? (md).blk_off<!w_groups>(g, o, i, w) \
- : (md).blk_off<!w_groups>(g, o, i, h, w))
-
- parallel_nd(G, NB_OC, [&](int g, int O) {
- for (int I = 0; I < NB_IC; I++)
- for (int h = 0; h < H; h++)
- for (int w = 0; w < W; w++) {
- auto i = &input[wei_blk_off(
- input_d, g, i_mult * O, i_mult * I, h, w)];
- auto o = &output[wei_blk_off(
- output_d, g, o_mult * O, o_mult * I, h, w)];
- const int oc_block = nstl::min(blksize, OC - O * blksize);
- const int ic_block = nstl::min(blksize, IC - I * blksize);
-
- int _offset = (g * NB_OC + O) * blksize;
- ker(i, o, (order_keep) ? &cp[_offset] : nullptr,
- &scales[(D_mask == 1) ? 0 : _offset],
- oc_block, ic_block);
- }
- });
-
-# undef wei_blk_off
-
- return success;
- }
-};
-
-template <SIMPLE_REORDER_TEMPL_DECL>
-struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
- typename utils::enable_if<false
- ||(tag_i == goiw && tag_o == Goiw16g)
- ||(tag_i == goihw && tag_o == Goihw16g)
- , spec::conv_s8s8>::type>
-{
- static bool is_applicable(const memory_desc_wrapper &input_d,
- const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
- const size_t D_mask = utils::array_product(input_d.dims(),
- math::ilog2q(attr->output_scales_.mask_ + 1));
- const int oc = input_d.dims()[1];
- const int g = input_d.dims()[0];
-
- return true
- && order_keep
- && input_d.matches_tag(tag_i)
- && output_d.matches_tag(tag_o)
- && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8)
- && (input_d.data_type() == f32 || input_d.data_type() == s8)
- && output_d.data_type() == s8
- && (D_mask == 1 || D_mask == (size_t)g * oc);
- }
-
- static status_t execute(const cpu_reorder_pd_t *pd,
- const data_t<type_i> *input, data_t<type_o> *output) {
- DECLARE_COMMON_PARAMS();
-
- constexpr bool is_1d = tag_i == goiw;
- constexpr int blksize = 16;
-
- const auto &dims = input_d.dims();
- const auto &pdims = output_d.padded_dims();
- const int G = dims[0];
- const int Gp = pdims[0];
- const int OC = dims[1];
- const int IC = dims[2];
- const int H = is_1d ? 1 : dims[3];
- const int W = dims[4 - is_1d];
-
- const size_t D_mask = utils::array_product(input_d.dims(),
- math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
- const float *scales = pd->attr()->output_scales_.scales_;
-
- assert(output_d.extra().flags
- & memory_extra_flags::compensation_conv_s8s8);
- float adj_scale =
- (output_d.extra().flags & memory_extra_flags::scale_adjust)
- ? output_d.extra().scale_adjust : 1.f;
-
- auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
- int32_t *cp, const float *s, const int g_block) {
- PRAGMA_OMP_SIMD()
- for (int g = 0; g < g_block; g++) {
- const auto i_off = g * input_d.blocking_desc().strides[0];
- out[g] = qz_b0<data_t<type_i>, data_t<type_o>>()(
- inp[i_off], s[g * OC] * adj_scale);
- cp[g * OC] -= 128 * (int32_t)(out[g]);
- }
- };
-
- size_t cp_offset = output_d.size() - output_d.additional_buffer_size();
- int32_t *cp = reinterpret_cast<int32_t *>(output + cp_offset);
- parallel_nd((Gp/blksize) * OC, [&](int ib) {
- PRAGMA_OMP_SIMD()
- for (int i = 0; i < blksize; i++)
- cp[ib * blksize + i] = 0;
- });
-
-# define wei_blk_off(md, g, o, i, h, w) \
- (is_1d ? (md).blk_off(g, o, i, w) : (md).blk_off(g, o, i, h, w))
-
- parallel_nd(Gp/blksize, OC, [&](int gb, int O) {
- for (int I = 0; I < IC; I++) {
- for (int h = 0; h < H; h++)
- for (int w = 0; w < W; w++)
- {
- const int g_block = nstl::min(G - gb * blksize, blksize);
- const auto inp = &input[wei_blk_off(
- input_d, gb * blksize, O, I, h, w)];
- const auto out = &output[wei_blk_off(
- output_d, gb, O, I, h, w)];
- int offset = gb * blksize + O;
- ker(inp, out, &cp[offset],
- &scales[(D_mask == 1) ? 0 : offset], g_block);
- }
- }
- });
-
-# undef wei_blk_off
-
- return success;
- }
-};
-
-/* reorders with tail support */
-
-template <SIMPLE_REORDER_TEMPL_DECL>
-struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
-typename utils::enable_if<false
- || (tag_i == nCdhw8c && tag_o == nCdhw16c)
- || (tag_i == nChw8c && tag_o == nChw16c)
- || (tag_i == nCw8c && tag_o == nCw16c)
- >::type>
-{
- static bool is_applicable(const memory_desc_wrapper &input_d,
- const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
- {
- return simple_fmt_check(order_keep, tag_i, tag_o, input_d, output_d)
- && simple_attr_check(attr, false);
- }
-
- static status_t execute(const cpu_reorder_pd_t *pd,
- const data_t<type_i> *input, data_t<type_o> *output) {
- DECLARE_COMMON_PARAMS();
-
- constexpr int is_1d = tag_i == nCw8c;
- constexpr int is_3d = tag_i == nCdhw8c;
- constexpr int blksize_16 = 16;
- constexpr int blksize_8 = 8;
- constexpr int ic_mult = order_keep ? 2 : 1;
- constexpr int oc_mult = order_keep ? 1 : 2;
-
- const auto &dims = input_d.dims();
- const auto &pdims = order_keep ? output_d.padded_dims()
- : input_d.padded_dims();
-
- const int C = dims[1];
- const int D = is_3d ? dims[2] : 1;
- const int H = is_1d ? 1 : dims[2 + is_3d];
- const int W = dims[3 + is_3d - is_1d];
-
- auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
- const int block_16) {
- const int nb = (block_16 - 1) / blksize_8 + 1;
- if (alpha == 1.0 && beta == 0.0) {
- for (int b = 0; b < nb; ++b) {
- const ptrdiff_t i_off = order_keep ? b : b * blksize_8;
- const ptrdiff_t o_off = order_keep ? b * blksize_8 : b;
- const int block_8 = nstl::min(blksize_8,
- block_16 - b * blksize_8);
- for (int c = 0; c < block_8; ++c) {
- o[o_off + c] = _qz_a1b0<type_i, type_o>()(
- i[i_off + c]);
- }
- }
- } else {
- for (int b = 0; b < nb; ++b) {
- const ptrdiff_t i_off = order_keep ? b : b * blksize_8;
- const ptrdiff_t o_off = order_keep ? b * blksize_8 : b;
- const int block_8 = nstl::min(blksize_8,
- block_16 - b * blksize_8);
- for (int c = 0; c < block_8; ++c) {
- o[o_off + c] = _qz<type_i, type_o>()(i[i_off + c],
- o[o_off + c], alpha, beta);
- }
- }
- }
- };
-
-# define data_blk_off(md, n, c, d, h, w) \
- ( is_1d ? (md).blk_off(n, c, w) \
- : is_3d ? (md).blk_off(n, c, d, h, w) : (md).blk_off(n, c, h, w))
-
- parallel_nd(dims[0], pdims[1] / blksize_16, D, H, W,
- [&](int n, int nb_c, int d, int h, int w) {
- auto i = &input[data_blk_off(input_d, n, ic_mult * nb_c, d, h, w)];
- auto o = &output[data_blk_off(output_d, n, oc_mult * nb_c, d, h, w)];
- const int block_16 = nstl::min(blksize_16, C - nb_c * blksize_16);
- ker(i, o, block_16);
- });
-
-# undef data_blk_off
-
- return success;
- }
-};
-
-#define PLAIN_TO_BLOCKED_IS_APPLICABLE() \
- static bool is_applicable(const memory_desc_wrapper &input_d, \
- const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { \
- return simple_attr_check(attr, false) && (order_keep \
- ? output_d.matches_tag(tag_o) && input_d.is_plain() \
- : input_d.matches_tag(tag_o) && output_d.is_plain()); \
- }
-
-template <SIMPLE_REORDER_TEMPL_DECL>
-struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
-typename utils::enable_if<tag_i == any
- && (tag_traits<tag_o>::block_dims == bd::_A
- || tag_traits<tag_o>::block_dims == bd::_B)
- && tag_traits<tag_o>::ndims >= 3
- && tag_traits<tag_o>::ndims <= 6
- >::type>
-{
- PLAIN_TO_BLOCKED_IS_APPLICABLE();
-
- static status_t execute(const cpu_reorder_pd_t *pd,
- const data_t<type_i> *input, data_t<type_o> *output) {
- DECLARE_COMMON_PARAMS();
-
- const auto &flat_d = order_keep ? input_d : output_d;
- const auto &block_d = order_keep ? output_d : input_d;
- const auto &dims = input_d.dims();
- const auto &pdims = block_d.padded_dims();
-
- constexpr int ndims = tag_traits<tag_o>::ndims;
- constexpr int blk_idx = tag_traits<tag_o>::block_dims == bd::_A ? 0 : 1;
-
- const dim_t H0 = dims[0];
- const dim_t H1 = dims[1];
- const dim_t M0 = ndims >= 6 ? dims[ndims - 4] : 1;
- const dim_t M1 = ndims >= 5 ? dims[ndims - 3] : 1;
- const dim_t M2 = ndims >= 4 ? dims[ndims - 2] : 1;
- const dim_t L = dims[ndims - 1];
- const dim_t l_blk_stride = block_d.blocking_desc().strides[ndims - 1];
-
- constexpr int blksize = false ? 0
- : utils::one_of(tag_traits<tag_o>::inner_blks, ib::_4a, ib::_4b) ? 4
- : utils::one_of(tag_traits<tag_o>::inner_blks, ib::_8a, ib::_8b) ? 8
- : 16;
-
- auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o, int block) {
- if (alpha == 1.0 && beta == 0.0) {
- for (int l = 0; l < L; ++l)
- for (int blk = 0; blk < block; ++blk) {
- const dim_t flat_off = 0
- + blk * flat_d.blocking_desc().strides[blk_idx]
- + l * flat_d.blocking_desc().strides[ndims - 1];
- if (order_keep) {
- o[l * l_blk_stride + blk] = _qz_a1b0<type_i, type_o>()(
- i[flat_off]);
- } else {
- o[flat_off] = _qz_a1b0<type_i, type_o>()(
- i[l * l_blk_stride + blk]);
- }
- }
- } else {
- for (int l = 0; l < L; ++l)
- for (int blk = 0; blk < block; ++blk) {
- const dim_t flat_off = 0
- + blk * flat_d.blocking_desc().strides[blk_idx]
- + l * flat_d.blocking_desc().strides[ndims - 1];
- if (order_keep) {
- o[l * l_blk_stride + blk] = _qz<type_i, type_o>()(
- i[flat_off], o[l * blksize + blk],
- alpha, beta);
- } else {
- o[flat_off] = _qz<type_i, type_o>()(
- i[l * l_blk_stride + blk], o[flat_off],
- alpha, beta);
- }
- }
- }
- };
-
-# define off(md, h0, h1, m0, m1, m2) \
- (ndims >= 6 ? (md).blk_off(h0, h1, m0, m1, m2) \
- : ndims >= 5 ? (md).blk_off(h0, h1, m1, m2) \
- : ndims >= 4 ? (md).blk_off(h0, h1, m2) \
- : /* ndims >= 3 ? */ (md).blk_off(h0, h1))
-
- constexpr int i_mult = order_keep ? blksize : 1;
- constexpr int o_mult = order_keep ? 1 : blksize;
-
- if (blk_idx == 0) {
- const dim_t BH0 = pdims[0] / blksize;
- parallel_nd(BH0, H1, M0, M1, M2,
- [&](dim_t bh0, dim_t h1, dim_t m0, dim_t m1, dim_t m2) {
- auto i = &input[off(input_d, bh0 * i_mult, h1, m0, m1, m2)];
- auto o = &output[off(output_d, bh0 * o_mult, h1, m0, m1, m2)];
- const int block = nstl::min<int>(blksize, H0 - bh0 * blksize);
- ker(i, o, block);
- });
- } else if (blk_idx == 1) {
- const dim_t BH1 = pdims[1] / blksize;
- parallel_nd(H0, BH1, M0, M1, M2,
- [&](dim_t h0, dim_t bh1, dim_t m0, dim_t m1, dim_t m2) {
- auto i = &input[off(input_d, h0, bh1 * i_mult, m0, m1, m2)];
- auto o = &output[off(output_d, h0, bh1 * o_mult, m0, m1, m2)];
- const int block = nstl::min<int>(blksize, H1 - bh1 * blksize);
- ker(i, o, block);
- });
- } else {
- assert(!"unimplemented");
- }
-
-# undef off
-
- return success;
- }
-};
-
-template <SIMPLE_REORDER_TEMPL_DECL>
-struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
-typename utils::enable_if<tag_i == any
- && (tag_traits<tag_o>::block_dims == bd::_AB
- || tag_traits<tag_o>::block_dims == bd::_BC)
- && IMPLICATION(tag_traits<tag_o>::block_dims == bd::_AB,
- tag_traits<tag_o>::ndims >= 3 && tag_traits<tag_o>::ndims <= 5)
- && IMPLICATION(tag_traits<tag_o>::block_dims == bd::_BC,
- tag_traits<tag_o>::ndims >= 4 && tag_traits<tag_o>::ndims <= 6)
- >::type>
-{
- PLAIN_TO_BLOCKED_IS_APPLICABLE();
-
- static status_t execute(const cpu_reorder_pd_t *pd,
- const data_t<type_i> *input, data_t<type_o> *output) {
- DECLARE_COMMON_PARAMS();
-
- const auto &flat_d = order_keep ? input_d : output_d;
- const auto &dims = input_d.dims();
- const auto &pdims = order_keep
- ? output_d.padded_dims()
- : input_d.padded_dims();
-
- constexpr int ndims = tag_traits<tag_o>::ndims;
-
- static constexpr bool with_g = tag_traits<tag_o>::block_dims == bd::_BC;
- const dim_t G = with_g ? dims[0] : 1;
-
- const dim_t H0 = dims[0 + with_g];
- const dim_t H1 = dims[1 + with_g];
-
- const dim_t M0 = ndims >= 5 + with_g ? dims[ndims - 3] : 1;
- const dim_t M1 = ndims >= 4 + with_g ? dims[ndims - 2] : 1;
- const dim_t M2 = ndims >= 3 + with_g ? dims[ndims - 1] : 1;
-
- constexpr int blksize_0 = false ? 0
- : utils::one_of(tag_traits<tag_o>::inner_blks,
- ib::_4b4a, ib::_4b4c, ib::_4c4b)
- ? 4
- : utils::one_of(tag_traits<tag_o>::inner_blks,
- ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_2c8b4c)
- ? 8
- : utils::one_of(tag_traits<tag_o>::inner_blks,
- ib::_16a16b, ib::_16a4b, ib::_16b16a, ib::_16b4c,
- ib::_16b16c, ib::_16c16b, ib::_8a16b2a, ib::_4b16a4b,
- ib::_8b16a2b, ib::_8b16c2b, ib::_4c16b4c, ib::_8c16b2c)
- ? 16 : INT_MIN;
-
- constexpr int blksize_1 = utils::one_of(tag_traits<tag_o>::inner_blks,
- ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_2c8b4c)
- ? 8
- : utils::one_of(tag_traits<tag_o>::inner_blks,
- ib::_16a16b, ib::_16b16a, ib::_16b16c, ib::_16c16b,
- ib::_8a16b2a, ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b,
- ib::_4c16b4c, ib::_8c16b2c)
- ? 16
- : utils::one_of(tag_traits<tag_o>::inner_blks,
- ib::_4b4a, ib::_4b4c, ib::_4c4b,
- ib::_16a4b, ib::_16b4c)
- ? 4
- : INT_MIN;
-
- const dim_t NB_H0 = pdims[0 + with_g] / blksize_0;
- const dim_t NB_H1 = pdims[1 + with_g] / blksize_1;
-
- auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
- const int block_h0, const int block_h1) {
-# define blk_off AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>
-
- if (alpha == 1.0 && beta == 0.0) {
- for (int h0 = 0; h0 < block_h0; ++h0)
- for (int h1 = 0; h1 < block_h1; ++h1) {
- const dim_t flat_off = 0
- + h0 * flat_d.blocking_desc().strides[with_g + 0]
- + h1 * flat_d.blocking_desc().strides[with_g + 1];
- if (order_keep) {
- o[blk_off(h0, h1)] = _qz_a1b0<type_i, type_o>()(
- i[flat_off]);
- } else {
- o[flat_off] = _qz_a1b0<type_i, type_o>()(
- i[blk_off(h0, h1)]);
- }
- }
- } else {
- for (int h0 = 0; h0 < block_h0; ++h0)
- for (int h1 = 0; h1 < block_h1; ++h1) {
- const dim_t flat_off = 0
- + h0 * flat_d.blocking_desc().strides[with_g + 0]
- + h1 * flat_d.blocking_desc().strides[with_g + 1];
- if (order_keep) {
- o[blk_off(h0, h1)] = _qz<type_i, type_o>()(i[flat_off],
- o[blk_off(h0, h1)], alpha, beta);
- } else {
- o[flat_off] = _qz<type_i, type_o>()(i[blk_off(h0, h1)],
- o[flat_off], alpha, beta);
- }
- }
- }
-
-# undef blk_off
- };
-
- constexpr int i_mult_0 = order_keep ? blksize_0 : 1;
- constexpr int o_mult_0 = order_keep ? 1 : blksize_0;
-
- constexpr int i_mult_1 = order_keep ? blksize_1 : 1;
- constexpr int o_mult_1 = order_keep ? 1 : blksize_1;
-
-# define off(md, g, h0, h1, m0, m1, m2) \
- (ndims >= 5 + with_g ? (md).blk_off<!with_g>(g, h0, h1, m0, m1, m2) \
- : ndims >= 4 + with_g ? (md).blk_off<!with_g>(g, h0, h1, m1, m2) \
- : /* ndims >= 3 + with_g ? */ (md).blk_off<!with_g>(g, h0, h1, m2))
-
- parallel_nd(G, NB_H0, NB_H1, M0, M1, M2,
- [&](dim_t g, dim_t nb_h0, dim_t nb_h1, dim_t m0, dim_t m1, dim_t m2) {
- auto i = &input[off(input_d,
- g, i_mult_0 * nb_h0, i_mult_1 * nb_h1, m0, m1, m2)];
- auto o = &output[off(output_d,
- g, o_mult_0 * nb_h0, o_mult_1 * nb_h1, m0, m1, m2)];
- const int block_h0 = nstl::min<int>(blksize_0, H0 - nb_h0 * blksize_0);
- const int block_h1 = nstl::min<int>(blksize_1, H1 - nb_h1 * blksize_1);
- ker(i, o, block_h0, block_h1);
- });
-
-# undef off
-
- return success;
- }
-};
-
-/* generic and direct-copy reorders */
-
-template <SIMPLE_REORDER_TEMPL_DECL>
-struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
- typename utils::enable_if<
- tag_i == any && tag_o == any && order_keep == fmt_order::any,
- spec::direct_copy>::type>
-{
- static bool is_applicable(const memory_desc_wrapper &input_d,
- const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
- /* FIXME: is the formula correct? */
- return input_d.similar_to(output_d, true, false, 0)
- && input_d.is_dense() && output_d.is_dense()
- && simple_attr_check(attr, false);
- }
-
- static status_t execute(const cpu_reorder_pd_t *pd,
- const data_t<type_i> *input, data_t<type_o> *output) {
- DECLARE_COMMON_PARAMS();
-
- assert(input_d.is_dense());
-
- input += input_d.blk_off(0);
- output += output_d.blk_off(0);
-
- const size_t nelems = input_d.nelems();
-
- constexpr int block_size = 16;
- const auto num_blocks = nelems / block_size;
- const auto rem_elems = nelems % block_size;
-
- parallel(0, [&](const int ithr, const int nthr) {
- size_t start{0}, end{0};
- balance211(num_blocks, nthr, ithr, start, end);
- start = start * block_size;
- end = end * block_size;
-
- if (alpha == 1.0 && beta == 0.0) {
- PRAGMA_OMP_SIMD()
- for (size_t e = start; e < end; ++e) {
- output[e] = qz_a1b0<data_t<type_i>, data_t<type_o>>()
- (input[e]);
- }
- } else if (alpha == 1.0) {
- PRAGMA_OMP_SIMD()
- for (size_t e = start; e < end; ++e) {
- output[e] = qz_a1<data_t<type_i>, data_t<type_o>>()
- (input[e], output[e], beta);
- }
- } else if (beta == 0.0) {
- PRAGMA_OMP_SIMD()
- for (size_t e = start; e < end; ++e) {
- output[e] = qz_b0<data_t<type_i>, data_t<type_o>>()
- (input[e], alpha);
- }
- } else {
- PRAGMA_OMP_SIMD()
- for (size_t e = start; e < end; ++e) {
- output[e] = qz<data_t<type_i>, data_t<type_o>>()
- (input[e], output[e], alpha, beta);
- }
- }
-
- if (rem_elems != 0 && ithr == nthr - 1){
- if (alpha == 1.0 && beta == 0.0) {
- PRAGMA_OMP_SIMD()
- for (size_t e = nelems - rem_elems; e < nelems; ++e) {
- output[e] = qz_a1b0<data_t<type_i>,
- data_t<type_o>>()(input[e]);
- }
- } else if (alpha == 1.0) {
- PRAGMA_OMP_SIMD()
- for (size_t e = nelems - rem_elems; e < nelems; ++e) {
- output[e] = qz_a1<data_t<type_i>,
- data_t<type_o>>()(input[e], output[e], beta);
- }
- } else if (beta == 0.0) {
- PRAGMA_OMP_SIMD()
- for (size_t e = nelems - rem_elems; e < nelems; ++e) {
- output[e] = qz_b0<data_t<type_i>,
- data_t<type_o>>()(input[e], alpha);
- }
- } else {
- PRAGMA_OMP_SIMD()
- for (size_t e = nelems - rem_elems; e < nelems; ++e) {
- output[e] = qz<data_t<type_i>, data_t<type_o>>()
- (input[e], output[e], alpha, beta);
- }
- }
- }
- });
- return success;
- }
-};
-
-template <SIMPLE_REORDER_TEMPL_DECL>
-struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
- typename utils::enable_if<
- tag_i == any && tag_o == any && order_keep == fmt_order::any,
- spec::direct_copy_except_dim_0>::type>
-{
- static bool is_applicable(const memory_desc_wrapper &input_d,
- const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
- auto is_dense_no_0 = [](const memory_desc_wrapper &data_d) {
- return nelems_no_dim_0(data_d) == _size_no_dim_0(data_d);
- };
- /* FIXME: is the formula correct? */
- return input_d.similar_to(output_d, true, false, 1)
- && is_dense_no_0(input_d) && is_dense_no_0(output_d)
- && simple_attr_check(attr, false);
- }
-
- static status_t execute(const cpu_reorder_pd_t *pd,
- const data_t<type_i> *input, data_t<type_o> *output) {
- DECLARE_COMMON_PARAMS();
-
- input += input_d.blk_off(0);
- output += output_d.blk_off(0);
-
- const int N = input_d.dims()[0];
- const dim_t is = input_d.blocking_desc().strides[0];
- const dim_t os = output_d.blocking_desc().strides[0];
- const dim_t nelems_no_d0 = nelems_no_dim_0(input_d);
- const dim_t work_amount = N * nelems_no_d0;
-
- if (alpha == 1.0 && beta == 0.0) {
- parallel(0, [&](const int ithr, const int nthr) {
- dim_t n{0}, dim1_s{0};
- dim_t start{0}, end{0};
- balance211(work_amount, nthr, ithr, start, end);
- nd_iterator_init(start, n, N, dim1_s, nelems_no_d0);
- while(start < end) {
- dim_t work_rem = end - start;
- dim_t dim1_e = dim1_s + work_rem > nelems_no_d0
- ? nelems_no_d0 : dim1_s + work_rem;
- PRAGMA_OMP_SIMD()
- for (dim_t e = dim1_s; e < dim1_e; ++e) {
- output[os * n + e] = _qz_a1b0<type_i, type_o>()(
- input[is * n + e]);
- }
- nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0);
- }
- });
- } else {
- parallel(0, [&](const int ithr, const int nthr) {
- dim_t n{0}, dim1_s{0};
- dim_t start{0}, end{0};
- balance211(work_amount, nthr, ithr, start, end);
- nd_iterator_init(start, n, N, dim1_s, nelems_no_d0);
- while(start < end) {
- dim_t work_rem = end - start;
- dim_t dim1_e =
- dim1_s + work_rem > nelems_no_d0 ? nelems_no_d0
- : dim1_s + work_rem;
- PRAGMA_OMP_SIMD()
- for (dim_t e = dim1_s; e < dim1_e; ++e){
- output[os * n + e] = _qz<type_i, type_o>()(
- input[is * n + e], output[os * n + e], alpha,
- beta);
- }
- nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0);
- }
- });
- }
-
- return success;
- }
-
-private:
- static dim_t nelems_no_dim_0(const memory_desc_wrapper &data_d) {
- const int ndims = data_d.ndims();
- if (ndims <= 1) return 1;
- return utils::array_product(data_d.dims() + 1, data_d.ndims() - 1);
- }
-
- static dim_t _size_no_dim_0(const memory_desc_wrapper &data_d) {
- dims_t blocks;
- data_d.compute_blocks(blocks);
-
- const auto &blk = data_d.blocking_desc();
-
- dim_t blk_size = 1;
- for (int iblk = 0; iblk < blk.inner_nblks; ++iblk)
- blk_size *= blk.inner_blks[iblk];
-
- dim_t max_size = blk_size;
- for (int d = 1; d < data_d.ndims(); ++d) {
- max_size = nstl::max(max_size,
- data_d.padded_dims()[d] / blocks[d] * blk.strides[d]);
- }
-
- return max_size;
- }
-};
-
-template <SIMPLE_REORDER_TEMPL_DECL>
-struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
- typename utils::enable_if<
- tag_i == any && tag_o == any && order_keep == fmt_order::any,
- spec::reference>::type>
-{
- static bool is_applicable(const memory_desc_wrapper &input_d,
- const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
- /* supported smask: 0x0...011..10...0,
- * i.e. 1 should be contiguous */
- int smask = attr ? attr->output_scales_.mask_ : 0;
- for (; smask > 0 && !(smask & 0x1); smask >>= 1);
- for (; smask > 0 && smask & 0x1; smask >>= 1);
- return true
- && input_d.is_blocking_desc()
- && output_d.is_blocking_desc()
- && !output_d.is_additional_buffer()
- && !input_d.is_additional_buffer()
- && smask == 0;
- }
-
- static status_t execute(const cpu_reorder_pd_t *pd,
- const data_t<type_i> *input, data_t<type_o> *output) {
- DECLARE_COMMON_PARAMS();
-
- const size_t nelems = input_d.nelems();
-
- int ndims_start = 0, ndims_mask = 0;
- int smask = pd->attr()->output_scales_.mask_;
- for (; smask > 0 && !(smask & 0x1); smask >>= 1) ++ndims_start;
- for (; smask > 0 && smask & 0x1; smask >>= 1) ++ndims_mask;
- assert(smask == 0);
-
- const ptrdiff_t D_start
- = utils::array_product(input_d.dims(), ndims_start);
- const ptrdiff_t D_mask
- = utils::array_product(input_d.dims() + ndims_start, ndims_mask);
- const ptrdiff_t D_rest = nelems / D_start / D_mask;
-
- const float *scales = pd->attr()->output_scales_.scales_;
-
- parallel_nd(D_start, D_mask, D_rest,
- [&](ptrdiff_t ds, ptrdiff_t dm, ptrdiff_t dr) {
- const float scale = scales[dm];
-
- const size_t e = (ds * D_mask + dm) * D_rest + dr;
- const auto &i = input[input_d.off_l(e)];
- auto &o = output[output_d.off_l(e)];
-
- o = _qz<type_i, type_o>()(i, o, scale, beta);
- });
-
- return success;
- }
-};
-
-
-/* high level class declaration */
-
-template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void>
-struct simple_reorder_t: public cpu_primitive_t {
- struct pd_t: public cpu_reorder_pd_t {
- using cpu_reorder_pd_t::cpu_reorder_pd_t;
-
- DECLARE_COMMON_PD_T("simple:any", simple_reorder_t);
-
- static status_t create(reorder_pd_t **reorder_pd,
- engine_t *engine, const primitive_attr_t *attr,
- engine_t *src_engine, const memory_desc_t *src_md,
- engine_t *dst_engine, const memory_desc_t *dst_md) {
- bool args_ok = true
- && src_md->data_type == type_i
- && dst_md->data_type == type_o
- && simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::
- is_applicable(src_md, dst_md, attr);
- if (!args_ok)
- return status::invalid_arguments;
-
- auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
- dst_md);
- if (_pd == nullptr) return status::out_of_memory;
- if (_pd->init() != status::success) {
- delete _pd;
- return status::unimplemented;
- }
- return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
- }
- };
-
- simple_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {}
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- auto input = CTX_IN_MEM(const data_t<type_i> *, MKLDNN_ARG_FROM);
- auto output = CTX_OUT_MEM(data_t<type_o> *, MKLDNN_ARG_TO);
- simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::execute(
- pd(), input, output);
- return status::success;
- }
-
-private:
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-#undef SIMPLE_REORDER_TEMPL_DECL
-#undef SIMPLE_REORDER_TEMPL_CALL
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp
deleted file mode 100644
index f0947573a9..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp
+++ /dev/null
@@ -1,91 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#include "mkldnn_thread.hpp"
-
-#include "simple_sum.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <data_type_t data_type>
-status_t simple_sum_t<data_type>::execute(const exec_ctx_t &ctx) const {
- auto output = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
-
- const memory_desc_wrapper o_d(pd()->dst_md());
- output += o_d.blk_off(0);
-
- const int num_arrs = pd()->n_inputs();
- const data_t *input_ptrs[max_num_arrs];
- const size_t nelems = o_d.nelems();
-
- for (int a = 0; a < num_arrs; ++a) {
- const memory_desc_wrapper i_d(pd()->src_md(a));
- input_ptrs[a] = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MULTIPLE_SRC + a)
- + i_d.blk_off(0);
- }
-
- const size_t block_size = 16 * 1024 / sizeof(data_type);
- const size_t blocks_number = nelems / block_size;
- const size_t tail = nelems % block_size;
-
- const auto scales = pd()->scales();
- parallel(0, [&](const int ithr, const int nthr) {
- size_t start{0}, end{0};
- balance211(blocks_number, nthr, ithr, start, end);
-
- for (size_t nb = start; nb < end; ++nb) {
- size_t start_e = nb * block_size;
- size_t end_e = start_e + block_size;
-
- PRAGMA_OMP_SIMD()
- for (size_t e = start_e; e < end_e; e++) {
- output[e] = data_t(scales[0] * input_ptrs[0][e]);
- }
- for (int a = 1; a < num_arrs; a++) {
- PRAGMA_OMP_SIMD()
- for (size_t e = start_e; e < end_e; e++) {
- output[e] += data_t(scales[a] * input_ptrs[a][e]);
- }
- }
- }
-
- if (tail != 0 && ithr == nthr - 1) {
- size_t start_e = nelems - tail;
- size_t end_e = nelems;
-
- PRAGMA_OMP_SIMD()
- for (size_t e = start_e; e < end_e; e++) {
- output[e] = data_t(scales[0] * input_ptrs[0][e]);
- }
- for (int a = 1; a < num_arrs; a++) {
- PRAGMA_OMP_SIMD()
- for (size_t e = start_e; e < end_e; e++) {
- output[e] += data_t(scales[a] * input_ptrs[a][e]);
- }
- }
- }
- });
-
- return status::success;
-}
-
-template struct simple_sum_t<data_type::f32>;
-
-}
-}
-}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp
deleted file mode 100644
index 2a0187a184..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp
+++ /dev/null
@@ -1,74 +0,0 @@
-/*******************************************************************************
-* Copyright 2017-2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef SIMPLE_SUM_HPP
-#define SIMPLE_SUM_HPP
-
-#include "cpu_sum_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <data_type_t data_type>
-struct simple_sum_t: public cpu_primitive_t {
- struct pd_t: public cpu_sum_pd_t {
- using cpu_sum_pd_t::cpu_sum_pd_t;
-
- DECLARE_SUM_PD_T("simple:any", simple_sum_t);
-
- status_t init() {
- const int n = n_inputs();
-
- bool ok = true
- && cpu_sum_pd_t::init() == status::success
- && n <= max_num_arrs;
- if (!ok) return status::unimplemented;
-
- const memory_desc_wrapper o_d(dst_md());
- ok = ok
- && o_d.data_type() == data_type
- && o_d.is_dense();
- if (!ok) return status::unimplemented;
-
- for (int i = 0; i < n; ++i) {
- const memory_desc_wrapper i_d(src_md(i));
- if (i_d != o_d) return status::unimplemented;
- }
-
- return status::success;
- }
- };
-
- simple_sum_t(const pd_t *apd): cpu_primitive_t(apd) {}
-
- virtual status_t execute(const exec_ctx_t &ctx) const override;
-
- enum {max_num_arrs = 16 };
- typedef typename prec_traits<data_type>::type data_t;
-
-private:
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp
deleted file mode 100644
index c2082d7d62..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp
+++ /dev/null
@@ -1,376 +0,0 @@
-/*******************************************************************************
- * Copyright 2017-2018 Intel Corporation
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *******************************************************************************/
-
-#ifndef CPU_WINO_REORDER_HPP
-#define CPU_WINO_REORDER_HPP
-
-#include "mkldnn_thread.hpp"
-
-#include "simple_q10n.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-template <data_type_t type_i, data_type_t type_o>
-struct wino_reorder_t : public cpu_primitive_t {
- struct pd_t : public cpu_reorder_pd_t {
- using cpu_reorder_pd_t::cpu_reorder_pd_t;
-
- DECLARE_COMMON_PD_T("wino_reorder", wino_reorder_t);
-
- static status_t create(reorder_pd_t **reorder_pd,
- engine_t *engine, const primitive_attr_t *attr,
- engine_t *src_engine, const memory_desc_t *src_md,
- engine_t *dst_engine, const memory_desc_t *dst_md) {
- const memory_desc_wrapper id(src_md), od(dst_md);
- bool args_ok = true
- && id.data_type() == type_i
- && od.data_type() == type_o
- && id.matches_tag(utils::pick(id.ndims() - 4,
- format_tag::oihw, format_tag::goihw))
- && od.format_kind() == format_kind::wino
- && utils::one_of(od.wino_desc().wino_format,
- mkldnn_wino_wei_aaOIoi, mkldnn_wino_wei_aaOio,
- mkldnn_wino_wei_aaOBiOo, mkldnn_wino_wei_OBaaIBOIio);
- if (!args_ok) return status::invalid_arguments;
-
- auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
- dst_md);
- if (_pd == nullptr) return status::out_of_memory;
- if (_pd->init() != status::success) {
- delete _pd;
- return status::unimplemented;
- }
- return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
- }
-
- status_t init() {
- status_t status = cpu_reorder_pd_t::init();
- if (status != status::success) return status;
-
- init_scratchpad();
-
- return status::success;
- }
-
- private:
- void init_scratchpad() {
- auto &o = memory_desc_wrapper(dst_md()).wino_desc();
- size_t transform_space_size = (size_t)o.r * o.alpha * o.oc_block;
- size_t plain_size = (size_t)o.alpha * o.alpha * o.oc * o.ic;
-
- using namespace memory_tracking::names;
- auto scratchpad = scratchpad_registry().registrar();
- scratchpad.book(key_reorder_wino_transform_space,
- sizeof(in_data_t) * transform_space_size);
- scratchpad.book(key_reorder_wino_plain,
- sizeof(out_data_t) * plain_size);
- }
- };
-
-private:
- typedef typename prec_traits<type_i>::type in_data_t;
- typedef typename prec_traits<type_o>::type out_data_t;
- const int unsign_val_in_wino_domain_ = 5;
-
- wino_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {
- const memory_desc_wrapper src_d(pd()->src_md());
- const memory_desc_wrapper dst_d(pd()->dst_md());
-
- r_ = dst_d.wino_desc().r;
- w_alpha_ = dst_d.wino_desc().alpha;
- wino_format_ = dst_d.wino_desc().wino_format;
-
- const auto &in_dims = src_d.dims();
- int groups;
- int groups_offset;
- if (src_d.ndims() == 5) {
- groups = in_dims[0];
- groups_offset = 1;
- } else {
- groups = 1;
- groups_offset = 0;
- }
- assert(groups == 1); // groups are not supported now
- MAYBE_UNUSED(groups);
-
- or_oc_ = in_dims[0 + groups_offset];
- or_ic_ = in_dims[1 + groups_offset];
- kh_ = in_dims[2 + groups_offset];
- kw_ = in_dims[3 + groups_offset];
-
- oc_ = dst_d.wino_desc().oc;
- ic_ = dst_d.wino_desc().ic;
- oc_block_ = dst_d.wino_desc().oc_block;
- ic_block_ = dst_d.wino_desc().ic_block;
- assert(oc_ % oc_block_ == 0 && ic_ % ic_block_ == 0);
- nb_oc_ = oc_ / oc_block_;
- nb_ic_ = ic_ / ic_block_;
- ic2_block_ = 1;
- if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio)
- ic2_block_ = dst_d.wino_desc().ic2_block;
- oc2_block_ = dst_d.wino_desc().oc2_block;
- assert(nb_ic_ % ic2_block_ == 0 && nb_oc_ % oc2_block_ == 0);
-
- adj_scale_ = dst_d.wino_desc().adj_scale;
-
- size_wino_wei_ = w_alpha_ * w_alpha_ * oc_ * ic_;
- size_wspace_ = r_ * w_alpha_ * oc_block_;
- }
-
- void transform(out_data_t *__restrict tmp_wei,
- const in_data_t *__restrict input,
- in_data_t *__restrict wspace) const {
- const memory_desc_wrapper src_d(pd()->src_md());
-
- const int smask = pd()->attr()->output_scales_.mask_;
- const int ndims_mask = math::ilog2q(smask + 1);
- const size_t D_mask = utils::array_product(src_d.dims(), ndims_mask);
- const float *__restrict scales = pd()->attr()->output_scales_.scales_;
- assert(D_mask == 1 || D_mask == (size_t)oc_);
-
- /* transform weights to winograd domain */
- const float G_2x2_3x3[4][3] = { { 1.0, 0.0, 0.0 }, { 0.5, 0.5, 0.5 },
- { 0.5, -0.5, 0.5 }, { 0.0, 0.0, 1.0 } };
-
- const float G_4x4_3x3[6][3] = { { 1.13777777777778f, 0.f, 0.f },
- { -0.688403361344538f, -0.430252100840336f, -0.26890756302521f },
- { -0.688403361344538f, 0.430252100840336f, -0.26890756302521f },
- { 0.119514472455649f, 0.179271708683473f, 0.26890756302521f },
- { 0.119514472455649f, -0.179271708683473f, 0.26890756302521f },
- { 0.f, 0.f, 1.f } };
-
- float *__restrict g;
- if (utils::one_of(wino_format_, mkldnn_wino_wei_aaOIoi,
- mkldnn_wino_wei_aaOio, mkldnn_wino_wei_aaOBiOo))
- g = (float *)G_2x2_3x3;
- else if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio)
- g = (float *)G_4x4_3x3;
- else {
- assert("Unknown winograd weights target layout");
- return;
- }
-
- int Z = oc_ * ic_;
- assert(r_ == kh_ && r_ == kw_);
-
- for (int iic = 0; iic < ic_; iic++) {
- for (int ob = 0; ob < nb_oc_; ob++) {
- const in_data_t *__restrict _inp
- = input + (ob * oc_block_ * or_ic_ + iic) * kh_ * kw_;
- out_data_t *__restrict _out
- = tmp_wei + (iic * nb_oc_ + ob) * oc_block_;
-
- for_nd(0, 1, size_wspace_, [&](int i) { wspace[i] = 0.f; });
-
- for_nd(0, 1, r_, w_alpha_, oc_block_,
- [&](int ih, int j, int ioc) {
- for (int iw = 0; iw < r_; ++iw) {
- int inp_oc = ob * oc_block_ + ioc;
- int inp_ic = iic;
- in_data_t inp_v = (inp_ic < or_ic_ && inp_oc < or_oc_)
- ? _inp[ioc * or_ic_ * kh_ * kw_ + ih * kw_ + iw]
- : 0.f;
- wspace[(ih * w_alpha_ + j) * oc_block_ + ioc]
- += inp_v * g[j * r_ + iw];
- }
- });
-
- for_nd(0, 1, w_alpha_, w_alpha_, oc_block_,
- [&](int i, int j, int ioc) {
- float t = 0;
- for (int k = 0; k < r_; ++k)
- t += g[i * r_ + k]
- * wspace[(k * w_alpha_ + j) * oc_block_ + ioc];
- if (type_o == data_type::s8) {
- const float scale = (D_mask == 1)
- ? scales[0]
- : scales[ob * oc_block_ + ioc];
- _out[(i * w_alpha_ + j) * Z + ioc]
- = qz_b0<in_data_t, out_data_t>()(
- (in_data_t)t, scale * adj_scale_);
- } else {
- _out[(i * w_alpha_ + j) * Z + ioc] = (out_data_t)t;
- }
- });
- }}
- }
-
- void reorder_to_aaOIoi(out_data_t *__restrict output,
- const out_data_t *__restrict tmp_wei) const {
- int32_t *__restrict dst_bias = nullptr;
- if (type_o == data_type::s8) {
- const auto bias_shift = sizeof(out_data_t) * size_wino_wei_;
- const size_t bias_size = w_alpha_ * w_alpha_ * oc_;
-
- dst_bias = (int32_t *)(output + bias_shift);
- utils::array_set((int32_t *)dst_bias, 0, bias_size);
- }
- int index = 0;
- for (int u_h = 0; u_h < w_alpha_; u_h++) {
- for (int u_w = 0; u_w < w_alpha_; u_w++) {
- for_nd(0, 1, nb_oc_, oc_block_, [&](int ob, int o) {
- int u_h_shift = u_h * w_alpha_ * ic_ * oc_;
- int u_w_shift = u_w * ic_ * oc_;
- int u_h_shift_b = u_h * w_alpha_ * oc_;
- int u_w_shift_b = u_w * oc_;
- int oc_block_shift = ob * oc_block_ * ic_ + o * ic_block_;
- for (int ib = 0; ib < nb_ic_; ib++) {
- for (int i = 0; i < ic_block_; i++) {
- int _i = ib * ic_block_;
- int _o = ob * oc_block_;
- int ic_shift = (_i + i) * oc_;
- int oc_shift = (_o + o);
- int ic_block_shift = ib * oc_block_ * ic_block_ + i;
- int src_offset =
- u_h_shift + u_w_shift + ic_shift + oc_shift;
- int dst_offset = u_h_shift + u_w_shift + oc_block_shift
- + ic_block_shift;
-
- output[dst_offset] = tmp_wei[src_offset];
- if (type_o == data_type::s8) {
- int bias_offset = u_h_shift_b + u_w_shift_b + oc_shift;
- if (index != unsign_val_in_wino_domain_)
- dst_bias[bias_offset]
- -= (128 * (int32_t)output[dst_offset]);
- else
- dst_bias[bias_offset] = 0;
- }
- }}
- });
- index++;
- }}
- }
-
- void reorder_to_aaOio(out_data_t *__restrict output,
- const out_data_t *__restrict tmp_wei) const {
- for_nd(0, 1, w_alpha_, w_alpha_, nb_oc_,
- [&](int u_h, int u_w, int ob) {
- for (int ib = 0; ib < nb_ic_; ib++) {
- for (int i = 0; i < ic_block_; i++) {
- for (int o = 0; o < oc_block_; o++) {
- int src_offset = u_h * w_alpha_ * ic_ * oc_ + u_w * ic_ * oc_
- + (ib * ic_block_ + i) * oc_ + (ob * oc_block_ + o);
-
- int dst_offset
- = u_h * w_alpha_ * nb_oc_ * nb_ic_ * ic_block_ * oc_block_
- + u_w * nb_oc_ * nb_ic_ * ic_block_ * oc_block_
- + ob * nb_ic_ * ic_block_ * oc_block_
- + ib * ic_block_ * oc_block_ + i * oc_block_ + o;
- output[dst_offset] = tmp_wei[src_offset];
- }}}
- });
- }
-
- void reorder_to_aaOBiOo(out_data_t *__restrict output,
- const out_data_t *__restrict tmp_wei) const {
- int oc_chunks = nb_oc_ / oc2_block_;
-
- for_nd(0, 1, w_alpha_, w_alpha_, oc_chunks,
- [&](int u_h, int u_w, int occ) {
- for (int ib = 0; ib < nb_ic_; ib++) {
- out_data_t *__restrict wei_ptr = output
- + (((u_h * w_alpha_ + u_w) * oc_chunks + occ) * nb_ic_ + ib)
- * oc2_block_ * ic_block_ * oc_block_;
- int wei_offset = 0;
- for (int i = 0; i < ic_block_; i++) {
- for (int ob2 = 0; ob2 < oc2_block_; ob2++) {
- for (int o = 0; o < oc_block_; o++) {
- int icp = ib * ic_block_ + i;
- int ocp =
- occ * oc2_block_ * oc_block_ + ob2 * oc_block_ + o;
-
- int src_offset = u_h * w_alpha_ * ic_ * oc_
- + u_w * ic_ * oc_ + icp * oc_ + ocp;
- wei_ptr[wei_offset + o] = tmp_wei[src_offset];
- }
- wei_offset += oc_block_;
- }}
- }
- });
- }
-
- void reorder_to_OBaaIBOIio(out_data_t *__restrict output,
- const out_data_t *__restrict tmp_wei) const {
- int ic_chunks = nb_ic_ / ic2_block_;
- int oc_chunks = nb_oc_ / oc2_block_;
-
- for_nd(0, 1, oc_chunks, w_alpha_, w_alpha_,
- [&](int occ, int u_h, int u_w) {
- for (int icc = 0; icc < ic_chunks; icc++) {
- for (int ob = 0; ob < oc2_block_; ob++) {
- int ocp = (occ * oc2_block_ + ob) * oc_block_;
- for (int ib = 0; ib < ic2_block_; ib++) {
- for (int i = 0; i < ic_block_; i++) {
- int icp = (icc * ic2_block_ + ib) * ic_block_ + i;
-
- int src_offset = u_h * w_alpha_ * ic_ * oc_
- + u_w * ic_ * oc_ + icp * oc_ + ocp;
- int wei_offset
- = ((((((occ * w_alpha_ + u_h) * w_alpha_ + u_w)
- * ic_chunks + icc) * oc2_block_ + ob) * ic2_block_
- + ib) * ic_block_ + i) * oc_block_;
- for (int o = 0; o < oc_block_; o++)
- output[wei_offset + o] = tmp_wei[src_offset + o];
- }}
- }}
- });
- }
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM);
- auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO);
-
- auto wspace = (in_data_t *__restrict)scratchpad(ctx).template get<void>(
- memory_tracking::names::key_reorder_wino_transform_space);
- auto tmp_wei = (out_data_t *__restrict)scratchpad(ctx).template get<void>(
- memory_tracking::names::key_reorder_wino_plain);
-
- transform(tmp_wei, input, wspace);
-
- /* reorder to winograd domain */
- switch (wino_format_) {
- case mkldnn_wino_wei_aaOIoi:
- reorder_to_aaOIoi(output, tmp_wei); break;
- case mkldnn_wino_wei_aaOio:
- reorder_to_aaOio(output, tmp_wei); break;
- case mkldnn_wino_wei_aaOBiOo:
- reorder_to_aaOBiOo(output, tmp_wei); break;
- case mkldnn_wino_wei_OBaaIBOIio:
- reorder_to_OBaaIBOIio(output, tmp_wei); break;
- default: assert("Unknown wino format"); break;
- }
-
- return status::success;
- }
-
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- int r_, w_alpha_;
- int ic_, oc_, or_ic_, or_oc_, kh_, kw_;
- int oc_block_, ic_block_, oc2_block_, ic2_block_;
- float adj_scale_;
- int nb_oc_, nb_ic_;
- mkldnn_wino_memory_format_t wino_format_;
- int size_wino_wei_;
- int size_wspace_;
-};
-
-} // namespace cpu
-} // namespace impl
-} // namespace mkldnn
-
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT
deleted file mode 100644
index 66b6ea55d0..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT
+++ /dev/null
@@ -1,47 +0,0 @@
-
-Copyright (c) 2007 MITSUNARI Shigeo
-All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
-
-Redistributions of source code must retain the above copyright notice, this
-list of conditions and the following disclaimer.
-Redistributions in binary form must reproduce the above copyright notice,
-this list of conditions and the following disclaimer in the documentation
-and/or other materials provided with the distribution.
-Neither the name of the copyright owner nor the names of its contributors may
-be used to endorse or promote products derived from this software without
-specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
-LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
-CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
-THE POSSIBILITY OF SUCH DAMAGE.
------------------------------------------------------------------------------
-ソースコード形式かバイナリ形式か、変更するかしないかを問わず、以下の条件を満た
-す場合に限り、再頒布および使用が許可されます。
-
-ソースコードを再頒布する場合、上記の著作権表示、本条件一覧、および下記免責条項
-を含めること。
-バイナリ形式で再頒布する場合、頒布物に付属のドキュメント等の資料に、上記の著作
-権表示、本条件一覧、および下記免責条項を含めること。
-書面による特別の許可なしに、本ソフトウェアから派生した製品の宣伝または販売促進
-に、著作権者の名前またはコントリビューターの名前を使用してはならない。
-本ソフトウェアは、著作権者およびコントリビューターによって「現状のまま」提供さ
-れており、明示黙示を問わず、商業的な使用可能性、および特定の目的に対する適合性
-に関する暗黙の保証も含め、またそれに限定されない、いかなる保証もありません。
-著作権者もコントリビューターも、事由のいかんを問わず、 損害発生の原因いかんを
-問わず、かつ責任の根拠が契約であるか厳格責任であるか(過失その他の)不法行為で
-あるかを問わず、仮にそのような損害が発生する可能性を知らされていたとしても、
-本ソフトウェアの使用によって発生した(代替品または代用サービスの調達、使用の
-喪失、データの喪失、利益の喪失、業務の中断も含め、またそれに限定されない)直接
-損害、間接損害、偶発的な損害、特別損害、懲罰的損害、または結果損害について、
-一切責任を負わないものとします。
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h
deleted file mode 100644
index cf5771332f..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h
+++ /dev/null
@@ -1,2658 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2019 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-/*******************************************************************************
-* Copyright (c) 2007 MITSUNARI Shigeo
-* All rights reserved.
-*
-* Redistribution and use in source and binary forms, with or without
-* modification, are permitted provided that the following conditions are met:
-*
-* Redistributions of source code must retain the above copyright notice, this
-* list of conditions and the following disclaimer.
-* Redistributions in binary form must reproduce the above copyright notice,
-* this list of conditions and the following disclaimer in the documentation
-* and/or other materials provided with the distribution.
-* Neither the name of the copyright owner nor the names of its contributors may
-* be used to endorse or promote products derived from this software without
-* specific prior written permission.
-*
-* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
-* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
-* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
-* THE POSSIBILITY OF SUCH DAMAGE.
-*******************************************************************************/
-
-#pragma once
-#ifndef XBYAK_XBYAK_H_
-#define XBYAK_XBYAK_H_
-/*!
- @file xbyak.h
- @brief Xbyak ; JIT assembler for x86(IA32)/x64 by C++
- @author herumi
- @url https://github.com/herumi/xbyak
- @note modified new BSD license
- http://opensource.org/licenses/BSD-3-Clause
-*/
-#ifndef XBYAK_NO_OP_NAMES
- #if not +0 // trick to detect whether 'not' is operator or not
- #error "use -fno-operator-names option if you want to use and(), or(), xor(), not() as function names, Or define XBYAK_NO_OP_NAMES and use and_(), or_(), xor_(), not_()."
- #endif
-#endif
-
-#include <stdio.h> // for debug print
-#include <assert.h>
-#include <list>
-#include <string>
-#include <algorithm>
-#ifndef NDEBUG
-#include <iostream>
-#endif
-
-// #define XBYAK_DISABLE_AVX512
-
-//#define XBYAK_USE_MMAP_ALLOCATOR
-#if !defined(__GNUC__) || defined(__MINGW32__)
- #undef XBYAK_USE_MMAP_ALLOCATOR
-#endif
-
-#ifdef __GNUC__
- #define XBYAK_GNUC_PREREQ(major, minor) ((__GNUC__) * 100 + (__GNUC_MINOR__) >= (major) * 100 + (minor))
-#else
- #define XBYAK_GNUC_PREREQ(major, minor) 0
-#endif
-
-// This covers -std=(gnu|c)++(0x|11|1y), -stdlib=libc++, and modern Microsoft.
-#if ((defined(_MSC_VER) && (_MSC_VER >= 1600)) || defined(_LIBCPP_VERSION) ||\
- ((__cplusplus >= 201103) || defined(__GXX_EXPERIMENTAL_CXX0X__)))
- #include <unordered_set>
- #define XBYAK_STD_UNORDERED_SET std::unordered_set
- #include <unordered_map>
- #define XBYAK_STD_UNORDERED_MAP std::unordered_map
- #define XBYAK_STD_UNORDERED_MULTIMAP std::unordered_multimap
-
-/*
- Clang/llvm-gcc and ICC-EDG in 'GCC-mode' always claim to be GCC 4.2, using
- libstdcxx 20070719 (from GCC 4.2.1, the last GPL 2 version).
-*/
-#elif XBYAK_GNUC_PREREQ(4, 5) || (XBYAK_GNUC_PREREQ(4, 2) && __GLIBCXX__ >= 20070719) || defined(__INTEL_COMPILER) || defined(__llvm__)
- #include <tr1/unordered_set>
- #define XBYAK_STD_UNORDERED_SET std::tr1::unordered_set
- #include <tr1/unordered_map>
- #define XBYAK_STD_UNORDERED_MAP std::tr1::unordered_map
- #define XBYAK_STD_UNORDERED_MULTIMAP std::tr1::unordered_multimap
-
-#elif defined(_MSC_VER) && (_MSC_VER >= 1500) && (_MSC_VER < 1600)
- #include <unordered_set>
- #define XBYAK_STD_UNORDERED_SET std::tr1::unordered_set
- #include <unordered_map>
- #define XBYAK_STD_UNORDERED_MAP std::tr1::unordered_map
- #define XBYAK_STD_UNORDERED_MULTIMAP std::tr1::unordered_multimap
-
-#else
- #include <set>
- #define XBYAK_STD_UNORDERED_SET std::set
- #include <map>
- #define XBYAK_STD_UNORDERED_MAP std::map
- #define XBYAK_STD_UNORDERED_MULTIMAP std::multimap
-#endif
-#ifdef _WIN32
- #include <winsock2.h>
- #include <windows.h>
- #include <malloc.h>
-#elif defined(__GNUC__)
- #include <unistd.h>
- #include <sys/mman.h>
- #include <stdlib.h>
-#endif
-#if !defined(_MSC_VER) || (_MSC_VER >= 1600)
- #include <stdint.h>
-#endif
-
-#if defined(_WIN64) || defined(__MINGW64__) || (defined(__CYGWIN__) && defined(__x86_64__))
- #define XBYAK64_WIN
-#elif defined(__x86_64__)
- #define XBYAK64_GCC
-#endif
-#if !defined(XBYAK64) && !defined(XBYAK32)
- #if defined(XBYAK64_GCC) || defined(XBYAK64_WIN)
- #define XBYAK64
- #else
- #define XBYAK32
- #endif
-#endif
-
-#if (__cplusplus >= 201103) || (_MSC_VER >= 1800)
- #define XBYAK_VARIADIC_TEMPLATE
-#endif
-
-#ifdef _MSC_VER
- #pragma warning(push)
- #pragma warning(disable : 4514) /* remove inline function */
- #pragma warning(disable : 4786) /* identifier is too long */
- #pragma warning(disable : 4503) /* name is too long */
- #pragma warning(disable : 4127) /* constant expresison */
-#endif
-
-namespace Xbyak {
-
-enum {
- DEFAULT_MAX_CODE_SIZE = 4096,
- VERSION = 0x5760 /* 0xABCD = A.BC(D) */
-};
-
-#ifndef MIE_INTEGER_TYPE_DEFINED
-#define MIE_INTEGER_TYPE_DEFINED
-#ifdef _MSC_VER
- typedef unsigned __int64 uint64;
- typedef __int64 sint64;
-#else
- typedef uint64_t uint64;
- typedef int64_t sint64;
-#endif
-typedef unsigned int uint32;
-typedef unsigned short uint16;
-typedef unsigned char uint8;
-#endif
-
-#ifndef MIE_ALIGN
- #ifdef _MSC_VER
- #define MIE_ALIGN(x) __declspec(align(x))
- #else
- #define MIE_ALIGN(x) __attribute__((aligned(x)))
- #endif
-#endif
-#ifndef MIE_PACK // for shufps
- #define MIE_PACK(x, y, z, w) ((x) * 64 + (y) * 16 + (z) * 4 + (w))
-#endif
-
-enum {
- ERR_NONE = 0,
- ERR_BAD_ADDRESSING,
- ERR_CODE_IS_TOO_BIG,
- ERR_BAD_SCALE,
- ERR_ESP_CANT_BE_INDEX,
- ERR_BAD_COMBINATION,
- ERR_BAD_SIZE_OF_REGISTER,
- ERR_IMM_IS_TOO_BIG,
- ERR_BAD_ALIGN,
- ERR_LABEL_IS_REDEFINED,
- ERR_LABEL_IS_TOO_FAR,
- ERR_LABEL_IS_NOT_FOUND,
- ERR_CODE_ISNOT_COPYABLE,
- ERR_BAD_PARAMETER,
- ERR_CANT_PROTECT,
- ERR_CANT_USE_64BIT_DISP,
- ERR_OFFSET_IS_TOO_BIG,
- ERR_MEM_SIZE_IS_NOT_SPECIFIED,
- ERR_BAD_MEM_SIZE,
- ERR_BAD_ST_COMBINATION,
- ERR_OVER_LOCAL_LABEL, // not used
- ERR_UNDER_LOCAL_LABEL,
- ERR_CANT_ALLOC,
- ERR_ONLY_T_NEAR_IS_SUPPORTED_IN_AUTO_GROW,
- ERR_BAD_PROTECT_MODE,
- ERR_BAD_PNUM,
- ERR_BAD_TNUM,
- ERR_BAD_VSIB_ADDRESSING,
- ERR_CANT_CONVERT,
- ERR_LABEL_ISNOT_SET_BY_L,
- ERR_LABEL_IS_ALREADY_SET_BY_L,
- ERR_BAD_LABEL_STR,
- ERR_MUNMAP,
- ERR_OPMASK_IS_ALREADY_SET,
- ERR_ROUNDING_IS_ALREADY_SET,
- ERR_K0_IS_INVALID,
- ERR_EVEX_IS_INVALID,
- ERR_SAE_IS_INVALID,
- ERR_ER_IS_INVALID,
- ERR_INVALID_BROADCAST,
- ERR_INVALID_OPMASK_WITH_MEMORY,
- ERR_INVALID_ZERO,
- ERR_INVALID_RIP_IN_AUTO_GROW,
- ERR_INVALID_MIB_ADDRESS,
- ERR_INTERNAL,
- ERR_X2APIC_IS_NOT_SUPPORTED
-};
-
-class Error : public std::exception {
- int err_;
-public:
- explicit Error(int err) : err_(err)
- {
- if (err_ < 0 || err_ > ERR_INTERNAL) {
- fprintf(stderr, "bad err=%d in Xbyak::Error\n", err_);
- //exit(1);
- }
- }
- operator int() const { return err_; }
- const char *what() const throw()
- {
- static const char *errTbl[] = {
- "none",
- "bad addressing",
- "code is too big",
- "bad scale",
- "esp can't be index",
- "bad combination",
- "bad size of register",
- "imm is too big",
- "bad align",
- "label is redefined",
- "label is too far",
- "label is not found",
- "code is not copyable",
- "bad parameter",
- "can't protect",
- "can't use 64bit disp(use (void*))",
- "offset is too big",
- "MEM size is not specified",
- "bad mem size",
- "bad st combination",
- "over local label",
- "under local label",
- "can't alloc",
- "T_SHORT is not supported in AutoGrow",
- "bad protect mode",
- "bad pNum",
- "bad tNum",
- "bad vsib addressing",
- "can't convert",
- "label is not set by L()",
- "label is already set by L()",
- "bad label string",
- "err munmap",
- "opmask is already set",
- "rounding is already set",
- "k0 is invalid",
- "evex is invalid",
- "sae(suppress all exceptions) is invalid",
- "er(embedded rounding) is invalid",
- "invalid broadcast",
- "invalid opmask with memory",
- "invalid zero",
- "invalid rip in AutoGrow",
- "invalid mib address",
- "internal error",
- "x2APIC is not supported"
- };
- assert((size_t)err_ < sizeof(errTbl) / sizeof(*errTbl));
- return errTbl[err_];
- }
-};
-
-inline const char *ConvertErrorToString(const Error& err)
-{
- return err.what();
-}
-
-inline void *AlignedMalloc(size_t size, size_t alignment)
-{
-#ifdef __MINGW32__
- return __mingw_aligned_malloc(size, alignment);
-#elif defined(_WIN32)
- return _aligned_malloc(size, alignment);
-#else
- void *p;
- int ret = posix_memalign(&p, alignment, size);
- return (ret == 0) ? p : 0;
-#endif
-}
-
-inline void AlignedFree(void *p)
-{
-#ifdef __MINGW32__
- __mingw_aligned_free(p);
-#elif defined(_MSC_VER)
- _aligned_free(p);
-#else
- free(p);
-#endif
-}
-
-template<class To, class From>
-inline const To CastTo(From p) throw()
-{
- return (const To)(size_t)(p);
-}
-namespace inner {
-
-static const size_t ALIGN_PAGE_SIZE = 4096;
-
-inline bool IsInDisp8(uint32 x) { return 0xFFFFFF80 <= x || x <= 0x7F; }
-inline bool IsInInt32(uint64 x) { return ~uint64(0x7fffffffu) <= x || x <= 0x7FFFFFFFU; }
-
-inline uint32 VerifyInInt32(uint64 x)
-{
-#ifdef XBYAK64
- if (!IsInInt32(x)) throw Error(ERR_OFFSET_IS_TOO_BIG);
-#endif
- return static_cast<uint32>(x);
-}
-
-enum LabelMode {
- LasIs, // as is
- Labs, // absolute
- LaddTop // (addr + top) for mov(reg, label) with AutoGrow
-};
-
-} // inner
-
-/*
- custom allocator
-*/
-struct Allocator {
- virtual uint8 *alloc(size_t size) { return reinterpret_cast<uint8*>(AlignedMalloc(size, inner::ALIGN_PAGE_SIZE)); }
- virtual void free(uint8 *p) { AlignedFree(p); }
- virtual ~Allocator() {}
- /* override to return false if you call protect() manually */
- virtual bool useProtect() const { return true; }
-};
-
-#ifdef XBYAK_USE_MMAP_ALLOCATOR
-class MmapAllocator : Allocator {
- typedef XBYAK_STD_UNORDERED_MAP<uintptr_t, size_t> SizeList;
- SizeList sizeList_;
-public:
- uint8 *alloc(size_t size)
- {
- const size_t alignedSizeM1 = inner::ALIGN_PAGE_SIZE - 1;
- size = (size + alignedSizeM1) & ~alignedSizeM1;
-#ifdef MAP_ANONYMOUS
- const int mode = MAP_PRIVATE | MAP_ANONYMOUS;
-#elif defined(MAP_ANON)
- const int mode = MAP_PRIVATE | MAP_ANON;
-#else
- #error "not supported"
-#endif
- void *p = mmap(NULL, size, PROT_READ | PROT_WRITE, mode, -1, 0);
- if (p == MAP_FAILED) throw Error(ERR_CANT_ALLOC);
- assert(p);
- sizeList_[(uintptr_t)p] = size;
- return (uint8*)p;
- }
- void free(uint8 *p)
- {
- if (p == 0) return;
- SizeList::iterator i = sizeList_.find((uintptr_t)p);
- if (i == sizeList_.end()) throw Error(ERR_BAD_PARAMETER);
- if (munmap((void*)i->first, i->second) < 0) throw Error(ERR_MUNMAP);
- sizeList_.erase(i);
- }
-};
-#endif
-
-class Address;
-class Reg;
-
-class Operand {
- static const uint8 EXT8BIT = 0x20;
- unsigned int idx_:6; // 0..31 + EXT8BIT = 1 if spl/bpl/sil/dil
- unsigned int kind_:9;
- unsigned int bit_:10;
-protected:
- unsigned int zero_:1;
- unsigned int mask_:3;
- unsigned int rounding_:3;
- void setIdx(int idx) { idx_ = idx; }
-public:
- enum Kind {
- NONE = 0,
- MEM = 1 << 0,
- REG = 1 << 1,
- MMX = 1 << 2,
- FPU = 1 << 3,
- XMM = 1 << 4,
- YMM = 1 << 5,
- ZMM = 1 << 6,
- OPMASK = 1 << 7,
- BNDREG = 1 << 8
- };
- enum Code {
-#ifdef XBYAK64
- RAX = 0, RCX, RDX, RBX, RSP, RBP, RSI, RDI, R8, R9, R10, R11, R12, R13, R14, R15,
- R8D = 8, R9D, R10D, R11D, R12D, R13D, R14D, R15D,
- R8W = 8, R9W, R10W, R11W, R12W, R13W, R14W, R15W,
- R8B = 8, R9B, R10B, R11B, R12B, R13B, R14B, R15B,
- SPL = 4, BPL, SIL, DIL,
-#endif
- EAX = 0, ECX, EDX, EBX, ESP, EBP, ESI, EDI,
- AX = 0, CX, DX, BX, SP, BP, SI, DI,
- AL = 0, CL, DL, BL, AH, CH, DH, BH
- };
- Operand() : idx_(0), kind_(0), bit_(0), zero_(0), mask_(0), rounding_(0) { }
- Operand(int idx, Kind kind, int bit, bool ext8bit = 0)
- : idx_(static_cast<uint8>(idx | (ext8bit ? EXT8BIT : 0)))
- , kind_(kind)
- , bit_(bit)
- , zero_(0), mask_(0), rounding_(0)
- {
- assert((bit_ & (bit_ - 1)) == 0); // bit must be power of two
- }
- Kind getKind() const { return static_cast<Kind>(kind_); }
- int getIdx() const { return idx_ & (EXT8BIT - 1); }
- bool isNone() const { return kind_ == 0; }
- bool isMMX() const { return is(MMX); }
- bool isXMM() const { return is(XMM); }
- bool isYMM() const { return is(YMM); }
- bool isZMM() const { return is(ZMM); }
- bool isXMEM() const { return is(XMM | MEM); }
- bool isYMEM() const { return is(YMM | MEM); }
- bool isZMEM() const { return is(ZMM | MEM); }
- bool isOPMASK() const { return is(OPMASK); }
- bool isBNDREG() const { return is(BNDREG); }
- bool isREG(int bit = 0) const { return is(REG, bit); }
- bool isMEM(int bit = 0) const { return is(MEM, bit); }
- bool isFPU() const { return is(FPU); }
- bool isExt8bit() const { return (idx_ & EXT8BIT) != 0; }
- bool isExtIdx() const { return (getIdx() & 8) != 0; }
- bool isExtIdx2() const { return (getIdx() & 16) != 0; }
- bool hasEvex() const { return isZMM() || isExtIdx2() || getOpmaskIdx() || getRounding(); }
- bool hasRex() const { return isExt8bit() || isREG(64) || isExtIdx(); }
- bool hasZero() const { return zero_; }
- int getOpmaskIdx() const { return mask_; }
- int getRounding() const { return rounding_; }
- void setKind(Kind kind)
- {
- if ((kind & (XMM|YMM|ZMM)) == 0) return;
- kind_ = kind;
- bit_ = kind == XMM ? 128 : kind == YMM ? 256 : 512;
- }
- void setBit(int bit) { bit_ = bit; }
- void setOpmaskIdx(int idx, bool ignore_idx0 = false)
- {
- if (!ignore_idx0 && idx == 0) throw Error(ERR_K0_IS_INVALID);
- if (mask_) throw Error(ERR_OPMASK_IS_ALREADY_SET);
- mask_ = idx;
- }
- void setRounding(int idx)
- {
- if (rounding_) throw Error(ERR_ROUNDING_IS_ALREADY_SET);
- rounding_ = idx;
- }
- void setZero() { zero_ = true; }
- // ah, ch, dh, bh?
- bool isHigh8bit() const
- {
- if (!isBit(8)) return false;
- if (isExt8bit()) return false;
- const int idx = getIdx();
- return AH <= idx && idx <= BH;
- }
- // any bit is accetable if bit == 0
- bool is(int kind, uint32 bit = 0) const
- {
- return (kind == 0 || (kind_ & kind)) && (bit == 0 || (bit_ & bit)); // cf. you can set (8|16)
- }
- bool isBit(uint32 bit) const { return (bit_ & bit) != 0; }
- uint32 getBit() const { return bit_; }
- const char *toString() const
- {
- const int idx = getIdx();
- if (kind_ == REG) {
- if (isExt8bit()) {
- static const char *tbl[4] = { "spl", "bpl", "sil", "dil" };
- return tbl[idx - 4];
- }
- static const char *tbl[4][16] = {
- { "al", "cl", "dl", "bl", "ah", "ch", "dh", "bh", "r8b", "r9b", "r10b", "r11b", "r12b", "r13b", "r14b", "r15b" },
- { "ax", "cx", "dx", "bx", "sp", "bp", "si", "di", "r8w", "r9w", "r10w", "r11w", "r12w", "r13w", "r14w", "r15w" },
- { "eax", "ecx", "edx", "ebx", "esp", "ebp", "esi", "edi", "r8d", "r9d", "r10d", "r11d", "r12d", "r13d", "r14d", "r15d" },
- { "rax", "rcx", "rdx", "rbx", "rsp", "rbp", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15" },
- };
- return tbl[bit_ == 8 ? 0 : bit_ == 16 ? 1 : bit_ == 32 ? 2 : 3][idx];
- } else if (isOPMASK()) {
- static const char *tbl[8] = { "k0", "k1", "k2", "k3", "k4", "k5", "k6", "k7" };
- return tbl[idx];
- } else if (isZMM()) {
- static const char *tbl[32] = {
- "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15",
- "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31"
- };
- return tbl[idx];
- } else if (isYMM()) {
- static const char *tbl[32] = {
- "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15",
- "ymm16", "ymm17", "ymm18", "ymm19", "ymm20", "ymm21", "ymm22", "ymm23", "ymm24", "ymm25", "ymm26", "ymm27", "ymm28", "ymm29", "ymm30", "ymm31"
- };
- return tbl[idx];
- } else if (isXMM()) {
- static const char *tbl[32] = {
- "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15",
- "xmm16", "xmm17", "xmm18", "xmm19", "xmm20", "xmm21", "xmm22", "xmm23", "xmm24", "xmm25", "xmm26", "xmm27", "xmm28", "xmm29", "xmm30", "xmm31"
- };
- return tbl[idx];
- } else if (isMMX()) {
- static const char *tbl[8] = { "mm0", "mm1", "mm2", "mm3", "mm4", "mm5", "mm6", "mm7" };
- return tbl[idx];
- } else if (isFPU()) {
- static const char *tbl[8] = { "st0", "st1", "st2", "st3", "st4", "st5", "st6", "st7" };
- return tbl[idx];
- } else if (isBNDREG()) {
- static const char *tbl[4] = { "bnd0", "bnd1", "bnd2", "bnd3" };
- return tbl[idx];
- }
- throw Error(ERR_INTERNAL);
- }
- bool isEqualIfNotInherited(const Operand& rhs) const { return idx_ == rhs.idx_ && kind_ == rhs.kind_ && bit_ == rhs.bit_ && zero_ == rhs.zero_ && mask_ == rhs.mask_ && rounding_ == rhs.rounding_; }
- bool operator==(const Operand& rhs) const;
- bool operator!=(const Operand& rhs) const { return !operator==(rhs); }
- const Address& getAddress() const;
- const Reg& getReg() const;
-};
-
-class Label;
-
-struct Reg8;
-struct Reg16;
-struct Reg32;
-#ifdef XBYAK64
-struct Reg64;
-#endif
-class Reg : public Operand {
-public:
- Reg() { }
- Reg(int idx, Kind kind, int bit = 0, bool ext8bit = false) : Operand(idx, kind, bit, ext8bit) { }
- Reg changeBit(int bit) const { return Reg(getIdx(), getKind(), bit, isExt8bit()); }
- uint8 getRexW() const { return isREG(64) ? 8 : 0; }
- uint8 getRexR() const { return isExtIdx() ? 4 : 0; }
- uint8 getRexX() const { return isExtIdx() ? 2 : 0; }
- uint8 getRexB() const { return isExtIdx() ? 1 : 0; }
- uint8 getRex(const Reg& base = Reg()) const
- {
- uint8 rex = getRexW() | getRexR() | base.getRexW() | base.getRexB();
- if (rex || isExt8bit() || base.isExt8bit()) rex |= 0x40;
- return rex;
- }
- Reg8 cvt8() const;
- Reg16 cvt16() const;
- Reg32 cvt32() const;
-#ifdef XBYAK64
- Reg64 cvt64() const;
-#endif
-};
-
-inline const Reg& Operand::getReg() const
-{
- assert(!isMEM());
- return static_cast<const Reg&>(*this);
-}
-
-struct Reg8 : public Reg {
- explicit Reg8(int idx = 0, bool ext8bit = false) : Reg(idx, Operand::REG, 8, ext8bit) { }
-};
-
-struct Reg16 : public Reg {
- explicit Reg16(int idx = 0) : Reg(idx, Operand::REG, 16) { }
-};
-
-struct Mmx : public Reg {
- explicit Mmx(int idx = 0, Kind kind = Operand::MMX, int bit = 64) : Reg(idx, kind, bit) { }
-};
-
-struct EvexModifierRounding {
- enum {
- T_RN_SAE = 1,
- T_RD_SAE = 2,
- T_RU_SAE = 3,
- T_RZ_SAE = 4,
- T_SAE = 5
- };
- explicit EvexModifierRounding(int rounding) : rounding(rounding) {}
- int rounding;
-};
-struct EvexModifierZero{EvexModifierZero() {}};
-
-struct Xmm : public Mmx {
- explicit Xmm(int idx = 0, Kind kind = Operand::XMM, int bit = 128) : Mmx(idx, kind, bit) { }
- Xmm(Kind kind, int idx) : Mmx(idx, kind, kind == XMM ? 128 : kind == YMM ? 256 : 512) { }
- Xmm operator|(const EvexModifierRounding& emr) const { Xmm r(*this); r.setRounding(emr.rounding); return r; }
- Xmm copyAndSetIdx(int idx) const { Xmm ret(*this); ret.setIdx(idx); return ret; }
- Xmm copyAndSetKind(Operand::Kind kind) const { Xmm ret(*this); ret.setKind(kind); return ret; }
-};
-
-struct Ymm : public Xmm {
- explicit Ymm(int idx = 0, Kind kind = Operand::YMM, int bit = 256) : Xmm(idx, kind, bit) { }
- Ymm operator|(const EvexModifierRounding& emr) const { Ymm r(*this); r.setRounding(emr.rounding); return r; }
-};
-
-struct Zmm : public Ymm {
- explicit Zmm(int idx = 0) : Ymm(idx, Operand::ZMM, 512) { }
- Zmm operator|(const EvexModifierRounding& emr) const { Zmm r(*this); r.setRounding(emr.rounding); return r; }
-};
-
-struct Opmask : public Reg {
- explicit Opmask(int idx = 0) : Reg(idx, Operand::OPMASK, 64) {}
-};
-
-struct BoundsReg : public Reg {
- explicit BoundsReg(int idx = 0) : Reg(idx, Operand::BNDREG, 128) {}
-};
-
-template<class T>T operator|(const T& x, const Opmask& k) { T r(x); r.setOpmaskIdx(k.getIdx()); return r; }
-template<class T>T operator|(const T& x, const EvexModifierZero&) { T r(x); r.setZero(); return r; }
-template<class T>T operator|(const T& x, const EvexModifierRounding& emr) { T r(x); r.setRounding(emr.rounding); return r; }
-
-struct Fpu : public Reg {
- explicit Fpu(int idx = 0) : Reg(idx, Operand::FPU, 32) { }
-};
-
-struct Reg32e : public Reg {
- explicit Reg32e(int idx, int bit) : Reg(idx, Operand::REG, bit) {}
-};
-struct Reg32 : public Reg32e {
- explicit Reg32(int idx = 0) : Reg32e(idx, 32) {}
-};
-#ifdef XBYAK64
-struct Reg64 : public Reg32e {
- explicit Reg64(int idx = 0) : Reg32e(idx, 64) {}
-};
-struct RegRip {
- sint64 disp_;
- const Label* label_;
- bool isAddr_;
- explicit RegRip(sint64 disp = 0, const Label* label = 0, bool isAddr = false) : disp_(disp), label_(label), isAddr_(isAddr) {}
- friend const RegRip operator+(const RegRip& r, int disp) {
- return RegRip(r.disp_ + disp, r.label_, r.isAddr_);
- }
- friend const RegRip operator-(const RegRip& r, int disp) {
- return RegRip(r.disp_ - disp, r.label_, r.isAddr_);
- }
- friend const RegRip operator+(const RegRip& r, sint64 disp) {
- return RegRip(r.disp_ + disp, r.label_, r.isAddr_);
- }
- friend const RegRip operator-(const RegRip& r, sint64 disp) {
- return RegRip(r.disp_ - disp, r.label_, r.isAddr_);
- }
- friend const RegRip operator+(const RegRip& r, const Label& label) {
- if (r.label_ || r.isAddr_) throw Error(ERR_BAD_ADDRESSING);
- return RegRip(r.disp_, &label);
- }
- friend const RegRip operator+(const RegRip& r, const void *addr) {
- if (r.label_ || r.isAddr_) throw Error(ERR_BAD_ADDRESSING);
- return RegRip(r.disp_ + (sint64)addr, 0, true);
- }
-};
-#endif
-
-inline Reg8 Reg::cvt8() const
-{
- const int idx = getIdx();
- if (isBit(8)) return Reg8(idx, isExt8bit());
-#ifdef XBYAK32
- if (idx >= 4) throw Error(ERR_CANT_CONVERT);
-#endif
- return Reg8(idx, 4 <= idx && idx < 8);
-}
-
-inline Reg16 Reg::cvt16() const
-{
- const int idx = getIdx();
- if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT);
- return Reg16(idx);
-}
-
-inline Reg32 Reg::cvt32() const
-{
- const int idx = getIdx();
- if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT);
- return Reg32(idx);
-}
-
-#ifdef XBYAK64
-inline Reg64 Reg::cvt64() const
-{
- const int idx = getIdx();
- if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT);
- return Reg64(idx);
-}
-#endif
-
-#ifndef XBYAK_DISABLE_SEGMENT
-// not derived from Reg
-class Segment {
- int idx_;
-public:
- enum {
- es, cs, ss, ds, fs, gs
- };
- explicit Segment(int idx) : idx_(idx) { assert(0 <= idx_ && idx_ < 6); }
- int getIdx() const { return idx_; }
- const char *toString() const
- {
- static const char tbl[][3] = {
- "es", "cs", "ss", "ds", "fs", "gs"
- };
- return tbl[idx_];
- }
-};
-#endif
-
-class RegExp {
-public:
-#ifdef XBYAK64
- enum { i32e = 32 | 64 };
-#else
- enum { i32e = 32 };
-#endif
- RegExp(size_t disp = 0) : scale_(0), disp_(disp) { }
- RegExp(const Reg& r, int scale = 1)
- : scale_(scale)
- , disp_(0)
- {
- if (!r.isREG(i32e) && !r.is(Reg::XMM|Reg::YMM|Reg::ZMM)) throw Error(ERR_BAD_SIZE_OF_REGISTER);
- if (scale == 0) return;
- if (scale != 1 && scale != 2 && scale != 4 && scale != 8) throw Error(ERR_BAD_SCALE);
- if (r.getBit() >= 128 || scale != 1) { // xmm/ymm is always index
- index_ = r;
- } else {
- base_ = r;
- }
- }
- bool isVsib(int bit = 128 | 256 | 512) const { return index_.isBit(bit); }
- RegExp optimize() const
- {
- RegExp exp = *this;
- // [reg * 2] => [reg + reg]
- if (index_.isBit(i32e) && !base_.getBit() && scale_ == 2) {
- exp.base_ = index_;
- exp.scale_ = 1;
- }
- return exp;
- }
- bool operator==(const RegExp& rhs) const
- {
- return base_ == rhs.base_ && index_ == rhs.index_ && disp_ == rhs.disp_ && scale_ == rhs.scale_;
- }
- const Reg& getBase() const { return base_; }
- const Reg& getIndex() const { return index_; }
- int getScale() const { return scale_; }
- size_t getDisp() const { return disp_; }
- void verify() const
- {
- if (base_.getBit() >= 128) throw Error(ERR_BAD_SIZE_OF_REGISTER);
- if (index_.getBit() && index_.getBit() <= 64) {
- if (index_.getIdx() == Operand::ESP) throw Error(ERR_ESP_CANT_BE_INDEX);
- if (base_.getBit() && base_.getBit() != index_.getBit()) throw Error(ERR_BAD_SIZE_OF_REGISTER);
- }
- }
- friend RegExp operator+(const RegExp& a, const RegExp& b);
- friend RegExp operator-(const RegExp& e, size_t disp);
- uint8 getRex() const
- {
- uint8 rex = index_.getRexX() | base_.getRexB();
- return rex ? uint8(rex | 0x40) : 0;
- }
-private:
- /*
- [base_ + index_ * scale_ + disp_]
- base : Reg32e, index : Reg32e(w/o esp), Xmm, Ymm
- */
- Reg base_;
- Reg index_;
- int scale_;
- size_t disp_;
-};
-
-inline RegExp operator+(const RegExp& a, const RegExp& b)
-{
- if (a.index_.getBit() && b.index_.getBit()) throw Error(ERR_BAD_ADDRESSING);
- RegExp ret = a;
- if (!ret.index_.getBit()) { ret.index_ = b.index_; ret.scale_ = b.scale_; }
- if (b.base_.getBit()) {
- if (ret.base_.getBit()) {
- if (ret.index_.getBit()) throw Error(ERR_BAD_ADDRESSING);
- // base + base => base + index * 1
- ret.index_ = b.base_;
- // [reg + esp] => [esp + reg]
- if (ret.index_.getIdx() == Operand::ESP) std::swap(ret.base_, ret.index_);
- ret.scale_ = 1;
- } else {
- ret.base_ = b.base_;
- }
- }
- ret.disp_ += b.disp_;
- return ret;
-}
-inline RegExp operator*(const Reg& r, int scale)
-{
- return RegExp(r, scale);
-}
-inline RegExp operator-(const RegExp& e, size_t disp)
-{
- RegExp ret = e;
- ret.disp_ -= disp;
- return ret;
-}
-
-// 2nd parameter for constructor of CodeArray(maxSize, userPtr, alloc)
-void *const AutoGrow = (void*)1; //-V566
-void *const DontSetProtectRWE = (void*)2; //-V566
-
-class CodeArray {
- enum Type {
- USER_BUF = 1, // use userPtr(non alignment, non protect)
- ALLOC_BUF, // use new(alignment, protect)
- AUTO_GROW // automatically move and grow memory if necessary
- };
- CodeArray(const CodeArray& rhs);
- void operator=(const CodeArray&);
- bool isAllocType() const { return type_ == ALLOC_BUF || type_ == AUTO_GROW; }
- struct AddrInfo {
- size_t codeOffset; // position to write
- size_t jmpAddr; // value to write
- int jmpSize; // size of jmpAddr
- inner::LabelMode mode;
- AddrInfo(size_t _codeOffset, size_t _jmpAddr, int _jmpSize, inner::LabelMode _mode)
- : codeOffset(_codeOffset), jmpAddr(_jmpAddr), jmpSize(_jmpSize), mode(_mode) {}
- uint64 getVal(const uint8 *top) const
- {
- uint64 disp = (mode == inner::LaddTop) ? jmpAddr + size_t(top) : (mode == inner::LasIs) ? jmpAddr : jmpAddr - size_t(top);
- if (jmpSize == 4) disp = inner::VerifyInInt32(disp);
- return disp;
- }
- };
- typedef std::list<AddrInfo> AddrInfoList;
- AddrInfoList addrInfoList_;
- const Type type_;
-#ifdef XBYAK_USE_MMAP_ALLOCATOR
- MmapAllocator defaultAllocator_;
-#else
- Allocator defaultAllocator_;
-#endif
- Allocator *alloc_;
-protected:
- size_t maxSize_;
- uint8 *top_;
- size_t size_;
- bool isCalledCalcJmpAddress_;
-
- bool useProtect() const { return alloc_->useProtect(); }
- /*
- allocate new memory and copy old data to the new area
- */
- void growMemory()
- {
- const size_t newSize = (std::max<size_t>)(DEFAULT_MAX_CODE_SIZE, maxSize_ * 2);
- uint8 *newTop = alloc_->alloc(newSize);
- if (newTop == 0) throw Error(ERR_CANT_ALLOC);
- for (size_t i = 0; i < size_; i++) newTop[i] = top_[i];
- alloc_->free(top_);
- top_ = newTop;
- maxSize_ = newSize;
- }
- /*
- calc jmp address for AutoGrow mode
- */
- void calcJmpAddress()
- {
- if (isCalledCalcJmpAddress_) return;
- for (AddrInfoList::const_iterator i = addrInfoList_.begin(), ie = addrInfoList_.end(); i != ie; ++i) {
- uint64 disp = i->getVal(top_);
- rewrite(i->codeOffset, disp, i->jmpSize);
- }
- isCalledCalcJmpAddress_ = true;
- }
-public:
- enum ProtectMode {
- PROTECT_RW = 0, // read/write
- PROTECT_RWE = 1, // read/write/exec
- PROTECT_RE = 2 // read/exec
- };
- explicit CodeArray(size_t maxSize, void *userPtr = 0, Allocator *allocator = 0)
- : type_(userPtr == AutoGrow ? AUTO_GROW : (userPtr == 0 || userPtr == DontSetProtectRWE) ? ALLOC_BUF : USER_BUF)
- , alloc_(allocator ? allocator : (Allocator*)&defaultAllocator_)
- , maxSize_(maxSize)
- , top_(type_ == USER_BUF ? reinterpret_cast<uint8*>(userPtr) : alloc_->alloc((std::max<size_t>)(maxSize, 1)))
- , size_(0)
- , isCalledCalcJmpAddress_(false)
- {
- if (maxSize_ > 0 && top_ == 0) throw Error(ERR_CANT_ALLOC);
- if ((type_ == ALLOC_BUF && userPtr != DontSetProtectRWE && useProtect()) && !setProtectMode(PROTECT_RWE, false)) {
- alloc_->free(top_);
- throw Error(ERR_CANT_PROTECT);
- }
- }
- virtual ~CodeArray()
- {
- if (isAllocType()) {
- if (useProtect()) setProtectModeRW(false);
- alloc_->free(top_);
- }
- }
- bool setProtectMode(ProtectMode mode, bool throwException = true)
- {
- bool isOK = protect(top_, maxSize_, mode);
- if (isOK) return true;
- if (throwException) throw Error(ERR_CANT_PROTECT);
- return false;
- }
- bool setProtectModeRE(bool throwException = true) { return setProtectMode(PROTECT_RE, throwException); }
- bool setProtectModeRW(bool throwException = true) { return setProtectMode(PROTECT_RW, throwException); }
- void resetSize()
- {
- size_ = 0;
- addrInfoList_.clear();
- isCalledCalcJmpAddress_ = false;
- }
- void db(int code)
- {
- if (size_ >= maxSize_) {
- if (type_ == AUTO_GROW) {
- growMemory();
- } else {
- throw Error(ERR_CODE_IS_TOO_BIG);
- }
- }
- top_[size_++] = static_cast<uint8>(code);
- }
- void db(const uint8 *code, size_t codeSize)
- {
- for (size_t i = 0; i < codeSize; i++) db(code[i]);
- }
- void db(uint64 code, size_t codeSize)
- {
- if (codeSize > 8) throw Error(ERR_BAD_PARAMETER);
- for (size_t i = 0; i < codeSize; i++) db(static_cast<uint8>(code >> (i * 8)));
- }
- void dw(uint32 code) { db(code, 2); }
- void dd(uint32 code) { db(code, 4); }
- void dq(uint64 code) { db(code, 8); }
- const uint8 *getCode() const { return top_; }
- template<class F>
- const F getCode() const { return reinterpret_cast<F>(top_); }
- const uint8 *getCurr() const { return &top_[size_]; }
- template<class F>
- const F getCurr() const { return reinterpret_cast<F>(&top_[size_]); }
- size_t getSize() const { return size_; }
- void setSize(size_t size)
- {
- if (size > maxSize_) throw Error(ERR_OFFSET_IS_TOO_BIG);
- size_ = size;
- }
- void dump() const
- {
- const uint8 *p = getCode();
- size_t bufSize = getSize();
- size_t remain = bufSize;
- for (int i = 0; i < 4; i++) {
- size_t disp = 16;
- if (remain < 16) {
- disp = remain;
- }
- for (size_t j = 0; j < 16; j++) {
- if (j < disp) {
- printf("%02X", p[i * 16 + j]);
- }
- }
- putchar('\n');
- remain -= disp;
- if (remain == 0) {
- break;
- }
- }
- }
- /*
- @param offset [in] offset from top
- @param disp [in] offset from the next of jmp
- @param size [in] write size(1, 2, 4, 8)
- */
- void rewrite(size_t offset, uint64 disp, size_t size)
- {
- assert(offset < maxSize_);
- if (size != 1 && size != 2 && size != 4 && size != 8) throw Error(ERR_BAD_PARAMETER);
- uint8 *const data = top_ + offset;
- for (size_t i = 0; i < size; i++) {
- data[i] = static_cast<uint8>(disp >> (i * 8));
- }
- }
- void save(size_t offset, size_t val, int size, inner::LabelMode mode)
- {
- addrInfoList_.push_back(AddrInfo(offset, val, size, mode));
- }
- bool isAutoGrow() const { return type_ == AUTO_GROW; }
- bool isCalledCalcJmpAddress() const { return isCalledCalcJmpAddress_; }
- /**
- change exec permission of memory
- @param addr [in] buffer address
- @param size [in] buffer size
- @param protectMode [in] mode(RW/RWE/RE)
- @return true(success), false(failure)
- */
- static inline bool protect(const void *addr, size_t size, int protectMode)
- {
-#if defined(_WIN32)
- const DWORD c_rw = PAGE_READWRITE;
- const DWORD c_rwe = PAGE_EXECUTE_READWRITE;
- const DWORD c_re = PAGE_EXECUTE_READ;
- DWORD mode;
-#else
- const int c_rw = PROT_READ | PROT_WRITE;
- const int c_rwe = PROT_READ | PROT_WRITE | PROT_EXEC;
- const int c_re = PROT_READ | PROT_EXEC;
- int mode;
-#endif
- switch (protectMode) {
- case PROTECT_RW: mode = c_rw; break;
- case PROTECT_RWE: mode = c_rwe; break;
- case PROTECT_RE: mode = c_re; break;
- default:
- return false;
- }
-#if defined(_WIN32)
- DWORD oldProtect;
- return VirtualProtect(const_cast<void*>(addr), size, mode, &oldProtect) != 0;
-#elif defined(__GNUC__)
- size_t pageSize = sysconf(_SC_PAGESIZE);
- size_t iaddr = reinterpret_cast<size_t>(addr);
- size_t roundAddr = iaddr & ~(pageSize - static_cast<size_t>(1));
-#ifndef NDEBUG
- if (pageSize != 4096) fprintf(stderr, "large page(%zd) is used. not tested enough.\n", pageSize);
-#endif
- return mprotect(reinterpret_cast<void*>(roundAddr), size + (iaddr - roundAddr), mode) == 0;
-#else
- return true;
-#endif
- }
- /**
- get aligned memory pointer
- @param addr [in] address
- @param alignedSize [in] power of two
- @return aligned addr by alingedSize
- */
- static inline uint8 *getAlignedAddress(uint8 *addr, size_t alignedSize = 16)
- {
- return reinterpret_cast<uint8*>((reinterpret_cast<size_t>(addr) + alignedSize - 1) & ~(alignedSize - static_cast<size_t>(1)));
- }
-};
-
-class Address : public Operand {
-public:
- enum Mode {
- M_ModRM,
- M_64bitDisp,
- M_rip,
- M_ripAddr
- };
- Address(uint32 sizeBit, bool broadcast, const RegExp& e)
- : Operand(0, MEM, sizeBit), e_(e), label_(0), mode_(M_ModRM), broadcast_(broadcast)
- {
- e_.verify();
- }
-#ifdef XBYAK64
- explicit Address(size_t disp)
- : Operand(0, MEM, 64), e_(disp), label_(0), mode_(M_64bitDisp), broadcast_(false){ }
- Address(uint32 sizeBit, bool broadcast, const RegRip& addr)
- : Operand(0, MEM, sizeBit), e_(addr.disp_), label_(addr.label_), mode_(addr.isAddr_ ? M_ripAddr : M_rip), broadcast_(broadcast) { }
-#endif
- RegExp getRegExp(bool optimize = true) const
- {
- return optimize ? e_.optimize() : e_;
- }
- Mode getMode() const { return mode_; }
- bool is32bit() const { return e_.getBase().getBit() == 32 || e_.getIndex().getBit() == 32; }
- bool isOnlyDisp() const { return !e_.getBase().getBit() && !e_.getIndex().getBit(); } // for mov eax
- size_t getDisp() const { return e_.getDisp(); }
- uint8 getRex() const
- {
- if (mode_ != M_ModRM) return 0;
- return getRegExp().getRex();
- }
- bool is64bitDisp() const { return mode_ == M_64bitDisp; } // for moffset
- bool isBroadcast() const { return broadcast_; }
- const Label* getLabel() const { return label_; }
- bool operator==(const Address& rhs) const
- {
- return getBit() == rhs.getBit() && e_ == rhs.e_ && label_ == rhs.label_ && mode_ == rhs.mode_ && broadcast_ == rhs.broadcast_;
- }
- bool operator!=(const Address& rhs) const { return !operator==(rhs); }
- bool isVsib() const { return e_.isVsib(); }
-private:
- RegExp e_;
- const Label* label_;
- Mode mode_;
- bool broadcast_;
-};
-
-inline const Address& Operand::getAddress() const
-{
- assert(isMEM());
- return static_cast<const Address&>(*this);
-}
-
-inline bool Operand::operator==(const Operand& rhs) const
-{
- if (isMEM() && rhs.isMEM()) return this->getAddress() == rhs.getAddress();
- return isEqualIfNotInherited(rhs);
-}
-
-class AddressFrame {
- void operator=(const AddressFrame&);
- AddressFrame(const AddressFrame&);
-public:
- const uint32 bit_;
- const bool broadcast_;
- explicit AddressFrame(uint32 bit, bool broadcast = false) : bit_(bit), broadcast_(broadcast) { }
- Address operator[](const RegExp& e) const
- {
- return Address(bit_, broadcast_, e);
- }
- Address operator[](const void *disp) const
- {
- return Address(bit_, broadcast_, RegExp(reinterpret_cast<size_t>(disp)));
- }
-#ifdef XBYAK64
- Address operator[](uint64 disp) const { return Address(disp); }
- Address operator[](const RegRip& addr) const { return Address(bit_, broadcast_, addr); }
-#endif
-};
-
-struct JmpLabel {
- size_t endOfJmp; /* offset from top to the end address of jmp */
- int jmpSize;
- inner::LabelMode mode;
- size_t disp; // disp for [rip + disp]
- explicit JmpLabel(size_t endOfJmp = 0, int jmpSize = 0, inner::LabelMode mode = inner::LasIs, size_t disp = 0)
- : endOfJmp(endOfJmp), jmpSize(jmpSize), mode(mode), disp(disp)
- {
- }
-};
-
-class LabelManager;
-
-class Label {
- mutable LabelManager *mgr;
- mutable int id;
- friend class LabelManager;
-public:
- Label() : mgr(0), id(0) {}
- Label(const Label& rhs);
- Label& operator=(const Label& rhs);
- ~Label();
- void clear() { mgr = 0; id = 0; }
- int getId() const { return id; }
- const uint8 *getAddress() const;
-
- // backward compatibility
- static inline std::string toStr(int num)
- {
- char buf[16];
-#if defined(_MSC_VER) && (_MSC_VER < 1900)
- _snprintf_s
-#else
- snprintf
-#endif
- (buf, sizeof(buf), ".%08x", num);
- return buf;
- }
-};
-
-class LabelManager {
- // for string label
- struct SlabelVal {
- size_t offset;
- SlabelVal(size_t offset) : offset(offset) {}
- };
- typedef XBYAK_STD_UNORDERED_MAP<std::string, SlabelVal> SlabelDefList;
- typedef XBYAK_STD_UNORDERED_MULTIMAP<std::string, const JmpLabel> SlabelUndefList;
- struct SlabelState {
- SlabelDefList defList;
- SlabelUndefList undefList;
- };
- typedef std::list<SlabelState> StateList;
- // for Label class
- struct ClabelVal {
- ClabelVal(size_t offset = 0) : offset(offset), refCount(1) {}
- size_t offset;
- int refCount;
- };
- typedef XBYAK_STD_UNORDERED_MAP<int, ClabelVal> ClabelDefList;
- typedef XBYAK_STD_UNORDERED_MULTIMAP<int, const JmpLabel> ClabelUndefList;
- typedef XBYAK_STD_UNORDERED_SET<Label*> LabelPtrList;
-
- CodeArray *base_;
- // global : stateList_.front(), local : stateList_.back()
- StateList stateList_;
- mutable int labelId_;
- ClabelDefList clabelDefList_;
- ClabelUndefList clabelUndefList_;
- LabelPtrList labelPtrList_;
-
- int getId(const Label& label) const
- {
- if (label.id == 0) label.id = labelId_++;
- return label.id;
- }
- template<class DefList, class UndefList, class T>
- void define_inner(DefList& defList, UndefList& undefList, const T& labelId, size_t addrOffset)
- {
- // add label
- typename DefList::value_type item(labelId, addrOffset);
- std::pair<typename DefList::iterator, bool> ret = defList.insert(item);
- if (!ret.second) throw Error(ERR_LABEL_IS_REDEFINED);
- // search undefined label
- for (;;) {
- typename UndefList::iterator itr = undefList.find(labelId);
- if (itr == undefList.end()) break;
- const JmpLabel *jmp = &itr->second;
- const size_t offset = jmp->endOfJmp - jmp->jmpSize;
- size_t disp;
- if (jmp->mode == inner::LaddTop) {
- disp = addrOffset;
- } else if (jmp->mode == inner::Labs) {
- disp = size_t(base_->getCurr());
- } else {
- disp = addrOffset - jmp->endOfJmp + jmp->disp;
-#ifdef XBYAK64
- if (jmp->jmpSize <= 4 && !inner::IsInInt32(disp)) throw Error(ERR_OFFSET_IS_TOO_BIG);
-#endif
- if (jmp->jmpSize == 1 && !inner::IsInDisp8((uint32)disp)) throw Error(ERR_LABEL_IS_TOO_FAR);
- }
- if (base_->isAutoGrow()) {
- base_->save(offset, disp, jmp->jmpSize, jmp->mode);
- } else {
- base_->rewrite(offset, disp, jmp->jmpSize);
- }
- undefList.erase(itr);
- }
- }
- template<class DefList, class T>
- bool getOffset_inner(const DefList& defList, size_t *offset, const T& label) const
- {
- typename DefList::const_iterator i = defList.find(label);
- if (i == defList.end()) return false;
- *offset = i->second.offset;
- return true;
- }
- friend class Label;
- void incRefCount(int id, Label *label)
- {
- clabelDefList_[id].refCount++;
- labelPtrList_.insert(label);
- }
- void decRefCount(int id, Label *label)
- {
- labelPtrList_.erase(label);
- ClabelDefList::iterator i = clabelDefList_.find(id);
- if (i == clabelDefList_.end()) return;
- if (i->second.refCount == 1) {
- clabelDefList_.erase(id);
- } else {
- --i->second.refCount;
- }
- }
- template<class T>
- bool hasUndefinedLabel_inner(const T& list) const
- {
-#ifndef NDEBUG
- for (typename T::const_iterator i = list.begin(); i != list.end(); ++i) {
- std::cerr << "undefined label:" << i->first << std::endl;
- }
-#endif
- return !list.empty();
- }
- // detach all labels linked to LabelManager
- void resetLabelPtrList()
- {
- for (LabelPtrList::iterator i = labelPtrList_.begin(), ie = labelPtrList_.end(); i != ie; ++i) {
- (*i)->clear();
- }
- labelPtrList_.clear();
- }
-public:
- LabelManager()
- {
- reset();
- }
- ~LabelManager()
- {
- resetLabelPtrList();
- }
- void reset()
- {
- base_ = 0;
- labelId_ = 1;
- stateList_.clear();
- stateList_.push_back(SlabelState());
- stateList_.push_back(SlabelState());
- clabelDefList_.clear();
- clabelUndefList_.clear();
- resetLabelPtrList();
- }
- void enterLocal()
- {
- stateList_.push_back(SlabelState());
- }
- void leaveLocal()
- {
- if (stateList_.size() <= 2) throw Error(ERR_UNDER_LOCAL_LABEL);
- if (hasUndefinedLabel_inner(stateList_.back().undefList)) throw Error(ERR_LABEL_IS_NOT_FOUND);
- stateList_.pop_back();
- }
- void set(CodeArray *base) { base_ = base; }
- void defineSlabel(std::string label)
- {
- if (label == "@b" || label == "@f") throw Error(ERR_BAD_LABEL_STR);
- if (label == "@@") {
- SlabelDefList& defList = stateList_.front().defList;
- SlabelDefList::iterator i = defList.find("@f");
- if (i != defList.end()) {
- defList.erase(i);
- label = "@b";
- } else {
- i = defList.find("@b");
- if (i != defList.end()) {
- defList.erase(i);
- }
- label = "@f";
- }
- }
- SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front();
- define_inner(st.defList, st.undefList, label, base_->getSize());
- }
- void defineClabel(Label& label)
- {
- define_inner(clabelDefList_, clabelUndefList_, getId(label), base_->getSize());
- label.mgr = this;
- labelPtrList_.insert(&label);
- }
- void assign(Label& dst, const Label& src)
- {
- ClabelDefList::const_iterator i = clabelDefList_.find(src.id);
- if (i == clabelDefList_.end()) throw Error(ERR_LABEL_ISNOT_SET_BY_L);
- define_inner(clabelDefList_, clabelUndefList_, dst.id, i->second.offset);
- dst.mgr = this;
- labelPtrList_.insert(&dst);
- }
- bool getOffset(size_t *offset, std::string& label) const
- {
- const SlabelDefList& defList = stateList_.front().defList;
- if (label == "@b") {
- if (defList.find("@f") != defList.end()) {
- label = "@f";
- } else if (defList.find("@b") == defList.end()) {
- throw Error(ERR_LABEL_IS_NOT_FOUND);
- }
- } else if (label == "@f") {
- if (defList.find("@f") != defList.end()) {
- label = "@b";
- }
- }
- const SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front();
- return getOffset_inner(st.defList, offset, label);
- }
- bool getOffset(size_t *offset, const Label& label) const
- {
- return getOffset_inner(clabelDefList_, offset, getId(label));
- }
- void addUndefinedLabel(const std::string& label, const JmpLabel& jmp)
- {
- SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front();
- st.undefList.insert(SlabelUndefList::value_type(label, jmp));
- }
- void addUndefinedLabel(const Label& label, const JmpLabel& jmp)
- {
- clabelUndefList_.insert(ClabelUndefList::value_type(label.id, jmp));
- }
- bool hasUndefSlabel() const
- {
- for (StateList::const_iterator i = stateList_.begin(), ie = stateList_.end(); i != ie; ++i) {
- if (hasUndefinedLabel_inner(i->undefList)) return true;
- }
- return false;
- }
- bool hasUndefClabel() const { return hasUndefinedLabel_inner(clabelUndefList_); }
- const uint8 *getCode() const { return base_->getCode(); }
- bool isReady() const { return !base_->isAutoGrow() || base_->isCalledCalcJmpAddress(); }
-};
-
-inline Label::Label(const Label& rhs)
-{
- id = rhs.id;
- mgr = rhs.mgr;
- if (mgr) mgr->incRefCount(id, this);
-}
-inline Label& Label::operator=(const Label& rhs)
-{
- if (id) throw Error(ERR_LABEL_IS_ALREADY_SET_BY_L);
- id = rhs.id;
- mgr = rhs.mgr;
- if (mgr) mgr->incRefCount(id, this);
- return *this;
-}
-inline Label::~Label()
-{
- if (id && mgr) mgr->decRefCount(id, this);
-}
-inline const uint8* Label::getAddress() const
-{
- if (mgr == 0 || !mgr->isReady()) return 0;
- size_t offset;
- if (!mgr->getOffset(&offset, *this)) return 0;
- return mgr->getCode() + offset;
-}
-
-class CodeGenerator : public CodeArray {
-public:
- enum LabelType {
- T_SHORT,
- T_NEAR,
- T_AUTO // T_SHORT if possible
- };
-private:
- CodeGenerator operator=(const CodeGenerator&); // don't call
-#ifdef XBYAK64
- enum { i32e = 32 | 64, BIT = 64 };
- static const size_t dummyAddr = (size_t(0x11223344) << 32) | 55667788;
- typedef Reg64 NativeReg;
-#else
- enum { i32e = 32, BIT = 32 };
- static const size_t dummyAddr = 0x12345678;
- typedef Reg32 NativeReg;
-#endif
- // (XMM, XMM|MEM)
- static inline bool isXMM_XMMorMEM(const Operand& op1, const Operand& op2)
- {
- return op1.isXMM() && (op2.isXMM() || op2.isMEM());
- }
- // (MMX, MMX|MEM) or (XMM, XMM|MEM)
- static inline bool isXMMorMMX_MEM(const Operand& op1, const Operand& op2)
- {
- return (op1.isMMX() && (op2.isMMX() || op2.isMEM())) || isXMM_XMMorMEM(op1, op2);
- }
- // (XMM, MMX|MEM)
- static inline bool isXMM_MMXorMEM(const Operand& op1, const Operand& op2)
- {
- return op1.isXMM() && (op2.isMMX() || op2.isMEM());
- }
- // (MMX, XMM|MEM)
- static inline bool isMMX_XMMorMEM(const Operand& op1, const Operand& op2)
- {
- return op1.isMMX() && (op2.isXMM() || op2.isMEM());
- }
- // (XMM, REG32|MEM)
- static inline bool isXMM_REG32orMEM(const Operand& op1, const Operand& op2)
- {
- return op1.isXMM() && (op2.isREG(i32e) || op2.isMEM());
- }
- // (REG32, XMM|MEM)
- static inline bool isREG32_XMMorMEM(const Operand& op1, const Operand& op2)
- {
- return op1.isREG(i32e) && (op2.isXMM() || op2.isMEM());
- }
- // (REG32, REG32|MEM)
- static inline bool isREG32_REG32orMEM(const Operand& op1, const Operand& op2)
- {
- return op1.isREG(i32e) && ((op2.isREG(i32e) && op1.getBit() == op2.getBit()) || op2.isMEM());
- }
- void rex(const Operand& op1, const Operand& op2 = Operand())
- {
- uint8 rex = 0;
- const Operand *p1 = &op1, *p2 = &op2;
- if (p1->isMEM()) std::swap(p1, p2);
- if (p1->isMEM()) throw Error(ERR_BAD_COMBINATION);
- if (p2->isMEM()) {
- const Address& addr = p2->getAddress();
- if (BIT == 64 && addr.is32bit()) db(0x67);
- rex = addr.getRex() | p1->getReg().getRex();
- } else {
- // ModRM(reg, base);
- rex = op2.getReg().getRex(op1.getReg());
- }
- // except movsx(16bit, 32/64bit)
- if ((op1.isBit(16) && !op2.isBit(i32e)) || (op2.isBit(16) && !op1.isBit(i32e))) db(0x66);
- if (rex) db(rex);
- }
- enum AVXtype {
- // low 3 bit
- T_N1 = 1,
- T_N2 = 2,
- T_N4 = 3,
- T_N8 = 4,
- T_N16 = 5,
- T_N32 = 6,
- T_NX_MASK = 7,
- //
- T_N_VL = 1 << 3, // N * (1, 2, 4) for VL
- T_DUP = 1 << 4, // N = (8, 32, 64)
- T_66 = 1 << 5,
- T_F3 = 1 << 6,
- T_F2 = 1 << 7,
- T_0F = 1 << 8,
- T_0F38 = 1 << 9,
- T_0F3A = 1 << 10,
- T_L0 = 1 << 11,
- T_L1 = 1 << 12,
- T_W0 = 1 << 13,
- T_W1 = 1 << 14,
- T_EW0 = 1 << 15,
- T_EW1 = 1 << 16,
- T_YMM = 1 << 17, // support YMM, ZMM
- T_EVEX = 1 << 18,
- T_ER_X = 1 << 19, // xmm{er}
- T_ER_Y = 1 << 20, // ymm{er}
- T_ER_Z = 1 << 21, // zmm{er}
- T_SAE_X = 1 << 22, // xmm{sae}
- T_SAE_Y = 1 << 23, // ymm{sae}
- T_SAE_Z = 1 << 24, // zmm{sae}
- T_MUST_EVEX = 1 << 25, // contains T_EVEX
- T_B32 = 1 << 26, // m32bcst
- T_B64 = 1 << 27, // m64bcst
- T_M_K = 1 << 28, // mem{k}
- T_VSIB = 1 << 29,
- T_MEM_EVEX = 1 << 30, // use evex if mem
- T_XXX
- };
- void vex(const Reg& reg, const Reg& base, const Operand *v, int type, int code, bool x = false)
- {
- int w = (type & T_W1) ? 1 : 0;
- bool is256 = (type & T_L1) ? true : (type & T_L0) ? false : reg.isYMM();
- bool r = reg.isExtIdx();
- bool b = base.isExtIdx();
- int idx = v ? v->getIdx() : 0;
- if ((idx | reg.getIdx() | base.getIdx()) >= 16) throw Error(ERR_BAD_COMBINATION);
- uint32 pp = (type & T_66) ? 1 : (type & T_F3) ? 2 : (type & T_F2) ? 3 : 0;
- uint32 vvvv = (((~idx) & 15) << 3) | (is256 ? 4 : 0) | pp;
- if (!b && !x && !w && (type & T_0F)) {
- db(0xC5); db((r ? 0 : 0x80) | vvvv);
- } else {
- uint32 mmmm = (type & T_0F) ? 1 : (type & T_0F38) ? 2 : (type & T_0F3A) ? 3 : 0;
- db(0xC4); db((r ? 0 : 0x80) | (x ? 0 : 0x40) | (b ? 0 : 0x20) | mmmm); db((w << 7) | vvvv);
- }
- db(code);
- }
- void verifySAE(const Reg& r, int type) const
- {
- if (((type & T_SAE_X) && r.isXMM()) || ((type & T_SAE_Y) && r.isYMM()) || ((type & T_SAE_Z) && r.isZMM())) return;
- throw Error(ERR_SAE_IS_INVALID);
- }
- void verifyER(const Reg& r, int type) const
- {
- if (((type & T_ER_X) && r.isXMM()) || ((type & T_ER_Y) && r.isYMM()) || ((type & T_ER_Z) && r.isZMM())) return;
- throw Error(ERR_ER_IS_INVALID);
- }
- // (a, b, c) contains non zero two or three values then err
- int verifyDuplicate(int a, int b, int c, int err)
- {
- int v = a | b | c;
- if ((a > 0 && a != v) + (b > 0 && b != v) + (c > 0 && c != v) > 0) return Error(err);
- return v;
- }
- int evex(const Reg& reg, const Reg& base, const Operand *v, int type, int code, bool x = false, bool b = false, int aaa = 0, uint32 VL = 0, bool Hi16Vidx = false)
- {
- if (!(type & (T_EVEX | T_MUST_EVEX))) throw Error(ERR_EVEX_IS_INVALID);
- int w = (type & T_EW1) ? 1 : 0;
- uint32 mm = (type & T_0F) ? 1 : (type & T_0F38) ? 2 : (type & T_0F3A) ? 3 : 0;
- uint32 pp = (type & T_66) ? 1 : (type & T_F3) ? 2 : (type & T_F2) ? 3 : 0;
-
- int idx = v ? v->getIdx() : 0;
- uint32 vvvv = ~idx;
-
- bool R = !reg.isExtIdx();
- bool X = x ? false : !base.isExtIdx2();
- bool B = !base.isExtIdx();
- bool Rp = !reg.isExtIdx2();
- int LL;
- int rounding = verifyDuplicate(reg.getRounding(), base.getRounding(), v ? v->getRounding() : 0, ERR_ROUNDING_IS_ALREADY_SET);
- int disp8N = 1;
- if (rounding) {
- if (rounding == EvexModifierRounding::T_SAE) {
- verifySAE(base, type); LL = 0;
- } else {
- verifyER(base, type); LL = rounding - 1;
- }
- b = true;
- } else {
- if (v) VL = (std::max)(VL, v->getBit());
- VL = (std::max)((std::max)(reg.getBit(), base.getBit()), VL);
- LL = (VL == 512) ? 2 : (VL == 256) ? 1 : 0;
- if (b) {
- disp8N = (type & T_B32) ? 4 : 8;
- } else if (type & T_DUP) {
- disp8N = VL == 128 ? 8 : VL == 256 ? 32 : 64;
- } else {
- if ((type & (T_NX_MASK | T_N_VL)) == 0) {
- type |= T_N16 | T_N_VL; // default
- }
- int low = type & T_NX_MASK;
- if (low > 0) {
- disp8N = 1 << (low - 1);
- if (type & T_N_VL) disp8N *= (VL == 512 ? 4 : VL == 256 ? 2 : 1);
- }
- }
- }
- bool Vp = !((v ? v->isExtIdx2() : 0) | Hi16Vidx);
- bool z = reg.hasZero() || base.hasZero() || (v ? v->hasZero() : false);
- if (aaa == 0) aaa = verifyDuplicate(base.getOpmaskIdx(), reg.getOpmaskIdx(), (v ? v->getOpmaskIdx() : 0), ERR_OPMASK_IS_ALREADY_SET);
- db(0x62);
- db((R ? 0x80 : 0) | (X ? 0x40 : 0) | (B ? 0x20 : 0) | (Rp ? 0x10 : 0) | (mm & 3));
- db((w == 1 ? 0x80 : 0) | ((vvvv & 15) << 3) | 4 | (pp & 3));
- db((z ? 0x80 : 0) | ((LL & 3) << 5) | (b ? 0x10 : 0) | (Vp ? 8 : 0) | (aaa & 7));
- db(code);
- return disp8N;
- }
- void setModRM(int mod, int r1, int r2)
- {
- db(static_cast<uint8>((mod << 6) | ((r1 & 7) << 3) | (r2 & 7)));
- }
- void setSIB(const RegExp& e, int reg, int disp8N = 0)
- {
- size_t disp64 = e.getDisp();
-#ifdef XBYAK64
- size_t high = disp64 >> 32;
- if (high != 0 && high != 0xFFFFFFFF) throw Error(ERR_OFFSET_IS_TOO_BIG);
-#endif
- uint32 disp = static_cast<uint32>(disp64);
- const Reg& base = e.getBase();
- const Reg& index = e.getIndex();
- const int baseIdx = base.getIdx();
- const int baseBit = base.getBit();
- const int indexBit = index.getBit();
- enum {
- mod00 = 0, mod01 = 1, mod10 = 2
- };
- int mod = mod10; // disp32
- if (!baseBit || ((baseIdx & 7) != Operand::EBP && disp == 0)) {
- mod = mod00;
- } else {
- if (disp8N == 0) {
- if (inner::IsInDisp8(disp)) {
- mod = mod01;
- }
- } else {
- // disp must be casted to signed
- uint32 t = static_cast<uint32>(static_cast<int>(disp) / disp8N);
- if ((disp % disp8N) == 0 && inner::IsInDisp8(t)) {
- disp = t;
- mod = mod01;
- }
- }
- }
- const int newBaseIdx = baseBit ? (baseIdx & 7) : Operand::EBP;
- /* ModR/M = [2:3:3] = [Mod:reg/code:R/M] */
- bool hasSIB = indexBit || (baseIdx & 7) == Operand::ESP;
-#ifdef XBYAK64
- if (!baseBit && !indexBit) hasSIB = true;
-#endif
- if (hasSIB) {
- setModRM(mod, reg, Operand::ESP);
- /* SIB = [2:3:3] = [SS:index:base(=rm)] */
- const int idx = indexBit ? (index.getIdx() & 7) : Operand::ESP;
- const int scale = e.getScale();
- const int SS = (scale == 8) ? 3 : (scale == 4) ? 2 : (scale == 2) ? 1 : 0;
- setModRM(SS, idx, newBaseIdx);
- } else {
- setModRM(mod, reg, newBaseIdx);
- }
- if (mod == mod01) {
- db(disp);
- } else if (mod == mod10 || (mod == mod00 && !baseBit)) {
- dd(disp);
- }
- }
- LabelManager labelMgr_;
- bool isInDisp16(uint32 x) const { return 0xFFFF8000 <= x || x <= 0x7FFF; }
- void opModR(const Reg& reg1, const Reg& reg2, int code0, int code1 = NONE, int code2 = NONE)
- {
- rex(reg2, reg1);
- db(code0 | (reg1.isBit(8) ? 0 : 1)); if (code1 != NONE) db(code1); if (code2 != NONE) db(code2);
- setModRM(3, reg1.getIdx(), reg2.getIdx());
- }
- void opModM(const Address& addr, const Reg& reg, int code0, int code1 = NONE, int code2 = NONE, int immSize = 0)
- {
- if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP);
- rex(addr, reg);
- db(code0 | (reg.isBit(8) ? 0 : 1)); if (code1 != NONE) db(code1); if (code2 != NONE) db(code2);
- opAddr(addr, reg.getIdx(), immSize);
- }
- void opMIB(const Address& addr, const Reg& reg, int code0, int code1)
- {
- if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP);
- if (addr.getMode() != Address::M_ModRM) throw Error(ERR_INVALID_MIB_ADDRESS);
- if (BIT == 64 && addr.is32bit()) db(0x67);
- const RegExp& regExp = addr.getRegExp(false);
- uint8 rex = regExp.getRex();
- if (rex) db(rex);
- db(code0); db(code1);
- setSIB(regExp, reg.getIdx());
- }
- void makeJmp(uint32 disp, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref)
- {
- const int shortJmpSize = 2;
- const int longHeaderSize = longPref ? 2 : 1;
- const int longJmpSize = longHeaderSize + 4;
- if (type != T_NEAR && inner::IsInDisp8(disp - shortJmpSize)) {
- db(shortCode); db(disp - shortJmpSize);
- } else {
- if (type == T_SHORT) throw Error(ERR_LABEL_IS_TOO_FAR);
- if (longPref) db(longPref);
- db(longCode); dd(disp - longJmpSize);
- }
- }
- template<class T>
- void opJmp(T& label, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref)
- {
- if (isAutoGrow() && size_ + 16 >= maxSize_) growMemory(); /* avoid splitting code of jmp */
- size_t offset = 0;
- if (labelMgr_.getOffset(&offset, label)) { /* label exists */
- makeJmp(inner::VerifyInInt32(offset - size_), type, shortCode, longCode, longPref);
- } else {
- int jmpSize = 0;
- if (type == T_NEAR) {
- jmpSize = 4;
- if (longPref) db(longPref);
- db(longCode); dd(0);
- } else {
- jmpSize = 1;
- db(shortCode); db(0);
- }
- JmpLabel jmp(size_, jmpSize, inner::LasIs);
- labelMgr_.addUndefinedLabel(label, jmp);
- }
- }
- void opJmpAbs(const void *addr, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref = 0)
- {
- if (isAutoGrow()) {
- if (type != T_NEAR) throw Error(ERR_ONLY_T_NEAR_IS_SUPPORTED_IN_AUTO_GROW);
- if (size_ + 16 >= maxSize_) growMemory();
- if (longPref) db(longPref);
- db(longCode);
- dd(0);
- save(size_ - 4, size_t(addr) - size_, 4, inner::Labs);
- } else {
- makeJmp(inner::VerifyInInt32(reinterpret_cast<const uint8*>(addr) - getCurr()), type, shortCode, longCode, longPref);
- }
-
- }
- // reg is reg field of ModRM
- // immSize is the size for immediate value
- // disp8N = 0(normal), disp8N = 1(force disp32), disp8N = {2, 4, 8} ; compressed displacement
- void opAddr(const Address &addr, int reg, int immSize = 0, int disp8N = 0, bool permitVisb = false)
- {
- if (!permitVisb && addr.isVsib()) throw Error(ERR_BAD_VSIB_ADDRESSING);
- if (addr.getMode() == Address::M_ModRM) {
- setSIB(addr.getRegExp(), reg, disp8N);
- } else if (addr.getMode() == Address::M_rip || addr.getMode() == Address::M_ripAddr) {
- setModRM(0, reg, 5);
- if (addr.getLabel()) { // [rip + Label]
- putL_inner(*addr.getLabel(), true, addr.getDisp() - immSize);
- } else {
- size_t disp = addr.getDisp();
- if (addr.getMode() == Address::M_ripAddr) {
- if (isAutoGrow()) throw Error(ERR_INVALID_RIP_IN_AUTO_GROW);
- disp -= (size_t)getCurr() + 4 + immSize;
- }
- dd(inner::VerifyInInt32(disp));
- }
- }
- }
- /* preCode is for SSSE3/SSE4 */
- void opGen(const Operand& reg, const Operand& op, int code, int pref, bool isValid(const Operand&, const Operand&), int imm8 = NONE, int preCode = NONE)
- {
- if (isValid && !isValid(reg, op)) throw Error(ERR_BAD_COMBINATION);
- if (pref != NONE) db(pref);
- if (op.isMEM()) {
- opModM(op.getAddress(), reg.getReg(), 0x0F, preCode, code, (imm8 != NONE) ? 1 : 0);
- } else {
- opModR(reg.getReg(), op.getReg(), 0x0F, preCode, code);
- }
- if (imm8 != NONE) db(imm8);
- }
- void opMMX_IMM(const Mmx& mmx, int imm8, int code, int ext)
- {
- if (mmx.isXMM()) db(0x66);
- opModR(Reg32(ext), mmx, 0x0F, code);
- db(imm8);
- }
- void opMMX(const Mmx& mmx, const Operand& op, int code, int pref = 0x66, int imm8 = NONE, int preCode = NONE)
- {
- opGen(mmx, op, code, mmx.isXMM() ? pref : NONE, isXMMorMMX_MEM, imm8, preCode);
- }
- void opMovXMM(const Operand& op1, const Operand& op2, int code, int pref)
- {
- if (pref != NONE) db(pref);
- if (op1.isXMM() && op2.isMEM()) {
- opModM(op2.getAddress(), op1.getReg(), 0x0F, code);
- } else if (op1.isMEM() && op2.isXMM()) {
- opModM(op1.getAddress(), op2.getReg(), 0x0F, code | 1);
- } else {
- throw Error(ERR_BAD_COMBINATION);
- }
- }
- void opExt(const Operand& op, const Mmx& mmx, int code, int imm, bool hasMMX2 = false)
- {
- if (hasMMX2 && op.isREG(i32e)) { /* pextrw is special */
- if (mmx.isXMM()) db(0x66);
- opModR(op.getReg(), mmx, 0x0F, 0xC5); db(imm);
- } else {
- opGen(mmx, op, code, 0x66, isXMM_REG32orMEM, imm, 0x3A);
- }
- }
- void opR_ModM(const Operand& op, int bit, int ext, int code0, int code1 = NONE, int code2 = NONE, bool disableRex = false, int immSize = 0)
- {
- int opBit = op.getBit();
- if (disableRex && opBit == 64) opBit = 32;
- if (op.isREG(bit)) {
- opModR(Reg(ext, Operand::REG, opBit), op.getReg().changeBit(opBit), code0, code1, code2);
- } else if (op.isMEM()) {
- opModM(op.getAddress(), Reg(ext, Operand::REG, opBit), code0, code1, code2, immSize);
- } else {
- throw Error(ERR_BAD_COMBINATION);
- }
- }
- void opShift(const Operand& op, int imm, int ext)
- {
- verifyMemHasSize(op);
- opR_ModM(op, 0, ext, (0xC0 | ((imm == 1 ? 1 : 0) << 4)), NONE, NONE, false, (imm != 1) ? 1 : 0);
- if (imm != 1) db(imm);
- }
- void opShift(const Operand& op, const Reg8& _cl, int ext)
- {
- if (_cl.getIdx() != Operand::CL) throw Error(ERR_BAD_COMBINATION);
- opR_ModM(op, 0, ext, 0xD2);
- }
- void opModRM(const Operand& op1, const Operand& op2, bool condR, bool condM, int code0, int code1 = NONE, int code2 = NONE, int immSize = 0)
- {
- if (condR) {
- opModR(op1.getReg(), op2.getReg(), code0, code1, code2);
- } else if (condM) {
- opModM(op2.getAddress(), op1.getReg(), code0, code1, code2, immSize);
- } else {
- throw Error(ERR_BAD_COMBINATION);
- }
- }
- void opShxd(const Operand& op, const Reg& reg, uint8 imm, int code, const Reg8 *_cl = 0)
- {
- if (_cl && _cl->getIdx() != Operand::CL) throw Error(ERR_BAD_COMBINATION);
- opModRM(reg, op, (op.isREG(16 | i32e) && op.getBit() == reg.getBit()), op.isMEM() && (reg.isREG(16 | i32e)), 0x0F, code | (_cl ? 1 : 0), NONE, _cl ? 0 : 1);
- if (!_cl) db(imm);
- }
- // (REG, REG|MEM), (MEM, REG)
- void opRM_RM(const Operand& op1, const Operand& op2, int code)
- {
- if (op1.isREG() && op2.isMEM()) {
- opModM(op2.getAddress(), op1.getReg(), code | 2);
- } else {
- opModRM(op2, op1, op1.isREG() && op1.getKind() == op2.getKind(), op1.isMEM() && op2.isREG(), code);
- }
- }
- // (REG|MEM, IMM)
- void opRM_I(const Operand& op, uint32 imm, int code, int ext)
- {
- verifyMemHasSize(op);
- uint32 immBit = inner::IsInDisp8(imm) ? 8 : isInDisp16(imm) ? 16 : 32;
- if (op.isBit(8)) immBit = 8;
- if (op.getBit() < immBit) throw Error(ERR_IMM_IS_TOO_BIG);
- if (op.isBit(32|64) && immBit == 16) immBit = 32; /* don't use MEM16 if 32/64bit mode */
- if (op.isREG() && op.getIdx() == 0 && (op.getBit() == immBit || (op.isBit(64) && immBit == 32))) { // rax, eax, ax, al
- rex(op);
- db(code | 4 | (immBit == 8 ? 0 : 1));
- } else {
- int tmp = immBit < (std::min)(op.getBit(), 32U) ? 2 : 0;
- opR_ModM(op, 0, ext, 0x80 | tmp, NONE, NONE, false, immBit / 8);
- }
- db(imm, immBit / 8);
- }
- void opIncDec(const Operand& op, int code, int ext)
- {
- verifyMemHasSize(op);
-#ifndef XBYAK64
- if (op.isREG() && !op.isBit(8)) {
- rex(op); db(code | op.getIdx());
- return;
- }
-#endif
- code = 0xFE;
- if (op.isREG()) {
- opModR(Reg(ext, Operand::REG, op.getBit()), op.getReg(), code);
- } else {
- opModM(op.getAddress(), Reg(ext, Operand::REG, op.getBit()), code);
- }
- }
- void opPushPop(const Operand& op, int code, int ext, int alt)
- {
- int bit = op.getBit();
- if (bit == 16 || bit == BIT) {
- if (bit == 16) db(0x66);
- if (op.isREG()) {
- if (op.getReg().getIdx() >= 8) db(0x41);
- db(alt | (op.getIdx() & 7));
- return;
- }
- if (op.isMEM()) {
- opModM(op.getAddress(), Reg(ext, Operand::REG, 32), code);
- return;
- }
- }
- throw Error(ERR_BAD_COMBINATION);
- }
- void verifyMemHasSize(const Operand& op) const
- {
- if (op.isMEM() && op.getBit() == 0) throw Error(ERR_MEM_SIZE_IS_NOT_SPECIFIED);
- }
- /*
- mov(r, imm) = db(imm, mov_imm(r, imm))
- */
- int mov_imm(const Reg& reg, size_t imm)
- {
- int bit = reg.getBit();
- const int idx = reg.getIdx();
- int code = 0xB0 | ((bit == 8 ? 0 : 1) << 3);
- if (bit == 64 && (imm & ~size_t(0xffffffffu)) == 0) {
- rex(Reg32(idx));
- bit = 32;
- } else {
- rex(reg);
- if (bit == 64 && inner::IsInInt32(imm)) {
- db(0xC7);
- code = 0xC0;
- bit = 32;
- }
- }
- db(code | (idx & 7));
- return bit / 8;
- }
- template<class T>
- void putL_inner(T& label, bool relative = false, size_t disp = 0)
- {
- const int jmpSize = relative ? 4 : (int)sizeof(size_t);
- if (isAutoGrow() && size_ + 16 >= maxSize_) growMemory();
- size_t offset = 0;
- if (labelMgr_.getOffset(&offset, label)) {
- if (relative) {
- db(inner::VerifyInInt32(offset + disp - size_ - jmpSize), jmpSize);
- } else if (isAutoGrow()) {
- db(uint64(0), jmpSize);
- save(size_ - jmpSize, offset, jmpSize, inner::LaddTop);
- } else {
- db(size_t(top_) + offset, jmpSize);
- }
- return;
- }
- db(uint64(0), jmpSize);
- JmpLabel jmp(size_, jmpSize, (relative ? inner::LasIs : isAutoGrow() ? inner::LaddTop : inner::Labs), disp);
- labelMgr_.addUndefinedLabel(label, jmp);
- }
- void opMovxx(const Reg& reg, const Operand& op, uint8 code)
- {
- if (op.isBit(32)) throw Error(ERR_BAD_COMBINATION);
- int w = op.isBit(16);
-#ifdef XBYAK64
- if (op.isHigh8bit()) throw Error(ERR_BAD_COMBINATION);
-#endif
- bool cond = reg.isREG() && (reg.getBit() > op.getBit());
- opModRM(reg, op, cond && op.isREG(), cond && op.isMEM(), 0x0F, code | w);
- }
- void opFpuMem(const Address& addr, uint8 m16, uint8 m32, uint8 m64, uint8 ext, uint8 m64ext)
- {
- if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP);
- uint8 code = addr.isBit(16) ? m16 : addr.isBit(32) ? m32 : addr.isBit(64) ? m64 : 0;
- if (!code) throw Error(ERR_BAD_MEM_SIZE);
- if (m64ext && addr.isBit(64)) ext = m64ext;
-
- rex(addr, st0);
- db(code);
- opAddr(addr, ext);
- }
- // use code1 if reg1 == st0
- // use code2 if reg1 != st0 && reg2 == st0
- void opFpuFpu(const Fpu& reg1, const Fpu& reg2, uint32 code1, uint32 code2)
- {
- uint32 code = reg1.getIdx() == 0 ? code1 : reg2.getIdx() == 0 ? code2 : 0;
- if (!code) throw Error(ERR_BAD_ST_COMBINATION);
- db(uint8(code >> 8));
- db(uint8(code | (reg1.getIdx() | reg2.getIdx())));
- }
- void opFpu(const Fpu& reg, uint8 code1, uint8 code2)
- {
- db(code1); db(code2 | reg.getIdx());
- }
- void opVex(const Reg& r, const Operand *p1, const Operand& op2, int type, int code, int imm8 = NONE)
- {
- if (op2.isMEM()) {
- const Address& addr = op2.getAddress();
- const RegExp& regExp = addr.getRegExp();
- const Reg& base = regExp.getBase();
- const Reg& index = regExp.getIndex();
- if (BIT == 64 && addr.is32bit()) db(0x67);
- int disp8N = 0;
- bool x = index.isExtIdx();
- if ((type & (T_MUST_EVEX|T_MEM_EVEX)) || r.hasEvex() || (p1 && p1->hasEvex()) || addr.isBroadcast() || addr.getOpmaskIdx()) {
- int aaa = addr.getOpmaskIdx();
- if (aaa && !(type & T_M_K)) throw Error(ERR_INVALID_OPMASK_WITH_MEMORY);
- bool b = false;
- if (addr.isBroadcast()) {
- if (!(type & (T_B32 | T_B64))) throw Error(ERR_INVALID_BROADCAST);
- b = true;
- }
- int VL = regExp.isVsib() ? index.getBit() : 0;
- disp8N = evex(r, base, p1, type, code, x, b, aaa, VL, index.isExtIdx2());
- } else {
- vex(r, base, p1, type, code, x);
- }
- opAddr(addr, r.getIdx(), (imm8 != NONE) ? 1 : 0, disp8N, (type & T_VSIB) != 0);
- } else {
- const Reg& base = op2.getReg();
- if ((type & T_MUST_EVEX) || r.hasEvex() || (p1 && p1->hasEvex()) || base.hasEvex()) {
- evex(r, base, p1, type, code);
- } else {
- vex(r, base, p1, type, code);
- }
- setModRM(3, r.getIdx(), base.getIdx());
- }
- if (imm8 != NONE) db(imm8);
- }
- // (r, r, r/m) if isR_R_RM
- // (r, r/m, r)
- void opGpr(const Reg32e& r, const Operand& op1, const Operand& op2, int type, uint8 code, bool isR_R_RM, int imm8 = NONE)
- {
- const Operand *p1 = &op1;
- const Operand *p2 = &op2;
- if (!isR_R_RM) std::swap(p1, p2);
- const unsigned int bit = r.getBit();
- if (p1->getBit() != bit || (p2->isREG() && p2->getBit() != bit)) throw Error(ERR_BAD_COMBINATION);
- type |= (bit == 64) ? T_W1 : T_W0;
- opVex(r, p1, *p2, type, code, imm8);
- }
- void opAVX_X_X_XM(const Xmm& x1, const Operand& op1, const Operand& op2, int type, int code0, int imm8 = NONE)
- {
- const Xmm *x2 = static_cast<const Xmm*>(&op1);
- const Operand *op = &op2;
- if (op2.isNone()) { // (x1, op1) -> (x1, x1, op1)
- x2 = &x1;
- op = &op1;
- }
- // (x1, x2, op)
- if (!((x1.isXMM() && x2->isXMM()) || ((type & T_YMM) && ((x1.isYMM() && x2->isYMM()) || (x1.isZMM() && x2->isZMM()))))) throw Error(ERR_BAD_COMBINATION);
- opVex(x1, x2, *op, type, code0, imm8);
- }
- void opAVX_K_X_XM(const Opmask& k, const Xmm& x2, const Operand& op3, int type, int code0, int imm8 = NONE)
- {
- if (!op3.isMEM() && (x2.getKind() != op3.getKind())) throw Error(ERR_BAD_COMBINATION);
- opVex(k, &x2, op3, type, code0, imm8);
- }
- // (x, x/m), (y, x/m256), (z, y/m)
- void checkCvt1(const Operand& x, const Operand& op) const
- {
- if (!op.isMEM() && !(x.is(Operand::XMM | Operand::YMM) && op.isXMM()) && !(x.isZMM() && op.isYMM())) throw Error(ERR_BAD_COMBINATION);
- }
- // (x, x/m), (x, y/m256), (y, z/m)
- void checkCvt2(const Xmm& x, const Operand& op) const
- {
- if (!(x.isXMM() && op.is(Operand::XMM | Operand::YMM | Operand::MEM)) && !(x.isYMM() && op.is(Operand::ZMM | Operand::MEM))) throw Error(ERR_BAD_COMBINATION);
- }
- void opCvt2(const Xmm& x, const Operand& op, int type, int code)
- {
- checkCvt2(x, op);
- Operand::Kind kind = x.isXMM() ? (op.isBit(256) ? Operand::YMM : Operand::XMM) : Operand::ZMM;
- opVex(x.copyAndSetKind(kind), &xm0, op, type, code);
- }
- void opCvt3(const Xmm& x1, const Xmm& x2, const Operand& op, int type, int type64, int type32, uint8 code)
- {
- if (!(x1.isXMM() && x2.isXMM() && (op.isREG(i32e) || op.isMEM()))) throw Error(ERR_BAD_SIZE_OF_REGISTER);
- Xmm x(op.getIdx());
- const Operand *p = op.isREG() ? &x : &op;
- opVex(x1, &x2, *p, type | (op.isBit(64) ? type64 : type32), code);
- }
- const Xmm& cvtIdx0(const Operand& x) const
- {
- return x.isZMM() ? zm0 : x.isYMM() ? ym0 : xm0;
- }
- // support (x, x/m, imm), (y, y/m, imm)
- void opAVX_X_XM_IMM(const Xmm& x, const Operand& op, int type, int code, int imm8 = NONE)
- {
- opAVX_X_X_XM(x, cvtIdx0(x), op, type, code, imm8);
- }
- // QQQ:need to refactor
- void opSp1(const Reg& reg, const Operand& op, uint8 pref, uint8 code0, uint8 code1)
- {
- if (reg.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER);
- bool is16bit = reg.isREG(16) && (op.isREG(16) || op.isMEM());
- if (!is16bit && !(reg.isREG(i32e) && (op.isREG(reg.getBit()) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION);
- if (is16bit) db(0x66);
- db(pref); opModRM(reg.changeBit(i32e == 32 ? 32 : reg.getBit()), op, op.isREG(), true, code0, code1);
- }
- void opGather(const Xmm& x1, const Address& addr, const Xmm& x2, int type, uint8 code, int mode)
- {
- const RegExp& regExp = addr.getRegExp();
- if (!regExp.isVsib(128 | 256)) throw Error(ERR_BAD_VSIB_ADDRESSING);
- const int y_vx_y = 0;
- const int y_vy_y = 1;
-// const int x_vy_x = 2;
- const bool isAddrYMM = regExp.getIndex().getBit() == 256;
- if (!x1.isXMM() || isAddrYMM || !x2.isXMM()) {
- bool isOK = false;
- if (mode == y_vx_y) {
- isOK = x1.isYMM() && !isAddrYMM && x2.isYMM();
- } else if (mode == y_vy_y) {
- isOK = x1.isYMM() && isAddrYMM && x2.isYMM();
- } else { // x_vy_x
- isOK = !x1.isYMM() && isAddrYMM && !x2.isYMM();
- }
- if (!isOK) throw Error(ERR_BAD_VSIB_ADDRESSING);
- }
- opAVX_X_X_XM(isAddrYMM ? Ymm(x1.getIdx()) : x1, isAddrYMM ? Ymm(x2.getIdx()) : x2, addr, type, code);
- }
- enum {
- xx_yy_zz = 0,
- xx_yx_zy = 1,
- xx_xy_yz = 2
- };
- void checkGather2(const Xmm& x1, const Reg& x2, int mode) const
- {
- if (x1.isXMM() && x2.isXMM()) return;
- switch (mode) {
- case xx_yy_zz: if ((x1.isYMM() && x2.isYMM()) || (x1.isZMM() && x2.isZMM())) return;
- break;
- case xx_yx_zy: if ((x1.isYMM() && x2.isXMM()) || (x1.isZMM() && x2.isYMM())) return;
- break;
- case xx_xy_yz: if ((x1.isXMM() && x2.isYMM()) || (x1.isYMM() && x2.isZMM())) return;
- break;
- }
- throw Error(ERR_BAD_VSIB_ADDRESSING);
- }
- void opGather2(const Xmm& x, const Address& addr, int type, uint8 code, int mode)
- {
- if (x.hasZero()) throw Error(ERR_INVALID_ZERO);
- checkGather2(x, addr.getRegExp().getIndex(), mode);
- opVex(x, 0, addr, type, code);
- }
- /*
- xx_xy_yz ; mode = true
- xx_xy_xz ; mode = false
- */
- void opVmov(const Operand& op, const Xmm& x, int type, uint8 code, bool mode)
- {
- if (mode) {
- if (!op.isMEM() && !((op.isXMM() && x.isXMM()) || (op.isXMM() && x.isYMM()) || (op.isYMM() && x.isZMM()))) throw Error(ERR_BAD_COMBINATION);
- } else {
- if (!op.isMEM() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION);
- }
- opVex(x, 0, op, type, code);
- }
- void opGatherFetch(const Address& addr, const Xmm& x, int type, uint8 code, Operand::Kind kind)
- {
- if (addr.hasZero()) throw Error(ERR_INVALID_ZERO);
- if (addr.getRegExp().getIndex().getKind() != kind) throw Error(ERR_BAD_VSIB_ADDRESSING);
- opVex(x, 0, addr, type, code);
- }
-public:
- unsigned int getVersion() const { return VERSION; }
- using CodeArray::db;
- const Mmx mm0, mm1, mm2, mm3, mm4, mm5, mm6, mm7;
- const Xmm xmm0, xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7;
- const Ymm ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7;
- const Zmm zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7;
- const Xmm &xm0, &xm1, &xm2, &xm3, &xm4, &xm5, &xm6, &xm7;
- const Ymm &ym0, &ym1, &ym2, &ym3, &ym4, &ym5, &ym6, &ym7;
- const Ymm &zm0, &zm1, &zm2, &zm3, &zm4, &zm5, &zm6, &zm7;
- const Reg32 eax, ecx, edx, ebx, esp, ebp, esi, edi;
- const Reg16 ax, cx, dx, bx, sp, bp, si, di;
- const Reg8 al, cl, dl, bl, ah, ch, dh, bh;
- const AddressFrame ptr, byte, word, dword, qword, xword, yword, zword; // xword is same as oword of NASM
- const AddressFrame ptr_b, xword_b, yword_b, zword_b; // broadcast such as {1to2}, {1to4}, {1to8}, {1to16}, {b}
- const Fpu st0, st1, st2, st3, st4, st5, st6, st7;
- const Opmask k0, k1, k2, k3, k4, k5, k6, k7;
- const BoundsReg bnd0, bnd1, bnd2, bnd3;
- const EvexModifierRounding T_sae, T_rn_sae, T_rd_sae, T_ru_sae, T_rz_sae; // {sae}, {rn-sae}, {rd-sae}, {ru-sae}, {rz-sae}
- const EvexModifierZero T_z; // {z}
-#ifdef XBYAK64
- const Reg64 rax, rcx, rdx, rbx, rsp, rbp, rsi, rdi, r8, r9, r10, r11, r12, r13, r14, r15;
- const Reg32 r8d, r9d, r10d, r11d, r12d, r13d, r14d, r15d;
- const Reg16 r8w, r9w, r10w, r11w, r12w, r13w, r14w, r15w;
- const Reg8 r8b, r9b, r10b, r11b, r12b, r13b, r14b, r15b;
- const Reg8 spl, bpl, sil, dil;
- const Xmm xmm8, xmm9, xmm10, xmm11, xmm12, xmm13, xmm14, xmm15;
- const Xmm xmm16, xmm17, xmm18, xmm19, xmm20, xmm21, xmm22, xmm23;
- const Xmm xmm24, xmm25, xmm26, xmm27, xmm28, xmm29, xmm30, xmm31;
- const Ymm ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14, ymm15;
- const Ymm ymm16, ymm17, ymm18, ymm19, ymm20, ymm21, ymm22, ymm23;
- const Ymm ymm24, ymm25, ymm26, ymm27, ymm28, ymm29, ymm30, ymm31;
- const Zmm zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15;
- const Zmm zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23;
- const Zmm zmm24, zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31;
- const Xmm &xm8, &xm9, &xm10, &xm11, &xm12, &xm13, &xm14, &xm15; // for my convenience
- const Xmm &xm16, &xm17, &xm18, &xm19, &xm20, &xm21, &xm22, &xm23;
- const Xmm &xm24, &xm25, &xm26, &xm27, &xm28, &xm29, &xm30, &xm31;
- const Ymm &ym8, &ym9, &ym10, &ym11, &ym12, &ym13, &ym14, &ym15;
- const Ymm &ym16, &ym17, &ym18, &ym19, &ym20, &ym21, &ym22, &ym23;
- const Ymm &ym24, &ym25, &ym26, &ym27, &ym28, &ym29, &ym30, &ym31;
- const Zmm &zm8, &zm9, &zm10, &zm11, &zm12, &zm13, &zm14, &zm15;
- const Zmm &zm16, &zm17, &zm18, &zm19, &zm20, &zm21, &zm22, &zm23;
- const Zmm &zm24, &zm25, &zm26, &zm27, &zm28, &zm29, &zm30, &zm31;
- const RegRip rip;
-#endif
-#ifndef XBYAK_DISABLE_SEGMENT
- const Segment es, cs, ss, ds, fs, gs;
-#endif
- void L(const std::string& label) { labelMgr_.defineSlabel(label); }
- void L(Label& label) { labelMgr_.defineClabel(label); }
- Label L() { Label label; L(label); return label; }
- void inLocalLabel() { labelMgr_.enterLocal(); }
- void outLocalLabel() { labelMgr_.leaveLocal(); }
- /*
- assign src to dst
- require
- dst : does not used by L()
- src : used by L()
- */
- void assignL(Label& dst, const Label& src) { labelMgr_.assign(dst, src); }
- /*
- put address of label to buffer
- @note the put size is 4(32-bit), 8(64-bit)
- */
- void putL(std::string label) { putL_inner(label); }
- void putL(const Label& label) { putL_inner(label); }
-
- void jmp(const Operand& op) { opR_ModM(op, BIT, 4, 0xFF, NONE, NONE, true); }
- void jmp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0xEB, 0xE9, 0); }
- void jmp(const char *label, LabelType type = T_AUTO) { jmp(std::string(label), type); }
- void jmp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0xEB, 0xE9, 0); }
- void jmp(const void *addr, LabelType type = T_AUTO) { opJmpAbs(addr, type, 0xEB, 0xE9); }
-
- void call(const Operand& op) { opR_ModM(op, 16 | i32e, 2, 0xFF, NONE, NONE, true); }
- // call(string label), not const std::string&
- void call(std::string label) { opJmp(label, T_NEAR, 0, 0xE8, 0); }
- void call(const char *label) { call(std::string(label)); }
- void call(const Label& label) { opJmp(label, T_NEAR, 0, 0xE8, 0); }
- // call(function pointer)
-#ifdef XBYAK_VARIADIC_TEMPLATE
- template<class Ret, class... Params>
- void call(Ret(*func)(Params...)) { call(reinterpret_cast<const void*>(func)); }
-#endif
- void call(const void *addr) { opJmpAbs(addr, T_NEAR, 0, 0xE8); }
-
- void test(const Operand& op, const Reg& reg)
- {
- opModRM(reg, op, op.isREG() && (op.getKind() == reg.getKind()), op.isMEM(), 0x84);
- }
- void test(const Operand& op, uint32 imm)
- {
- verifyMemHasSize(op);
- int immSize = (std::min)(op.getBit() / 8, 4U);
- if (op.isREG() && op.getIdx() == 0) { // al, ax, eax
- rex(op);
- db(0xA8 | (op.isBit(8) ? 0 : 1));
- } else {
- opR_ModM(op, 0, 0, 0xF6, NONE, NONE, false, immSize);
- }
- db(imm, immSize);
- }
- void imul(const Reg& reg, const Operand& op)
- {
- opModRM(reg, op, op.isREG() && (reg.getKind() == op.getKind()), op.isMEM(), 0x0F, 0xAF);
- }
- void imul(const Reg& reg, const Operand& op, int imm)
- {
- int s = inner::IsInDisp8(imm) ? 1 : 0;
- int immSize = s ? 1 : reg.isREG(16) ? 2 : 4;
- opModRM(reg, op, op.isREG() && (reg.getKind() == op.getKind()), op.isMEM(), 0x69 | (s << 1), NONE, NONE, immSize);
- db(imm, immSize);
- }
- void push(const Operand& op) { opPushPop(op, 0xFF, 6, 0x50); }
- void pop(const Operand& op) { opPushPop(op, 0x8F, 0, 0x58); }
- void push(const AddressFrame& af, uint32 imm)
- {
- if (af.bit_ == 8 && inner::IsInDisp8(imm)) {
- db(0x6A); db(imm);
- } else if (af.bit_ == 16 && isInDisp16(imm)) {
- db(0x66); db(0x68); dw(imm);
- } else {
- db(0x68); dd(imm);
- }
- }
- /* use "push(word, 4)" if you want "push word 4" */
- void push(uint32 imm)
- {
- if (inner::IsInDisp8(imm)) {
- push(byte, imm);
- } else {
- push(dword, imm);
- }
- }
- void mov(const Operand& reg1, const Operand& reg2)
- {
- const Reg *reg = 0;
- const Address *addr = 0;
- uint8 code = 0;
- if (reg1.isREG() && reg1.getIdx() == 0 && reg2.isMEM()) { // mov eax|ax|al, [disp]
- reg = &reg1.getReg();
- addr= &reg2.getAddress();
- code = 0xA0;
- } else
- if (reg1.isMEM() && reg2.isREG() && reg2.getIdx() == 0) { // mov [disp], eax|ax|al
- reg = &reg2.getReg();
- addr= &reg1.getAddress();
- code = 0xA2;
- }
-#ifdef XBYAK64
- if (addr && addr->is64bitDisp()) {
- if (code) {
- rex(*reg);
- db(reg1.isREG(8) ? 0xA0 : reg1.isREG() ? 0xA1 : reg2.isREG(8) ? 0xA2 : 0xA3);
- db(addr->getDisp(), 8);
- } else {
- throw Error(ERR_BAD_COMBINATION);
- }
- } else
-#else
- if (code && addr->isOnlyDisp()) {
- rex(*reg, *addr);
- db(code | (reg->isBit(8) ? 0 : 1));
- dd(static_cast<uint32>(addr->getDisp()));
- } else
-#endif
- {
- opRM_RM(reg1, reg2, 0x88);
- }
- }
- void mov(const Operand& op, size_t imm)
- {
- if (op.isREG()) {
- const int size = mov_imm(op.getReg(), imm);
- db(imm, size);
- } else if (op.isMEM()) {
- verifyMemHasSize(op);
- int immSize = op.getBit() / 8;
- if (immSize <= 4) {
- sint64 s = sint64(imm) >> (immSize * 8);
- if (s != 0 && s != -1) throw Error(ERR_IMM_IS_TOO_BIG);
- } else {
- if (!inner::IsInInt32(imm)) throw Error(ERR_IMM_IS_TOO_BIG);
- immSize = 4;
- }
- opModM(op.getAddress(), Reg(0, Operand::REG, op.getBit()), 0xC6, NONE, NONE, immSize);
- db(static_cast<uint32>(imm), immSize);
- } else {
- throw Error(ERR_BAD_COMBINATION);
- }
- }
- void mov(const NativeReg& reg, const char *label) // can't use std::string
- {
- if (label == 0) {
- mov(static_cast<const Operand&>(reg), 0); // call imm
- return;
- }
- mov_imm(reg, dummyAddr);
- putL(label);
- }
- void mov(const NativeReg& reg, const Label& label)
- {
- mov_imm(reg, dummyAddr);
- putL(label);
- }
- void xchg(const Operand& op1, const Operand& op2)
- {
- const Operand *p1 = &op1, *p2 = &op2;
- if (p1->isMEM() || (p2->isREG(16 | i32e) && p2->getIdx() == 0)) {
- p1 = &op2; p2 = &op1;
- }
- if (p1->isMEM()) throw Error(ERR_BAD_COMBINATION);
- if (p2->isREG() && (p1->isREG(16 | i32e) && p1->getIdx() == 0)
-#ifdef XBYAK64
- && (p2->getIdx() != 0 || !p1->isREG(32))
-#endif
- ) {
- rex(*p2, *p1); db(0x90 | (p2->getIdx() & 7));
- return;
- }
- opModRM(*p1, *p2, (p1->isREG() && p2->isREG() && (p1->getBit() == p2->getBit())), p2->isMEM(), 0x86 | (p1->isBit(8) ? 0 : 1));
- }
-
-#ifndef XBYAK_DISABLE_SEGMENT
- void push(const Segment& seg)
- {
- switch (seg.getIdx()) {
- case Segment::es: db(0x06); break;
- case Segment::cs: db(0x0E); break;
- case Segment::ss: db(0x16); break;
- case Segment::ds: db(0x1E); break;
- case Segment::fs: db(0x0F); db(0xA0); break;
- case Segment::gs: db(0x0F); db(0xA8); break;
- default:
- assert(0);
- }
- }
- void pop(const Segment& seg)
- {
- switch (seg.getIdx()) {
- case Segment::es: db(0x07); break;
- case Segment::cs: throw Error(ERR_BAD_COMBINATION);
- case Segment::ss: db(0x17); break;
- case Segment::ds: db(0x1F); break;
- case Segment::fs: db(0x0F); db(0xA1); break;
- case Segment::gs: db(0x0F); db(0xA9); break;
- default:
- assert(0);
- }
- }
- void putSeg(const Segment& seg)
- {
- switch (seg.getIdx()) {
- case Segment::es: db(0x2E); break;
- case Segment::cs: db(0x36); break;
- case Segment::ss: db(0x3E); break;
- case Segment::ds: db(0x26); break;
- case Segment::fs: db(0x64); break;
- case Segment::gs: db(0x65); break;
- default:
- assert(0);
- }
- }
- void mov(const Operand& op, const Segment& seg)
- {
- opModRM(Reg8(seg.getIdx()), op, op.isREG(16|i32e), op.isMEM(), 0x8C);
- }
- void mov(const Segment& seg, const Operand& op)
- {
- opModRM(Reg8(seg.getIdx()), op.isREG(16|i32e) ? static_cast<const Operand&>(op.getReg().cvt32()) : op, op.isREG(16|i32e), op.isMEM(), 0x8E);
- }
-#endif
-
- enum { NONE = 256 };
- // constructor
- CodeGenerator(size_t maxSize = DEFAULT_MAX_CODE_SIZE, void *userPtr = 0, Allocator *allocator = 0)
- : CodeArray(maxSize, userPtr, allocator)
- , mm0(0), mm1(1), mm2(2), mm3(3), mm4(4), mm5(5), mm6(6), mm7(7)
- , xmm0(0), xmm1(1), xmm2(2), xmm3(3), xmm4(4), xmm5(5), xmm6(6), xmm7(7)
- , ymm0(0), ymm1(1), ymm2(2), ymm3(3), ymm4(4), ymm5(5), ymm6(6), ymm7(7)
- , zmm0(0), zmm1(1), zmm2(2), zmm3(3), zmm4(4), zmm5(5), zmm6(6), zmm7(7)
- // for my convenience
- , xm0(xmm0), xm1(xmm1), xm2(xmm2), xm3(xmm3), xm4(xmm4), xm5(xmm5), xm6(xmm6), xm7(xmm7)
- , ym0(ymm0), ym1(ymm1), ym2(ymm2), ym3(ymm3), ym4(ymm4), ym5(ymm5), ym6(ymm6), ym7(ymm7)
- , zm0(zmm0), zm1(zmm1), zm2(zmm2), zm3(zmm3), zm4(zmm4), zm5(zmm5), zm6(zmm6), zm7(zmm7)
-
- , eax(Operand::EAX), ecx(Operand::ECX), edx(Operand::EDX), ebx(Operand::EBX), esp(Operand::ESP), ebp(Operand::EBP), esi(Operand::ESI), edi(Operand::EDI)
- , ax(Operand::AX), cx(Operand::CX), dx(Operand::DX), bx(Operand::BX), sp(Operand::SP), bp(Operand::BP), si(Operand::SI), di(Operand::DI)
- , al(Operand::AL), cl(Operand::CL), dl(Operand::DL), bl(Operand::BL), ah(Operand::AH), ch(Operand::CH), dh(Operand::DH), bh(Operand::BH)
- , ptr(0), byte(8), word(16), dword(32), qword(64), xword(128), yword(256), zword(512)
- , ptr_b(0, true), xword_b(128, true), yword_b(256, true), zword_b(512, true)
- , st0(0), st1(1), st2(2), st3(3), st4(4), st5(5), st6(6), st7(7)
- , k0(0), k1(1), k2(2), k3(3), k4(4), k5(5), k6(6), k7(7)
- , bnd0(0), bnd1(1), bnd2(2), bnd3(3)
- , T_sae(EvexModifierRounding::T_SAE), T_rn_sae(EvexModifierRounding::T_RN_SAE), T_rd_sae(EvexModifierRounding::T_RD_SAE), T_ru_sae(EvexModifierRounding::T_RU_SAE), T_rz_sae(EvexModifierRounding::T_RZ_SAE)
- , T_z()
-#ifdef XBYAK64
- , rax(Operand::RAX), rcx(Operand::RCX), rdx(Operand::RDX), rbx(Operand::RBX), rsp(Operand::RSP), rbp(Operand::RBP), rsi(Operand::RSI), rdi(Operand::RDI), r8(Operand::R8), r9(Operand::R9), r10(Operand::R10), r11(Operand::R11), r12(Operand::R12), r13(Operand::R13), r14(Operand::R14), r15(Operand::R15)
- , r8d(8), r9d(9), r10d(10), r11d(11), r12d(12), r13d(13), r14d(14), r15d(15)
- , r8w(8), r9w(9), r10w(10), r11w(11), r12w(12), r13w(13), r14w(14), r15w(15)
- , r8b(8), r9b(9), r10b(10), r11b(11), r12b(12), r13b(13), r14b(14), r15b(15)
- , spl(Operand::SPL, true), bpl(Operand::BPL, true), sil(Operand::SIL, true), dil(Operand::DIL, true)
- , xmm8(8), xmm9(9), xmm10(10), xmm11(11), xmm12(12), xmm13(13), xmm14(14), xmm15(15)
- , xmm16(16), xmm17(17), xmm18(18), xmm19(19), xmm20(20), xmm21(21), xmm22(22), xmm23(23)
- , xmm24(24), xmm25(25), xmm26(26), xmm27(27), xmm28(28), xmm29(29), xmm30(30), xmm31(31)
- , ymm8(8), ymm9(9), ymm10(10), ymm11(11), ymm12(12), ymm13(13), ymm14(14), ymm15(15)
- , ymm16(16), ymm17(17), ymm18(18), ymm19(19), ymm20(20), ymm21(21), ymm22(22), ymm23(23)
- , ymm24(24), ymm25(25), ymm26(26), ymm27(27), ymm28(28), ymm29(29), ymm30(30), ymm31(31)
- , zmm8(8), zmm9(9), zmm10(10), zmm11(11), zmm12(12), zmm13(13), zmm14(14), zmm15(15)
- , zmm16(16), zmm17(17), zmm18(18), zmm19(19), zmm20(20), zmm21(21), zmm22(22), zmm23(23)
- , zmm24(24), zmm25(25), zmm26(26), zmm27(27), zmm28(28), zmm29(29), zmm30(30), zmm31(31)
- // for my convenience
- , xm8(xmm8), xm9(xmm9), xm10(xmm10), xm11(xmm11), xm12(xmm12), xm13(xmm13), xm14(xmm14), xm15(xmm15)
- , xm16(xmm16), xm17(xmm17), xm18(xmm18), xm19(xmm19), xm20(xmm20), xm21(xmm21), xm22(xmm22), xm23(xmm23)
- , xm24(xmm24), xm25(xmm25), xm26(xmm26), xm27(xmm27), xm28(xmm28), xm29(xmm29), xm30(xmm30), xm31(xmm31)
- , ym8(ymm8), ym9(ymm9), ym10(ymm10), ym11(ymm11), ym12(ymm12), ym13(ymm13), ym14(ymm14), ym15(ymm15)
- , ym16(ymm16), ym17(ymm17), ym18(ymm18), ym19(ymm19), ym20(ymm20), ym21(ymm21), ym22(ymm22), ym23(ymm23)
- , ym24(ymm24), ym25(ymm25), ym26(ymm26), ym27(ymm27), ym28(ymm28), ym29(ymm29), ym30(ymm30), ym31(ymm31)
- , zm8(zmm8), zm9(zmm9), zm10(zmm10), zm11(zmm11), zm12(zmm12), zm13(zmm13), zm14(zmm14), zm15(zmm15)
- , zm16(zmm16), zm17(zmm17), zm18(zmm18), zm19(zmm19), zm20(zmm20), zm21(zmm21), zm22(zmm22), zm23(zmm23)
- , zm24(zmm24), zm25(zmm25), zm26(zmm26), zm27(zmm27), zm28(zmm28), zm29(zmm29), zm30(zmm30), zm31(zmm31)
- , rip()
-#endif
-#ifndef XBYAK_DISABLE_SEGMENT
- , es(Segment::es), cs(Segment::cs), ss(Segment::ss), ds(Segment::ds), fs(Segment::fs), gs(Segment::gs)
-#endif
- {
- labelMgr_.set(this);
- }
- void reset()
- {
- resetSize();
- labelMgr_.reset();
- labelMgr_.set(this);
- }
- bool hasUndefinedLabel() const { return labelMgr_.hasUndefSlabel() || labelMgr_.hasUndefClabel(); }
- /*
- MUST call ready() to complete generating code if you use AutoGrow mode.
- It is not necessary for the other mode if hasUndefinedLabel() is true.
- */
- void ready(ProtectMode mode = PROTECT_RWE)
- {
- if (hasUndefinedLabel()) throw Error(ERR_LABEL_IS_NOT_FOUND);
- if (isAutoGrow()) {
- calcJmpAddress();
- if (useProtect()) setProtectMode(mode);
- }
- }
- // set read/exec
- void readyRE() { return ready(PROTECT_RE); }
-#ifdef XBYAK_TEST
- void dump(bool doClear = true)
- {
- CodeArray::dump();
- if (doClear) size_ = 0;
- }
-#endif
-
-#ifdef XBYAK_UNDEF_JNL
- #undef jnl
-#endif
-
- /*
- use single byte nop if useMultiByteNop = false
- */
- void nop(size_t size = 1, bool useMultiByteNop = true)
- {
- if (!useMultiByteNop) {
- for (size_t i = 0; i < size; i++) {
- db(0x90);
- }
- return;
- }
- /*
- Intel Architectures Software Developer's Manual Volume 2
- recommended multi-byte sequence of NOP instruction
- AMD and Intel seem to agree on the same sequences for up to 9 bytes:
- https://support.amd.com/TechDocs/55723_SOG_Fam_17h_Processors_3.00.pdf
- */
- static const uint8 nopTbl[9][9] = {
- {0x90},
- {0x66, 0x90},
- {0x0F, 0x1F, 0x00},
- {0x0F, 0x1F, 0x40, 0x00},
- {0x0F, 0x1F, 0x44, 0x00, 0x00},
- {0x66, 0x0F, 0x1F, 0x44, 0x00, 0x00},
- {0x0F, 0x1F, 0x80, 0x00, 0x00, 0x00, 0x00},
- {0x0F, 0x1F, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00},
- {0x66, 0x0F, 0x1F, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00},
- };
- const size_t n = sizeof(nopTbl) / sizeof(nopTbl[0]);
- while (size > 0) {
- size_t len = (std::min)(n, size);
- const uint8 *seq = nopTbl[len - 1];
- db(seq, len);
- size -= len;
- }
- }
-
-#ifndef XBYAK_DONT_READ_LIST
-#include "xbyak_mnemonic.h"
- /*
- use single byte nop if useMultiByteNop = false
- */
- void align(size_t x = 16, bool useMultiByteNop = true)
- {
- if (x == 1) return;
- if (x < 1 || (x & (x - 1))) throw Error(ERR_BAD_ALIGN);
- if (isAutoGrow() && x > inner::ALIGN_PAGE_SIZE) fprintf(stderr, "warning:autoGrow mode does not support %d align\n", (int)x);
- size_t remain = size_t(getCurr()) % x;
- if (remain) {
- nop(x - remain, useMultiByteNop);
- }
- }
-#endif
-};
-
-namespace util {
-static const Mmx mm0(0), mm1(1), mm2(2), mm3(3), mm4(4), mm5(5), mm6(6), mm7(7);
-static const Xmm xmm0(0), xmm1(1), xmm2(2), xmm3(3), xmm4(4), xmm5(5), xmm6(6), xmm7(7);
-static const Ymm ymm0(0), ymm1(1), ymm2(2), ymm3(3), ymm4(4), ymm5(5), ymm6(6), ymm7(7);
-static const Zmm zmm0(0), zmm1(1), zmm2(2), zmm3(3), zmm4(4), zmm5(5), zmm6(6), zmm7(7);
-static const Reg32 eax(Operand::EAX), ecx(Operand::ECX), edx(Operand::EDX), ebx(Operand::EBX), esp(Operand::ESP), ebp(Operand::EBP), esi(Operand::ESI), edi(Operand::EDI);
-static const Reg16 ax(Operand::AX), cx(Operand::CX), dx(Operand::DX), bx(Operand::BX), sp(Operand::SP), bp(Operand::BP), si(Operand::SI), di(Operand::DI);
-static const Reg8 al(Operand::AL), cl(Operand::CL), dl(Operand::DL), bl(Operand::BL), ah(Operand::AH), ch(Operand::CH), dh(Operand::DH), bh(Operand::BH);
-static const AddressFrame ptr(0), byte(8), word(16), dword(32), qword(64), xword(128), yword(256), zword(512);
-static const AddressFrame ptr_b(0, true), xword_b(128, true), yword_b(256, true), zword_b(512, true);
-static const Fpu st0(0), st1(1), st2(2), st3(3), st4(4), st5(5), st6(6), st7(7);
-static const Opmask k0(0), k1(1), k2(2), k3(3), k4(4), k5(5), k6(6), k7(7);
-static const BoundsReg bnd0(0), bnd1(1), bnd2(2), bnd3(3);
-static const EvexModifierRounding T_sae(EvexModifierRounding::T_SAE), T_rn_sae(EvexModifierRounding::T_RN_SAE), T_rd_sae(EvexModifierRounding::T_RD_SAE), T_ru_sae(EvexModifierRounding::T_RU_SAE), T_rz_sae(EvexModifierRounding::T_RZ_SAE);
-static const EvexModifierZero T_z;
-#ifdef XBYAK64
-static const Reg64 rax(Operand::RAX), rcx(Operand::RCX), rdx(Operand::RDX), rbx(Operand::RBX), rsp(Operand::RSP), rbp(Operand::RBP), rsi(Operand::RSI), rdi(Operand::RDI), r8(Operand::R8), r9(Operand::R9), r10(Operand::R10), r11(Operand::R11), r12(Operand::R12), r13(Operand::R13), r14(Operand::R14), r15(Operand::R15);
-static const Reg32 r8d(8), r9d(9), r10d(10), r11d(11), r12d(12), r13d(13), r14d(14), r15d(15);
-static const Reg16 r8w(8), r9w(9), r10w(10), r11w(11), r12w(12), r13w(13), r14w(14), r15w(15);
-static const Reg8 r8b(8), r9b(9), r10b(10), r11b(11), r12b(12), r13b(13), r14b(14), r15b(15), spl(Operand::SPL, true), bpl(Operand::BPL, true), sil(Operand::SIL, true), dil(Operand::DIL, true);
-static const Xmm xmm8(8), xmm9(9), xmm10(10), xmm11(11), xmm12(12), xmm13(13), xmm14(14), xmm15(15);
-static const Xmm xmm16(16), xmm17(17), xmm18(18), xmm19(19), xmm20(20), xmm21(21), xmm22(22), xmm23(23);
-static const Xmm xmm24(24), xmm25(25), xmm26(26), xmm27(27), xmm28(28), xmm29(29), xmm30(30), xmm31(31);
-static const Ymm ymm8(8), ymm9(9), ymm10(10), ymm11(11), ymm12(12), ymm13(13), ymm14(14), ymm15(15);
-static const Ymm ymm16(16), ymm17(17), ymm18(18), ymm19(19), ymm20(20), ymm21(21), ymm22(22), ymm23(23);
-static const Ymm ymm24(24), ymm25(25), ymm26(26), ymm27(27), ymm28(28), ymm29(29), ymm30(30), ymm31(31);
-static const Zmm zmm8(8), zmm9(9), zmm10(10), zmm11(11), zmm12(12), zmm13(13), zmm14(14), zmm15(15);
-static const Zmm zmm16(16), zmm17(17), zmm18(18), zmm19(19), zmm20(20), zmm21(21), zmm22(22), zmm23(23);
-static const Zmm zmm24(24), zmm25(25), zmm26(26), zmm27(27), zmm28(28), zmm29(29), zmm30(30), zmm31(31);
-static const RegRip rip;
-#endif
-#ifndef XBYAK_DISABLE_SEGMENT
-static const Segment es(Segment::es), cs(Segment::cs), ss(Segment::ss), ds(Segment::ds), fs(Segment::fs), gs(Segment::gs);
-#endif
-} // util
-
-#ifdef _MSC_VER
- #pragma warning(pop)
-#endif
-
-} // end of namespace
-
-#endif // XBYAK_XBYAK_H_
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h
deleted file mode 100644
index a22e5224c3..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h
+++ /dev/null
@@ -1,303 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2019 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-/*******************************************************************************
-* Copyright (c) 2007 MITSUNARI Shigeo
-* All rights reserved.
-*
-* Redistribution and use in source and binary forms, with or without
-* modification, are permitted provided that the following conditions are met:
-*
-* Redistributions of source code must retain the above copyright notice, this
-* list of conditions and the following disclaimer.
-* Redistributions in binary form must reproduce the above copyright notice,
-* this list of conditions and the following disclaimer in the documentation
-* and/or other materials provided with the distribution.
-* Neither the name of the copyright owner nor the names of its contributors may
-* be used to endorse or promote products derived from this software without
-* specific prior written permission.
-*
-* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
-* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
-* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
-* THE POSSIBILITY OF SUCH DAMAGE.
-*******************************************************************************/
-
-enum {
- B00000000= 0,
- B00000001= 1,
- B00000010= 2,
- B00000011= 3,
- B00000100= 4,
- B00000101= 5,
- B00000110= 6,
- B00000111= 7,
- B00001000= 8,
- B00001001= 9,
- B00001010= 10,
- B00001011= 11,
- B00001100= 12,
- B00001101= 13,
- B00001110= 14,
- B00001111= 15,
- B00010000= 16,
- B00010001= 17,
- B00010010= 18,
- B00010011= 19,
- B00010100= 20,
- B00010101= 21,
- B00010110= 22,
- B00010111= 23,
- B00011000= 24,
- B00011001= 25,
- B00011010= 26,
- B00011011= 27,
- B00011100= 28,
- B00011101= 29,
- B00011110= 30,
- B00011111= 31,
- B00100000= 32,
- B00100001= 33,
- B00100010= 34,
- B00100011= 35,
- B00100100= 36,
- B00100101= 37,
- B00100110= 38,
- B00100111= 39,
- B00101000= 40,
- B00101001= 41,
- B00101010= 42,
- B00101011= 43,
- B00101100= 44,
- B00101101= 45,
- B00101110= 46,
- B00101111= 47,
- B00110000= 48,
- B00110001= 49,
- B00110010= 50,
- B00110011= 51,
- B00110100= 52,
- B00110101= 53,
- B00110110= 54,
- B00110111= 55,
- B00111000= 56,
- B00111001= 57,
- B00111010= 58,
- B00111011= 59,
- B00111100= 60,
- B00111101= 61,
- B00111110= 62,
- B00111111= 63,
- B01000000= 64,
- B01000001= 65,
- B01000010= 66,
- B01000011= 67,
- B01000100= 68,
- B01000101= 69,
- B01000110= 70,
- B01000111= 71,
- B01001000= 72,
- B01001001= 73,
- B01001010= 74,
- B01001011= 75,
- B01001100= 76,
- B01001101= 77,
- B01001110= 78,
- B01001111= 79,
- B01010000= 80,
- B01010001= 81,
- B01010010= 82,
- B01010011= 83,
- B01010100= 84,
- B01010101= 85,
- B01010110= 86,
- B01010111= 87,
- B01011000= 88,
- B01011001= 89,
- B01011010= 90,
- B01011011= 91,
- B01011100= 92,
- B01011101= 93,
- B01011110= 94,
- B01011111= 95,
- B01100000= 96,
- B01100001= 97,
- B01100010= 98,
- B01100011= 99,
- B01100100= 100,
- B01100101= 101,
- B01100110= 102,
- B01100111= 103,
- B01101000= 104,
- B01101001= 105,
- B01101010= 106,
- B01101011= 107,
- B01101100= 108,
- B01101101= 109,
- B01101110= 110,
- B01101111= 111,
- B01110000= 112,
- B01110001= 113,
- B01110010= 114,
- B01110011= 115,
- B01110100= 116,
- B01110101= 117,
- B01110110= 118,
- B01110111= 119,
- B01111000= 120,
- B01111001= 121,
- B01111010= 122,
- B01111011= 123,
- B01111100= 124,
- B01111101= 125,
- B01111110= 126,
- B01111111= 127,
- B10000000= 128,
- B10000001= 129,
- B10000010= 130,
- B10000011= 131,
- B10000100= 132,
- B10000101= 133,
- B10000110= 134,
- B10000111= 135,
- B10001000= 136,
- B10001001= 137,
- B10001010= 138,
- B10001011= 139,
- B10001100= 140,
- B10001101= 141,
- B10001110= 142,
- B10001111= 143,
- B10010000= 144,
- B10010001= 145,
- B10010010= 146,
- B10010011= 147,
- B10010100= 148,
- B10010101= 149,
- B10010110= 150,
- B10010111= 151,
- B10011000= 152,
- B10011001= 153,
- B10011010= 154,
- B10011011= 155,
- B10011100= 156,
- B10011101= 157,
- B10011110= 158,
- B10011111= 159,
- B10100000= 160,
- B10100001= 161,
- B10100010= 162,
- B10100011= 163,
- B10100100= 164,
- B10100101= 165,
- B10100110= 166,
- B10100111= 167,
- B10101000= 168,
- B10101001= 169,
- B10101010= 170,
- B10101011= 171,
- B10101100= 172,
- B10101101= 173,
- B10101110= 174,
- B10101111= 175,
- B10110000= 176,
- B10110001= 177,
- B10110010= 178,
- B10110011= 179,
- B10110100= 180,
- B10110101= 181,
- B10110110= 182,
- B10110111= 183,
- B10111000= 184,
- B10111001= 185,
- B10111010= 186,
- B10111011= 187,
- B10111100= 188,
- B10111101= 189,
- B10111110= 190,
- B10111111= 191,
- B11000000= 192,
- B11000001= 193,
- B11000010= 194,
- B11000011= 195,
- B11000100= 196,
- B11000101= 197,
- B11000110= 198,
- B11000111= 199,
- B11001000= 200,
- B11001001= 201,
- B11001010= 202,
- B11001011= 203,
- B11001100= 204,
- B11001101= 205,
- B11001110= 206,
- B11001111= 207,
- B11010000= 208,
- B11010001= 209,
- B11010010= 210,
- B11010011= 211,
- B11010100= 212,
- B11010101= 213,
- B11010110= 214,
- B11010111= 215,
- B11011000= 216,
- B11011001= 217,
- B11011010= 218,
- B11011011= 219,
- B11011100= 220,
- B11011101= 221,
- B11011110= 222,
- B11011111= 223,
- B11100000= 224,
- B11100001= 225,
- B11100010= 226,
- B11100011= 227,
- B11100100= 228,
- B11100101= 229,
- B11100110= 230,
- B11100111= 231,
- B11101000= 232,
- B11101001= 233,
- B11101010= 234,
- B11101011= 235,
- B11101100= 236,
- B11101101= 237,
- B11101110= 238,
- B11101111= 239,
- B11110000= 240,
- B11110001= 241,
- B11110010= 242,
- B11110011= 243,
- B11110100= 244,
- B11110101= 245,
- B11110110= 246,
- B11110111= 247,
- B11111000= 248,
- B11111001= 249,
- B11111010= 250,
- B11111011= 251,
- B11111100= 252,
- B11111101= 253,
- B11111110= 254,
- B11111111= 255
-};
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h
deleted file mode 100644
index 28d2d222f9..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h
+++ /dev/null
@@ -1,2017 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2019 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-/*******************************************************************************
-* Copyright (c) 2007 MITSUNARI Shigeo
-* All rights reserved.
-*
-* Redistribution and use in source and binary forms, with or without
-* modification, are permitted provided that the following conditions are met:
-*
-* Redistributions of source code must retain the above copyright notice, this
-* list of conditions and the following disclaimer.
-* Redistributions in binary form must reproduce the above copyright notice,
-* this list of conditions and the following disclaimer in the documentation
-* and/or other materials provided with the distribution.
-* Neither the name of the copyright owner nor the names of its contributors may
-* be used to endorse or promote products derived from this software without
-* specific prior written permission.
-*
-* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
-* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
-* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
-* THE POSSIBILITY OF SUCH DAMAGE.
-*******************************************************************************/
-
-const char *getVersionString() const { return "5.76"; }
-void adc(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x10, 2); }
-void adc(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x10); }
-void adcx(const Reg32e& reg, const Operand& op) { opGen(reg, op, 0xF6, 0x66, isREG32_REG32orMEM, NONE, 0x38); }
-void add(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x00, 0); }
-void add(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x00); }
-void addpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0x66, isXMM_XMMorMEM); }
-void addps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0x100, isXMM_XMMorMEM); }
-void addsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0xF2, isXMM_XMMorMEM); }
-void addss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0xF3, isXMM_XMMorMEM); }
-void addsubpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xD0, 0x66, isXMM_XMMorMEM); }
-void addsubps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xD0, 0xF2, isXMM_XMMorMEM); }
-void adox(const Reg32e& reg, const Operand& op) { opGen(reg, op, 0xF6, 0xF3, isREG32_REG32orMEM, NONE, 0x38); }
-void aesdec(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDE, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void aesdeclast(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDF, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void aesenc(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDC, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void aesenclast(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDD, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void aesimc(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDB, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void aeskeygenassist(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0xDF, 0x66, isXMM_XMMorMEM, imm, 0x3A); }
-void and_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x20, 4); }
-void and_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x20); }
-void andn(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_0F38, 0xf2, true); }
-void andnpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x55, 0x66, isXMM_XMMorMEM); }
-void andnps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x55, 0x100, isXMM_XMMorMEM); }
-void andpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x54, 0x66, isXMM_XMMorMEM); }
-void andps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x54, 0x100, isXMM_XMMorMEM); }
-void bextr(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_0F38, 0xf7, false); }
-void blendpd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0D, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
-void blendps(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0C, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
-void blendvpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void blendvps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void blsi(const Reg32e& r, const Operand& op) { opGpr(Reg32e(3, r.getBit()), op, r, T_0F38, 0xf3, false); }
-void blsmsk(const Reg32e& r, const Operand& op) { opGpr(Reg32e(2, r.getBit()), op, r, T_0F38, 0xf3, false); }
-void blsr(const Reg32e& r, const Operand& op) { opGpr(Reg32e(1, r.getBit()), op, r, T_0F38, 0xf3, false); }
-void bnd() { db(0xF2); }
-void bndcl(const BoundsReg& bnd, const Operand& op) { db(0xF3); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1A, NONE, !op.isMEM()); }
-void bndcn(const BoundsReg& bnd, const Operand& op) { db(0xF2); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1B, NONE, !op.isMEM()); }
-void bndcu(const BoundsReg& bnd, const Operand& op) { db(0xF2); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1A, NONE, !op.isMEM()); }
-void bndldx(const BoundsReg& bnd, const Address& addr) { opMIB(addr, bnd, 0x0F, 0x1A); }
-void bndmk(const BoundsReg& bnd, const Address& addr) { db(0xF3); opModM(addr, bnd, 0x0F, 0x1B); }
-void bndmov(const Address& addr, const BoundsReg& bnd) { db(0x66); opModM(addr, bnd, 0x0F, 0x1B); }
-void bndmov(const BoundsReg& bnd, const Operand& op) { db(0x66); opModRM(bnd, op, op.isBNDREG(), op.isMEM(), 0x0F, 0x1A); }
-void bndstx(const Address& addr, const BoundsReg& bnd) { opMIB(addr, bnd, 0x0F, 0x1B); }
-void bsf(const Reg&reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0xBC); }
-void bsr(const Reg&reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0xBD); }
-void bswap(const Reg32e& reg) { opModR(Reg32(1), reg, 0x0F); }
-void bt(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xA3); }
-void bt(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 4, 0x0f, 0xba, NONE, false, 1); db(imm); }
-void btc(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xBB); }
-void btc(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 7, 0x0f, 0xba, NONE, false, 1); db(imm); }
-void btr(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xB3); }
-void btr(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 6, 0x0f, 0xba, NONE, false, 1); db(imm); }
-void bts(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xAB); }
-void bts(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 5, 0x0f, 0xba, NONE, false, 1); db(imm); }
-void bzhi(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_0F38, 0xf5, false); }
-void cbw() { db(0x66); db(0x98); }
-void cdq() { db(0x99); }
-void clc() { db(0xF8); }
-void cld() { db(0xFC); }
-void clflush(const Address& addr) { opModM(addr, Reg32(7), 0x0F, 0xAE); }
-void cli() { db(0xFA); }
-void cmc() { db(0xF5); }
-void cmova(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 7); }//-V524
-void cmovae(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524
-void cmovb(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524
-void cmovbe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 6); }//-V524
-void cmovc(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524
-void cmove(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 4); }//-V524
-void cmovg(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 15); }//-V524
-void cmovge(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 13); }//-V524
-void cmovl(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 12); }//-V524
-void cmovle(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 14); }//-V524
-void cmovna(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 6); }//-V524
-void cmovnae(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524
-void cmovnb(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524
-void cmovnbe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 7); }//-V524
-void cmovnc(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524
-void cmovne(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 5); }//-V524
-void cmovng(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 14); }//-V524
-void cmovnge(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 12); }//-V524
-void cmovnl(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 13); }//-V524
-void cmovnle(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 15); }//-V524
-void cmovno(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 1); }//-V524
-void cmovnp(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 11); }//-V524
-void cmovns(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 9); }//-V524
-void cmovnz(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 5); }//-V524
-void cmovo(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 0); }//-V524
-void cmovp(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 10); }//-V524
-void cmovpe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 10); }//-V524
-void cmovpo(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 11); }//-V524
-void cmovs(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 8); }//-V524
-void cmovz(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 4); }//-V524
-void cmp(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x38, 7); }
-void cmp(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x38); }
-void cmpeqpd(const Xmm& x, const Operand& op) { cmppd(x, op, 0); }
-void cmpeqps(const Xmm& x, const Operand& op) { cmpps(x, op, 0); }
-void cmpeqsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 0); }
-void cmpeqss(const Xmm& x, const Operand& op) { cmpss(x, op, 0); }
-void cmplepd(const Xmm& x, const Operand& op) { cmppd(x, op, 2); }
-void cmpleps(const Xmm& x, const Operand& op) { cmpps(x, op, 2); }
-void cmplesd(const Xmm& x, const Operand& op) { cmpsd(x, op, 2); }
-void cmpless(const Xmm& x, const Operand& op) { cmpss(x, op, 2); }
-void cmpltpd(const Xmm& x, const Operand& op) { cmppd(x, op, 1); }
-void cmpltps(const Xmm& x, const Operand& op) { cmpps(x, op, 1); }
-void cmpltsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 1); }
-void cmpltss(const Xmm& x, const Operand& op) { cmpss(x, op, 1); }
-void cmpneqpd(const Xmm& x, const Operand& op) { cmppd(x, op, 4); }
-void cmpneqps(const Xmm& x, const Operand& op) { cmpps(x, op, 4); }
-void cmpneqsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 4); }
-void cmpneqss(const Xmm& x, const Operand& op) { cmpss(x, op, 4); }
-void cmpnlepd(const Xmm& x, const Operand& op) { cmppd(x, op, 6); }
-void cmpnleps(const Xmm& x, const Operand& op) { cmpps(x, op, 6); }
-void cmpnlesd(const Xmm& x, const Operand& op) { cmpsd(x, op, 6); }
-void cmpnless(const Xmm& x, const Operand& op) { cmpss(x, op, 6); }
-void cmpnltpd(const Xmm& x, const Operand& op) { cmppd(x, op, 5); }
-void cmpnltps(const Xmm& x, const Operand& op) { cmpps(x, op, 5); }
-void cmpnltsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 5); }
-void cmpnltss(const Xmm& x, const Operand& op) { cmpss(x, op, 5); }
-void cmpordpd(const Xmm& x, const Operand& op) { cmppd(x, op, 7); }
-void cmpordps(const Xmm& x, const Operand& op) { cmpps(x, op, 7); }
-void cmpordsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 7); }
-void cmpordss(const Xmm& x, const Operand& op) { cmpss(x, op, 7); }
-void cmppd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0x66, isXMM_XMMorMEM, imm8); }
-void cmpps(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0x100, isXMM_XMMorMEM, imm8); }
-void cmpsb() { db(0xA6); }
-void cmpsd() { db(0xA7); }
-void cmpsd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0xF2, isXMM_XMMorMEM, imm8); }
-void cmpss(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0xF3, isXMM_XMMorMEM, imm8); }
-void cmpsw() { db(0x66); db(0xA7); }
-void cmpunordpd(const Xmm& x, const Operand& op) { cmppd(x, op, 3); }
-void cmpunordps(const Xmm& x, const Operand& op) { cmpps(x, op, 3); }
-void cmpunordsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 3); }
-void cmpunordss(const Xmm& x, const Operand& op) { cmpss(x, op, 3); }
-void cmpxchg(const Operand& op, const Reg& reg) { opModRM(reg, op, (op.isREG() && reg.isREG() && op.getBit() == reg.getBit()), op.isMEM(), 0x0F, 0xB0 | (reg.isBit(8) ? 0 : 1)); }
-void cmpxchg8b(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0xC7); }
-void comisd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2F, 0x66, isXMM_XMMorMEM); }
-void comiss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2F, 0x100, isXMM_XMMorMEM); }
-void cpuid() { db(0x0F); db(0xA2); }
-void crc32(const Reg32e& reg, const Operand& op) { if (reg.isBit(32) && op.isBit(16)) db(0x66); db(0xF2); opModRM(reg, op, op.isREG(), op.isMEM(), 0x0F, 0x38, 0xF0 | (op.isBit(8) ? 0 : 1)); }
-void cvtdq2pd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0xF3, isXMM_XMMorMEM); }
-void cvtdq2ps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0x100, isXMM_XMMorMEM); }
-void cvtpd2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0xF2, isXMM_XMMorMEM); }
-void cvtpd2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0x66, isMMX_XMMorMEM); }
-void cvtpd2ps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0x66, isXMM_XMMorMEM); }
-void cvtpi2pd(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0x66, isXMM_MMXorMEM); }
-void cvtpi2ps(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0x100, isXMM_MMXorMEM); }
-void cvtps2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0x66, isXMM_XMMorMEM); }
-void cvtps2pd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0x100, isXMM_XMMorMEM); }
-void cvtps2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0x100, isMMX_XMMorMEM); }
-void cvtsd2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0xF2, isREG32_XMMorMEM); }
-void cvtsd2ss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0xF2, isXMM_XMMorMEM); }
-void cvtsi2sd(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0xF2, isXMM_REG32orMEM); }
-void cvtsi2ss(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0xF3, isXMM_REG32orMEM); }
-void cvtss2sd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0xF3, isXMM_XMMorMEM); }
-void cvtss2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0xF3, isREG32_XMMorMEM); }
-void cvttpd2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0x66, isXMM_XMMorMEM); }
-void cvttpd2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0x66, isMMX_XMMorMEM); }
-void cvttps2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0xF3, isXMM_XMMorMEM); }
-void cvttps2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0x100, isMMX_XMMorMEM); }
-void cvttsd2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0xF2, isREG32_XMMorMEM); }
-void cvttss2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0xF3, isREG32_XMMorMEM); }
-void cwd() { db(0x66); db(0x99); }
-void cwde() { db(0x98); }
-void dec(const Operand& op) { opIncDec(op, 0x48, 1); }
-void div(const Operand& op) { opR_ModM(op, 0, 6, 0xF6); }
-void divpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0x66, isXMM_XMMorMEM); }
-void divps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0x100, isXMM_XMMorMEM); }
-void divsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0xF2, isXMM_XMMorMEM); }
-void divss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0xF3, isXMM_XMMorMEM); }
-void dppd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x41, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
-void dpps(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x40, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
-void emms() { db(0x0F); db(0x77); }
-void extractps(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x17, imm); }
-void f2xm1() { db(0xD9); db(0xF0); }
-void fabs() { db(0xD9); db(0xE1); }
-void fadd(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 0, 0); }
-void fadd(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8C0, 0xDCC0); }
-void fadd(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8C0, 0xDCC0); }
-void faddp() { db(0xDE); db(0xC1); }
-void faddp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEC0); }
-void faddp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEC0); }
-void fchs() { db(0xD9); db(0xE0); }
-void fcmovb(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAC0, 0x00C0); }
-void fcmovb(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAC0, 0x00C0); }
-void fcmovbe(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAD0, 0x00D0); }
-void fcmovbe(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAD0, 0x00D0); }
-void fcmove(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAC8, 0x00C8); }
-void fcmove(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAC8, 0x00C8); }
-void fcmovnb(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBC0, 0x00C0); }
-void fcmovnb(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBC0, 0x00C0); }
-void fcmovnbe(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBD0, 0x00D0); }
-void fcmovnbe(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBD0, 0x00D0); }
-void fcmovne(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBC8, 0x00C8); }
-void fcmovne(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBC8, 0x00C8); }
-void fcmovnu(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBD8, 0x00D8); }
-void fcmovnu(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBD8, 0x00D8); }
-void fcmovu(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAD8, 0x00D8); }
-void fcmovu(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAD8, 0x00D8); }
-void fcom() { db(0xD8); db(0xD1); }
-void fcom(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 2, 0); }
-void fcom(const Fpu& reg) { opFpu(reg, 0xD8, 0xD0); }
-void fcomi(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBF0, 0x00F0); }
-void fcomi(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBF0, 0x00F0); }
-void fcomip(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDFF0, 0x00F0); }
-void fcomip(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDFF0, 0x00F0); }
-void fcomp() { db(0xD8); db(0xD9); }
-void fcomp(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 3, 0); }
-void fcomp(const Fpu& reg) { opFpu(reg, 0xD8, 0xD8); }
-void fcompp() { db(0xDE); db(0xD9); }
-void fcos() { db(0xD9); db(0xFF); }
-void fdecstp() { db(0xD9); db(0xF6); }
-void fdiv(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 6, 0); }
-void fdiv(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8F0, 0xDCF8); }
-void fdiv(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8F0, 0xDCF8); }
-void fdivp() { db(0xDE); db(0xF9); }
-void fdivp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEF8); }
-void fdivp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEF8); }
-void fdivr(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 7, 0); }
-void fdivr(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8F8, 0xDCF0); }
-void fdivr(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8F8, 0xDCF0); }
-void fdivrp() { db(0xDE); db(0xF1); }
-void fdivrp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEF0); }
-void fdivrp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEF0); }
-void ffree(const Fpu& reg) { opFpu(reg, 0xDD, 0xC0); }
-void fiadd(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 0, 0); }
-void ficom(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 2, 0); }
-void ficomp(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 3, 0); }
-void fidiv(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 6, 0); }
-void fidivr(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 7, 0); }
-void fild(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDF, 0, 5); }
-void fimul(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 1, 0); }
-void fincstp() { db(0xD9); db(0xF7); }
-void finit() { db(0x9B); db(0xDB); db(0xE3); }
-void fist(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0x00, 2, 0); }
-void fistp(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDF, 3, 7); }
-void fisttp(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDD, 1, 0); }
-void fisub(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 4, 0); }
-void fisubr(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 5, 0); }
-void fld(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 0, 0); }
-void fld(const Fpu& reg) { opFpu(reg, 0xD9, 0xC0); }
-void fld1() { db(0xD9); db(0xE8); }
-void fldcw(const Address& addr) { opModM(addr, Reg32(5), 0xD9, 0x100); }
-void fldl2e() { db(0xD9); db(0xEA); }
-void fldl2t() { db(0xD9); db(0xE9); }
-void fldlg2() { db(0xD9); db(0xEC); }
-void fldln2() { db(0xD9); db(0xED); }
-void fldpi() { db(0xD9); db(0xEB); }
-void fldz() { db(0xD9); db(0xEE); }
-void fmul(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 1, 0); }
-void fmul(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8C8, 0xDCC8); }
-void fmul(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8C8, 0xDCC8); }
-void fmulp() { db(0xDE); db(0xC9); }
-void fmulp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEC8); }
-void fmulp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEC8); }
-void fninit() { db(0xDB); db(0xE3); }
-void fnop() { db(0xD9); db(0xD0); }
-void fpatan() { db(0xD9); db(0xF3); }
-void fprem() { db(0xD9); db(0xF8); }
-void fprem1() { db(0xD9); db(0xF5); }
-void fptan() { db(0xD9); db(0xF2); }
-void frndint() { db(0xD9); db(0xFC); }
-void fscale() { db(0xD9); db(0xFD); }
-void fsin() { db(0xD9); db(0xFE); }
-void fsincos() { db(0xD9); db(0xFB); }
-void fsqrt() { db(0xD9); db(0xFA); }
-void fst(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 2, 0); }
-void fst(const Fpu& reg) { opFpu(reg, 0xDD, 0xD0); }
-void fstcw(const Address& addr) { db(0x9B); opModM(addr, Reg32(7), 0xD9, NONE); }
-void fstp(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 3, 0); }
-void fstp(const Fpu& reg) { opFpu(reg, 0xDD, 0xD8); }
-void fsub(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 4, 0); }
-void fsub(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8E0, 0xDCE8); }
-void fsub(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8E0, 0xDCE8); }
-void fsubp() { db(0xDE); db(0xE9); }
-void fsubp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEE8); }
-void fsubp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEE8); }
-void fsubr(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 5, 0); }
-void fsubr(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8E8, 0xDCE0); }
-void fsubr(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8E8, 0xDCE0); }
-void fsubrp() { db(0xDE); db(0xE1); }
-void fsubrp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEE0); }
-void fsubrp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEE0); }
-void ftst() { db(0xD9); db(0xE4); }
-void fucom() { db(0xDD); db(0xE1); }
-void fucom(const Fpu& reg) { opFpu(reg, 0xDD, 0xE0); }
-void fucomi(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBE8, 0x00E8); }
-void fucomi(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBE8, 0x00E8); }
-void fucomip(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDFE8, 0x00E8); }
-void fucomip(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDFE8, 0x00E8); }
-void fucomp() { db(0xDD); db(0xE9); }
-void fucomp(const Fpu& reg) { opFpu(reg, 0xDD, 0xE8); }
-void fucompp() { db(0xDA); db(0xE9); }
-void fwait() { db(0x9B); }
-void fxam() { db(0xD9); db(0xE5); }
-void fxch() { db(0xD9); db(0xC9); }
-void fxch(const Fpu& reg) { opFpu(reg, 0xD9, 0xC8); }
-void fxtract() { db(0xD9); db(0xF4); }
-void fyl2x() { db(0xD9); db(0xF1); }
-void fyl2xp1() { db(0xD9); db(0xF9); }
-void gf2p8affineinvqb(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0xCF, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
-void gf2p8affineqb(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0xCE, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
-void gf2p8mulb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCF, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void haddpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7C, 0x66, isXMM_XMMorMEM); }
-void haddps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7C, 0xF2, isXMM_XMMorMEM); }
-void hsubpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7D, 0x66, isXMM_XMMorMEM); }
-void hsubps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7D, 0xF2, isXMM_XMMorMEM); }
-void idiv(const Operand& op) { opR_ModM(op, 0, 7, 0xF6); }
-void imul(const Operand& op) { opR_ModM(op, 0, 5, 0xF6); }
-void inc(const Operand& op) { opIncDec(op, 0x40, 0); }
-void insertps(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x21, 0x66, isXMM_XMMorMEM, imm, 0x3A); }
-void ja(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524
-void ja(const char *label, LabelType type = T_AUTO) { ja(std::string(label), type); }//-V524
-void ja(const void *addr) { opJmpAbs(addr, T_NEAR, 0x77, 0x87, 0x0F); }//-V524
-void ja(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524
-void jae(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524
-void jae(const char *label, LabelType type = T_AUTO) { jae(std::string(label), type); }//-V524
-void jae(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524
-void jae(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524
-void jb(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524
-void jb(const char *label, LabelType type = T_AUTO) { jb(std::string(label), type); }//-V524
-void jb(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524
-void jb(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524
-void jbe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524
-void jbe(const char *label, LabelType type = T_AUTO) { jbe(std::string(label), type); }//-V524
-void jbe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x76, 0x86, 0x0F); }//-V524
-void jbe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524
-void jc(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524
-void jc(const char *label, LabelType type = T_AUTO) { jc(std::string(label), type); }//-V524
-void jc(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524
-void jc(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524
-void je(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524
-void je(const char *label, LabelType type = T_AUTO) { je(std::string(label), type); }//-V524
-void je(const void *addr) { opJmpAbs(addr, T_NEAR, 0x74, 0x84, 0x0F); }//-V524
-void je(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524
-void jg(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524
-void jg(const char *label, LabelType type = T_AUTO) { jg(std::string(label), type); }//-V524
-void jg(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7F, 0x8F, 0x0F); }//-V524
-void jg(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524
-void jge(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524
-void jge(const char *label, LabelType type = T_AUTO) { jge(std::string(label), type); }//-V524
-void jge(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7D, 0x8D, 0x0F); }//-V524
-void jge(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524
-void jl(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524
-void jl(const char *label, LabelType type = T_AUTO) { jl(std::string(label), type); }//-V524
-void jl(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7C, 0x8C, 0x0F); }//-V524
-void jl(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524
-void jle(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524
-void jle(const char *label, LabelType type = T_AUTO) { jle(std::string(label), type); }//-V524
-void jle(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7E, 0x8E, 0x0F); }//-V524
-void jle(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524
-void jna(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524
-void jna(const char *label, LabelType type = T_AUTO) { jna(std::string(label), type); }//-V524
-void jna(const void *addr) { opJmpAbs(addr, T_NEAR, 0x76, 0x86, 0x0F); }//-V524
-void jna(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524
-void jnae(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524
-void jnae(const char *label, LabelType type = T_AUTO) { jnae(std::string(label), type); }//-V524
-void jnae(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524
-void jnae(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524
-void jnb(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524
-void jnb(const char *label, LabelType type = T_AUTO) { jnb(std::string(label), type); }//-V524
-void jnb(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524
-void jnb(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524
-void jnbe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524
-void jnbe(const char *label, LabelType type = T_AUTO) { jnbe(std::string(label), type); }//-V524
-void jnbe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x77, 0x87, 0x0F); }//-V524
-void jnbe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524
-void jnc(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524
-void jnc(const char *label, LabelType type = T_AUTO) { jnc(std::string(label), type); }//-V524
-void jnc(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524
-void jnc(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524
-void jne(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524
-void jne(const char *label, LabelType type = T_AUTO) { jne(std::string(label), type); }//-V524
-void jne(const void *addr) { opJmpAbs(addr, T_NEAR, 0x75, 0x85, 0x0F); }//-V524
-void jne(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524
-void jng(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524
-void jng(const char *label, LabelType type = T_AUTO) { jng(std::string(label), type); }//-V524
-void jng(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7E, 0x8E, 0x0F); }//-V524
-void jng(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524
-void jnge(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524
-void jnge(const char *label, LabelType type = T_AUTO) { jnge(std::string(label), type); }//-V524
-void jnge(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7C, 0x8C, 0x0F); }//-V524
-void jnge(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524
-void jnl(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524
-void jnl(const char *label, LabelType type = T_AUTO) { jnl(std::string(label), type); }//-V524
-void jnl(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7D, 0x8D, 0x0F); }//-V524
-void jnl(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524
-void jnle(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524
-void jnle(const char *label, LabelType type = T_AUTO) { jnle(std::string(label), type); }//-V524
-void jnle(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7F, 0x8F, 0x0F); }//-V524
-void jnle(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524
-void jno(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x71, 0x81, 0x0F); }//-V524
-void jno(const char *label, LabelType type = T_AUTO) { jno(std::string(label), type); }//-V524
-void jno(const void *addr) { opJmpAbs(addr, T_NEAR, 0x71, 0x81, 0x0F); }//-V524
-void jno(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x71, 0x81, 0x0F); }//-V524
-void jnp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524
-void jnp(const char *label, LabelType type = T_AUTO) { jnp(std::string(label), type); }//-V524
-void jnp(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7B, 0x8B, 0x0F); }//-V524
-void jnp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524
-void jns(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x79, 0x89, 0x0F); }//-V524
-void jns(const char *label, LabelType type = T_AUTO) { jns(std::string(label), type); }//-V524
-void jns(const void *addr) { opJmpAbs(addr, T_NEAR, 0x79, 0x89, 0x0F); }//-V524
-void jns(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x79, 0x89, 0x0F); }//-V524
-void jnz(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524
-void jnz(const char *label, LabelType type = T_AUTO) { jnz(std::string(label), type); }//-V524
-void jnz(const void *addr) { opJmpAbs(addr, T_NEAR, 0x75, 0x85, 0x0F); }//-V524
-void jnz(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524
-void jo(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x70, 0x80, 0x0F); }//-V524
-void jo(const char *label, LabelType type = T_AUTO) { jo(std::string(label), type); }//-V524
-void jo(const void *addr) { opJmpAbs(addr, T_NEAR, 0x70, 0x80, 0x0F); }//-V524
-void jo(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x70, 0x80, 0x0F); }//-V524
-void jp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524
-void jp(const char *label, LabelType type = T_AUTO) { jp(std::string(label), type); }//-V524
-void jp(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7A, 0x8A, 0x0F); }//-V524
-void jp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524
-void jpe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524
-void jpe(const char *label, LabelType type = T_AUTO) { jpe(std::string(label), type); }//-V524
-void jpe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7A, 0x8A, 0x0F); }//-V524
-void jpe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524
-void jpo(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524
-void jpo(const char *label, LabelType type = T_AUTO) { jpo(std::string(label), type); }//-V524
-void jpo(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7B, 0x8B, 0x0F); }//-V524
-void jpo(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524
-void js(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x78, 0x88, 0x0F); }//-V524
-void js(const char *label, LabelType type = T_AUTO) { js(std::string(label), type); }//-V524
-void js(const void *addr) { opJmpAbs(addr, T_NEAR, 0x78, 0x88, 0x0F); }//-V524
-void js(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x78, 0x88, 0x0F); }//-V524
-void jz(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524
-void jz(const char *label, LabelType type = T_AUTO) { jz(std::string(label), type); }//-V524
-void jz(const void *addr) { opJmpAbs(addr, T_NEAR, 0x74, 0x84, 0x0F); }//-V524
-void jz(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524
-void lahf() { db(0x9F); }
-void lddqu(const Xmm& xmm, const Address& addr) { db(0xF2); opModM(addr, xmm, 0x0F, 0xF0); }
-void ldmxcsr(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0xAE); }
-void lea(const Reg& reg, const Address& addr) { if (!reg.isBit(16 | i32e)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModM(addr, reg, 0x8D); }
-void lfence() { db(0x0F); db(0xAE); db(0xE8); }
-void lock() { db(0xF0); }
-void lzcnt(const Reg&reg, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xBD); }
-void maskmovdqu(const Xmm& reg1, const Xmm& reg2) { db(0x66); opModR(reg1, reg2, 0x0F, 0xF7); }
-void maskmovq(const Mmx& reg1, const Mmx& reg2) { if (!reg1.isMMX() || !reg2.isMMX()) throw Error(ERR_BAD_COMBINATION); opModR(reg1, reg2, 0x0F, 0xF7); }
-void maxpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0x66, isXMM_XMMorMEM); }
-void maxps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0x100, isXMM_XMMorMEM); }
-void maxsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0xF2, isXMM_XMMorMEM); }
-void maxss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0xF3, isXMM_XMMorMEM); }
-void mfence() { db(0x0F); db(0xAE); db(0xF0); }
-void minpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0x66, isXMM_XMMorMEM); }
-void minps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0x100, isXMM_XMMorMEM); }
-void minsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0xF2, isXMM_XMMorMEM); }
-void minss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0xF3, isXMM_XMMorMEM); }
-void monitor() { db(0x0F); db(0x01); db(0xC8); }
-void movapd(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x29); }
-void movapd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x28, 0x66); }
-void movaps(const Address& addr, const Xmm& xmm) { opModM(addr, xmm, 0x0F, 0x29); }
-void movaps(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x28, 0x100); }
-void movbe(const Address& addr, const Reg& reg) { opModM(addr, reg, 0x0F, 0x38, 0xF1); }
-void movbe(const Reg& reg, const Address& addr) { opModM(addr, reg, 0x0F, 0x38, 0xF0); }
-void movd(const Address& addr, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, 0x7E); }
-void movd(const Mmx& mmx, const Address& addr) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, 0x6E); }
-void movd(const Mmx& mmx, const Reg32& reg) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x6E); }
-void movd(const Reg32& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x7E); }
-void movddup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x12, 0xF2, isXMM_XMMorMEM, NONE, NONE); }
-void movdq2q(const Mmx& mmx, const Xmm& xmm) { db(0xF2); opModR(mmx, xmm, 0x0F, 0xD6); }
-void movdqa(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x7F); }
-void movdqa(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x6F, 0x66); }
-void movdqu(const Address& addr, const Xmm& xmm) { db(0xF3); opModM(addr, xmm, 0x0F, 0x7F); }
-void movdqu(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x6F, 0xF3); }
-void movhlps(const Xmm& reg1, const Xmm& reg2) { opModR(reg1, reg2, 0x0F, 0x12); }
-void movhpd(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x16, 0x66); }
-void movhps(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x16, 0x100); }
-void movlhps(const Xmm& reg1, const Xmm& reg2) { opModR(reg1, reg2, 0x0F, 0x16); }
-void movlpd(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x12, 0x66); }
-void movlps(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x12, 0x100); }
-void movmskpd(const Reg32e& reg, const Xmm& xmm) { db(0x66); movmskps(reg, xmm); }
-void movmskps(const Reg32e& reg, const Xmm& xmm) { opModR(reg, xmm, 0x0F, 0x50); }
-void movntdq(const Address& addr, const Xmm& reg) { opModM(addr, Reg16(reg.getIdx()), 0x0F, 0xE7); }
-void movntdqa(const Xmm& xmm, const Address& addr) { db(0x66); opModM(addr, xmm, 0x0F, 0x38, 0x2A); }
-void movnti(const Address& addr, const Reg32e& reg) { opModM(addr, reg, 0x0F, 0xC3); }
-void movntpd(const Address& addr, const Xmm& reg) { opModM(addr, Reg16(reg.getIdx()), 0x0F, 0x2B); }
-void movntps(const Address& addr, const Xmm& xmm) { opModM(addr, Mmx(xmm.getIdx()), 0x0F, 0x2B); }
-void movntq(const Address& addr, const Mmx& mmx) { if (!mmx.isMMX()) throw Error(ERR_BAD_COMBINATION); opModM(addr, mmx, 0x0F, 0xE7); }
-void movq(const Address& addr, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, mmx.isXMM() ? 0xD6 : 0x7F); }
-void movq(const Mmx& mmx, const Operand& op) { if (mmx.isXMM()) db(0xF3); opModRM(mmx, op, (mmx.getKind() == op.getKind()), op.isMEM(), 0x0F, mmx.isXMM() ? 0x7E : 0x6F); }
-void movq2dq(const Xmm& xmm, const Mmx& mmx) { db(0xF3); opModR(xmm, mmx, 0x0F, 0xD6); }
-void movsb() { db(0xA4); }
-void movsd() { db(0xA5); }
-void movsd(const Address& addr, const Xmm& xmm) { db(0xF2); opModM(addr, xmm, 0x0F, 0x11); }
-void movsd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0xF2); }
-void movshdup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x16, 0xF3, isXMM_XMMorMEM, NONE, NONE); }
-void movsldup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x12, 0xF3, isXMM_XMMorMEM, NONE, NONE); }
-void movss(const Address& addr, const Xmm& xmm) { db(0xF3); opModM(addr, xmm, 0x0F, 0x11); }
-void movss(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0xF3); }
-void movsw() { db(0x66); db(0xA5); }
-void movsx(const Reg& reg, const Operand& op) { opMovxx(reg, op, 0xBE); }
-void movupd(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x11); }
-void movupd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0x66); }
-void movups(const Address& addr, const Xmm& xmm) { opModM(addr, xmm, 0x0F, 0x11); }
-void movups(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0x100); }
-void movzx(const Reg& reg, const Operand& op) { opMovxx(reg, op, 0xB6); }
-void mpsadbw(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x42, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
-void mul(const Operand& op) { opR_ModM(op, 0, 4, 0xF6); }
-void mulpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0x66, isXMM_XMMorMEM); }
-void mulps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0x100, isXMM_XMMorMEM); }
-void mulsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0xF2, isXMM_XMMorMEM); }
-void mulss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0xF3, isXMM_XMMorMEM); }
-void mulx(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F2 | T_0F38, 0xf6, true); }
-void mwait() { db(0x0F); db(0x01); db(0xC9); }
-void neg(const Operand& op) { opR_ModM(op, 0, 3, 0xF6); }
-void not_(const Operand& op) { opR_ModM(op, 0, 2, 0xF6); }
-void or_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x08, 1); }
-void or_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x08); }
-void orpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x56, 0x66, isXMM_XMMorMEM); }
-void orps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x56, 0x100, isXMM_XMMorMEM); }
-void pabsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1C, 0x66, NONE, 0x38); }
-void pabsd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1E, 0x66, NONE, 0x38); }
-void pabsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1D, 0x66, NONE, 0x38); }
-void packssdw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x6B); }
-void packsswb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x63); }
-void packusdw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2B, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void packuswb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x67); }
-void paddb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFC); }
-void paddd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFE); }
-void paddq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD4); }
-void paddsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEC); }
-void paddsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xED); }
-void paddusb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDC); }
-void paddusw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDD); }
-void paddw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFD); }
-void palignr(const Mmx& mmx, const Operand& op, int imm) { opMMX(mmx, op, 0x0f, 0x66, static_cast<uint8>(imm), 0x3a); }
-void pand(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDB); }
-void pandn(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDF); }
-void pause() { db(0xF3); db(0x90); }
-void pavgb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE0); }
-void pavgw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE3); }
-void pblendvb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x10, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pblendw(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0E, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
-void pclmulhqhdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x11); }
-void pclmulhqlqdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x01); }
-void pclmullqhdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x10); }
-void pclmullqlqdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x00); }
-void pclmulqdq(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x44, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
-void pcmpeqb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x74); }
-void pcmpeqd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x76); }
-void pcmpeqq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x29, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pcmpeqw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x75); }
-void pcmpestri(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x61, 0x66, isXMM_XMMorMEM, imm, 0x3A); }
-void pcmpestrm(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x60, 0x66, isXMM_XMMorMEM, imm, 0x3A); }
-void pcmpgtb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x64); }
-void pcmpgtd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x66); }
-void pcmpgtq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x37, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pcmpgtw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x65); }
-void pcmpistri(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x63, 0x66, isXMM_XMMorMEM, imm, 0x3A); }
-void pcmpistrm(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x62, 0x66, isXMM_XMMorMEM, imm, 0x3A); }
-void pdep(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F2 | T_0F38, 0xf5, true); }
-void pext(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F3 | T_0F38, 0xf5, true); }
-void pextrb(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x14, imm); }
-void pextrd(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x16, imm); }
-void pextrw(const Operand& op, const Mmx& xmm, uint8 imm) { opExt(op, xmm, 0x15, imm, true); }
-void phaddd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x02, 0x66, NONE, 0x38); }
-void phaddsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x03, 0x66, NONE, 0x38); }
-void phaddw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x01, 0x66, NONE, 0x38); }
-void phminposuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x41, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void phsubd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x06, 0x66, NONE, 0x38); }
-void phsubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x07, 0x66, NONE, 0x38); }
-void phsubw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x05, 0x66, NONE, 0x38); }
-void pinsrb(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x20, 0x66, isXMM_REG32orMEM, imm, 0x3A); }
-void pinsrd(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x22, 0x66, isXMM_REG32orMEM, imm, 0x3A); }
-void pinsrw(const Mmx& mmx, const Operand& op, int imm) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(mmx, op, 0xC4, mmx.isXMM() ? 0x66 : NONE, 0, imm); }
-void pmaddubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x04, 0x66, NONE, 0x38); }
-void pmaddwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF5); }
-void pmaxsb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3C, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pmaxsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3D, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pmaxsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEE); }
-void pmaxub(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDE); }
-void pmaxud(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3F, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pmaxuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3E, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pminsb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x38, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pminsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x39, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pminsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEA); }
-void pminub(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDA); }
-void pminud(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3B, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pminuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3A, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pmovmskb(const Reg32e& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(reg, mmx, 0x0F, 0xD7); }
-void pmovsxbd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x21, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pmovsxbq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x22, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pmovsxbw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x20, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pmovsxdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x25, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pmovsxwd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x23, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pmovsxwq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x24, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pmovzxbd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x31, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pmovzxbq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x32, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pmovzxbw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x30, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pmovzxdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x35, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pmovzxwd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x33, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pmovzxwq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x34, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pmuldq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x28, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pmulhrsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x0B, 0x66, NONE, 0x38); }
-void pmulhuw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE4); }
-void pmulhw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE5); }
-void pmulld(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x40, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void pmullw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD5); }
-void pmuludq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF4); }
-void popcnt(const Reg&reg, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xB8); }
-void popf() { db(0x9D); }
-void por(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEB); }
-void prefetchnta(const Address& addr) { opModM(addr, Reg32(0), 0x0F, 0x18); }
-void prefetcht0(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0x18); }
-void prefetcht1(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0x18); }
-void prefetcht2(const Address& addr) { opModM(addr, Reg32(3), 0x0F, 0x18); }
-void prefetchw(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0x0D); }
-void prefetchwt1(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0x0D); }
-void psadbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF6); }
-void pshufb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x00, 0x66, NONE, 0x38); }
-void pshufd(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0x66, imm8); }
-void pshufhw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0xF3, imm8); }
-void pshuflw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0xF2, imm8); }
-void pshufw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0x00, imm8); }
-void psignb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x08, 0x66, NONE, 0x38); }
-void psignd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x0A, 0x66, NONE, 0x38); }
-void psignw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x09, 0x66, NONE, 0x38); }
-void pslld(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF2); }
-void pslld(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 6); }
-void pslldq(const Xmm& xmm, int imm8) { opMMX_IMM(xmm, imm8, 0x73, 7); }
-void psllq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF3); }
-void psllq(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x73, 6); }
-void psllw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF1); }
-void psllw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 6); }
-void psrad(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE2); }
-void psrad(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 4); }
-void psraw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE1); }
-void psraw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 4); }
-void psrld(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD2); }
-void psrld(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 2); }
-void psrldq(const Xmm& xmm, int imm8) { opMMX_IMM(xmm, imm8, 0x73, 3); }
-void psrlq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD3); }
-void psrlq(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x73, 2); }
-void psrlw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD1); }
-void psrlw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 2); }
-void psubb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF8); }
-void psubd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFA); }
-void psubq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFB); }
-void psubsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE8); }
-void psubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE9); }
-void psubusb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD8); }
-void psubusw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD9); }
-void psubw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF9); }
-void ptest(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x17, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
-void punpckhbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x68); }
-void punpckhdq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x6A); }
-void punpckhqdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x6D, 0x66, isXMM_XMMorMEM); }
-void punpckhwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x69); }
-void punpcklbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x60); }
-void punpckldq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x62); }
-void punpcklqdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x6C, 0x66, isXMM_XMMorMEM); }
-void punpcklwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x61); }
-void pushf() { db(0x9C); }
-void pxor(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEF); }
-void rcl(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 2); }
-void rcl(const Operand& op, int imm) { opShift(op, imm, 2); }
-void rcpps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x53, 0x100, isXMM_XMMorMEM); }
-void rcpss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x53, 0xF3, isXMM_XMMorMEM); }
-void rcr(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 3); }
-void rcr(const Operand& op, int imm) { opShift(op, imm, 3); }
-void rdmsr() { db(0x0F); db(0x32); }
-void rdpmc() { db(0x0F); db(0x33); }
-void rdrand(const Reg& r) { if (r.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModR(Reg(6, Operand::REG, r.getBit()), r, 0x0F, 0xC7); }
-void rdseed(const Reg& r) { if (r.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModR(Reg(7, Operand::REG, r.getBit()), r, 0x0F, 0xC7); }
-void rdtsc() { db(0x0F); db(0x31); }
-void rdtscp() { db(0x0F); db(0x01); db(0xF9); }
-void rep() { db(0xF3); }
-void ret(int imm = 0) { if (imm) { db(0xC2); dw(imm); } else { db(0xC3); } }
-void rol(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 0); }
-void rol(const Operand& op, int imm) { opShift(op, imm, 0); }
-void ror(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 1); }
-void ror(const Operand& op, int imm) { opShift(op, imm, 1); }
-void rorx(const Reg32e& r, const Operand& op, uint8 imm) { opGpr(r, op, Reg32e(0, r.getBit()), T_0F3A | T_F2, 0xF0, false, imm); }
-void roundpd(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x09, 0x66, isXMM_XMMorMEM, imm, 0x3A); }
-void roundps(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x08, 0x66, isXMM_XMMorMEM, imm, 0x3A); }
-void roundsd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0B, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
-void roundss(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0A, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
-void rsqrtps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x52, 0x100, isXMM_XMMorMEM); }
-void rsqrtss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x52, 0xF3, isXMM_XMMorMEM); }
-void sahf() { db(0x9E); }
-void sal(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 4); }
-void sal(const Operand& op, int imm) { opShift(op, imm, 4); }
-void sar(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 7); }
-void sar(const Operand& op, int imm) { opShift(op, imm, 7); }
-void sarx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_F3 | T_0F38, 0xf7, false); }
-void sbb(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x18, 3); }
-void sbb(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x18); }
-void scasb() { db(0xAE); }
-void scasd() { db(0xAF); }
-void scasw() { db(0x66); db(0xAF); }
-void seta(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 7); }//-V524
-void setae(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524
-void setb(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524
-void setbe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 6); }//-V524
-void setc(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524
-void sete(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 4); }//-V524
-void setg(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 15); }//-V524
-void setge(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 13); }//-V524
-void setl(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 12); }//-V524
-void setle(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 14); }//-V524
-void setna(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 6); }//-V524
-void setnae(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524
-void setnb(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524
-void setnbe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 7); }//-V524
-void setnc(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524
-void setne(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 5); }//-V524
-void setng(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 14); }//-V524
-void setnge(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 12); }//-V524
-void setnl(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 13); }//-V524
-void setnle(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 15); }//-V524
-void setno(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 1); }//-V524
-void setnp(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 11); }//-V524
-void setns(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 9); }//-V524
-void setnz(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 5); }//-V524
-void seto(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 0); }//-V524
-void setp(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 10); }//-V524
-void setpe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 10); }//-V524
-void setpo(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 11); }//-V524
-void sets(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 8); }//-V524
-void setz(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 4); }//-V524
-void sfence() { db(0x0F); db(0xAE); db(0xF8); }
-void sha1msg1(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xC9, NONE, isXMM_XMMorMEM, NONE, 0x38); }
-void sha1msg2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCA, NONE, isXMM_XMMorMEM, NONE, 0x38); }
-void sha1nexte(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xC8, NONE, isXMM_XMMorMEM, NONE, 0x38); }
-void sha1rnds4(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0xCC, NONE, isXMM_XMMorMEM, imm, 0x3A); }
-void sha256msg1(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCC, NONE, isXMM_XMMorMEM, NONE, 0x38); }
-void sha256msg2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCD, NONE, isXMM_XMMorMEM, NONE, 0x38); }
-void sha256rnds2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCB, NONE, isXMM_XMMorMEM, NONE, 0x38); }
-void shl(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 4); }
-void shl(const Operand& op, int imm) { opShift(op, imm, 4); }
-void shld(const Operand& op, const Reg& reg, const Reg8& _cl) { opShxd(op, reg, 0, 0xA4, &_cl); }
-void shld(const Operand& op, const Reg& reg, uint8 imm) { opShxd(op, reg, imm, 0xA4); }
-void shlx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_66 | T_0F38, 0xf7, false); }
-void shr(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 5); }
-void shr(const Operand& op, int imm) { opShift(op, imm, 5); }
-void shrd(const Operand& op, const Reg& reg, const Reg8& _cl) { opShxd(op, reg, 0, 0xAC, &_cl); }
-void shrd(const Operand& op, const Reg& reg, uint8 imm) { opShxd(op, reg, imm, 0xAC); }
-void shrx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_F2 | T_0F38, 0xf7, false); }
-void shufpd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC6, 0x66, isXMM_XMMorMEM, imm8); }
-void shufps(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC6, 0x100, isXMM_XMMorMEM, imm8); }
-void sqrtpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0x66, isXMM_XMMorMEM); }
-void sqrtps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0x100, isXMM_XMMorMEM); }
-void sqrtsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0xF2, isXMM_XMMorMEM); }
-void sqrtss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0xF3, isXMM_XMMorMEM); }
-void stac() { db(0x0F); db(0x01); db(0xCB); }
-void stc() { db(0xF9); }
-void std() { db(0xFD); }
-void sti() { db(0xFB); }
-void stmxcsr(const Address& addr) { opModM(addr, Reg32(3), 0x0F, 0xAE); }
-void stosb() { db(0xAA); }
-void stosd() { db(0xAB); }
-void stosw() { db(0x66); db(0xAB); }
-void sub(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x28, 5); }
-void sub(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x28); }
-void subpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0x66, isXMM_XMMorMEM); }
-void subps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0x100, isXMM_XMMorMEM); }
-void subsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0xF2, isXMM_XMMorMEM); }
-void subss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0xF3, isXMM_XMMorMEM); }
-void tzcnt(const Reg&reg, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xBC); }
-void ucomisd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2E, 0x66, isXMM_XMMorMEM); }
-void ucomiss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2E, 0x100, isXMM_XMMorMEM); }
-void ud2() { db(0x0F); db(0x0B); }
-void unpckhpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x66, isXMM_XMMorMEM); }
-void unpckhps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x100, isXMM_XMMorMEM); }
-void unpcklpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x66, isXMM_XMMorMEM); }
-void unpcklps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x100, isXMM_XMMorMEM); }
-void vaddpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x58); }
-void vaddps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x58); }
-void vaddsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x58); }
-void vaddss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x58); }
-void vaddsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0xD0); }
-void vaddsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0xD0); }
-void vaesdec(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDE); }
-void vaesdeclast(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDF); }
-void vaesenc(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDC); }
-void vaesenclast(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDD); }
-void vaesimc(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_W0, 0xDB); }
-void vaeskeygenassist(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0xDF, imm); }
-void vandnpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x55); }
-void vandnps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x55); }
-void vandpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x54); }
-void vandps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x54); }
-void vblendpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0D, imm); }
-void vblendps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0C, imm); }
-void vblendvpd(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4B, x4.getIdx() << 4); }
-void vblendvps(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4A, x4.getIdx() << 4); }
-void vbroadcastf128(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x1A); }
-void vbroadcasti128(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x5A); }
-void vbroadcastsd(const Ymm& y, const Operand& op) { if (!op.isMEM() && !(y.isYMM() && op.isXMM()) && !(y.isZMM() && op.isXMM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(y, op, T_0F38 | T_66 | T_W0 | T_YMM | T_EVEX | T_EW1 | T_N8, 0x19); }
-void vbroadcastss(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x18); }
-void vcmpeq_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 16); }
-void vcmpeq_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 16); }
-void vcmpeq_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 16); }
-void vcmpeq_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 16); }
-void vcmpeq_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 8); }
-void vcmpeq_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 8); }
-void vcmpeq_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 8); }
-void vcmpeq_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 8); }
-void vcmpeq_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 24); }
-void vcmpeq_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 24); }
-void vcmpeq_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 24); }
-void vcmpeq_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 24); }
-void vcmpeqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 0); }
-void vcmpeqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 0); }
-void vcmpeqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 0); }
-void vcmpeqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 0); }
-void vcmpfalse_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 27); }
-void vcmpfalse_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 27); }
-void vcmpfalse_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 27); }
-void vcmpfalse_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 27); }
-void vcmpfalsepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 11); }
-void vcmpfalseps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 11); }
-void vcmpfalsesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 11); }
-void vcmpfalsess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 11); }
-void vcmpge_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 29); }
-void vcmpge_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 29); }
-void vcmpge_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 29); }
-void vcmpge_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 29); }
-void vcmpgepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 13); }
-void vcmpgeps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 13); }
-void vcmpgesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 13); }
-void vcmpgess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 13); }
-void vcmpgt_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 30); }
-void vcmpgt_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 30); }
-void vcmpgt_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 30); }
-void vcmpgt_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 30); }
-void vcmpgtpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 14); }
-void vcmpgtps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 14); }
-void vcmpgtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 14); }
-void vcmpgtss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 14); }
-void vcmple_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 18); }
-void vcmple_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 18); }
-void vcmple_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 18); }
-void vcmple_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 18); }
-void vcmplepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 2); }
-void vcmpleps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 2); }
-void vcmplesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 2); }
-void vcmpless(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 2); }
-void vcmplt_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 17); }
-void vcmplt_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 17); }
-void vcmplt_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 17); }
-void vcmplt_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 17); }
-void vcmpltpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 1); }
-void vcmpltps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 1); }
-void vcmpltsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 1); }
-void vcmpltss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 1); }
-void vcmpneq_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 12); }
-void vcmpneq_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 12); }
-void vcmpneq_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 12); }
-void vcmpneq_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 12); }
-void vcmpneq_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 28); }
-void vcmpneq_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 28); }
-void vcmpneq_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 28); }
-void vcmpneq_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 28); }
-void vcmpneq_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 20); }
-void vcmpneq_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 20); }
-void vcmpneq_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 20); }
-void vcmpneq_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 20); }
-void vcmpneqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 4); }
-void vcmpneqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 4); }
-void vcmpneqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 4); }
-void vcmpneqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 4); }
-void vcmpnge_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 25); }
-void vcmpnge_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 25); }
-void vcmpnge_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 25); }
-void vcmpnge_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 25); }
-void vcmpngepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 9); }
-void vcmpngeps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 9); }
-void vcmpngesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 9); }
-void vcmpngess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 9); }
-void vcmpngt_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 26); }
-void vcmpngt_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 26); }
-void vcmpngt_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 26); }
-void vcmpngt_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 26); }
-void vcmpngtpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 10); }
-void vcmpngtps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 10); }
-void vcmpngtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 10); }
-void vcmpngtss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 10); }
-void vcmpnle_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 22); }
-void vcmpnle_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 22); }
-void vcmpnle_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 22); }
-void vcmpnle_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 22); }
-void vcmpnlepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 6); }
-void vcmpnleps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 6); }
-void vcmpnlesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 6); }
-void vcmpnless(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 6); }
-void vcmpnlt_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 21); }
-void vcmpnlt_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 21); }
-void vcmpnlt_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 21); }
-void vcmpnlt_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 21); }
-void vcmpnltpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 5); }
-void vcmpnltps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 5); }
-void vcmpnltsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 5); }
-void vcmpnltss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 5); }
-void vcmpord_spd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 23); }
-void vcmpord_sps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 23); }
-void vcmpord_ssd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 23); }
-void vcmpord_sss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 23); }
-void vcmpordpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 7); }
-void vcmpordps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 7); }
-void vcmpordsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 7); }
-void vcmpordss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 7); }
-void vcmppd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xC2, imm); }
-void vcmpps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_0F | T_YMM, 0xC2, imm); }
-void vcmpsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_F2 | T_0F, 0xC2, imm); }
-void vcmpss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0xC2, imm); }
-void vcmptrue_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 31); }
-void vcmptrue_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 31); }
-void vcmptrue_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 31); }
-void vcmptrue_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 31); }
-void vcmptruepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 15); }
-void vcmptrueps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 15); }
-void vcmptruesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 15); }
-void vcmptruess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 15); }
-void vcmpunord_spd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 19); }
-void vcmpunord_sps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 19); }
-void vcmpunord_ssd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 19); }
-void vcmpunord_sss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 19); }
-void vcmpunordpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 3); }
-void vcmpunordps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 3); }
-void vcmpunordsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 3); }
-void vcmpunordss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 3); }
-void vcomisd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_66 | T_0F | T_EW1 | T_EVEX | T_SAE_X, 0x2F); }
-void vcomiss(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x2F); }
-void vcvtdq2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F | T_F3 | T_YMM | T_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL, 0xE6); }
-void vcvtdq2ps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5B); }
-void vcvtpd2dq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_F2 | T_YMM | T_EVEX | T_EW1 | T_B64 | T_ER_Z, 0xE6); }
-void vcvtpd2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_66 | T_YMM | T_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x5A); }
-void vcvtph2ps(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F38 | T_66 | T_W0 | T_EVEX | T_EW0 | T_N8 | T_N_VL | T_SAE_Y, 0x13); }
-void vcvtps2dq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5B); }
-void vcvtps2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F | T_YMM | T_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x5A); }
-void vcvtps2ph(const Operand& op, const Xmm& x, uint8 imm) { checkCvt1(x, op); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N8 | T_N_VL | T_SAE_Y, 0x1D, imm); }
-void vcvtsd2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W0 | T_EVEX | T_EW0 | T_N4 | T_ER_X, 0x2D); }
-void vcvtsd2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_ER_X, 0x5A); }
-void vcvtsi2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_0F | T_F2 | T_EVEX, T_W1 | T_EW1 | T_ER_X | T_N8, T_W0 | T_EW0 | T_N4, 0x2A); }
-void vcvtsi2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_0F | T_F3 | T_EVEX | T_ER_X, T_W1 | T_EW1 | T_N8, T_W0 | T_EW0 | T_N4, 0x2A); }
-void vcvtss2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x5A); }
-void vcvtss2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W0 | T_EVEX | T_EW0 | T_ER_X | T_N8, 0x2D); }
-void vcvttpd2dq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_66 | T_0F | T_YMM | T_EVEX |T_EW1 | T_B64 | T_ER_Z, 0xE6); }
-void vcvttps2dq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX | T_SAE_Z | T_B32, 0x5B); }
-void vcvttsd2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W0 | T_EVEX | T_EW0 | T_N4 | T_SAE_X, 0x2C); }
-void vcvttss2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W0 | T_EVEX | T_EW0 | T_SAE_X | T_N8, 0x2C); }
-void vdivpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5E); }
-void vdivps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5E); }
-void vdivsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5E); }
-void vdivss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5E); }
-void vdppd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x41, imm); }
-void vdpps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x40, imm); }
-void vextractf128(const Operand& op, const Ymm& y, uint8 imm) { if (!(op.isXMEM() && y.isYMM())) throw Error(ERR_BAD_COMBINATION); opVex(y, 0, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x19, imm); }
-void vextracti128(const Operand& op, const Ymm& y, uint8 imm) { if (!(op.isXMEM() && y.isYMM())) throw Error(ERR_BAD_COMBINATION); opVex(y, 0, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x39, imm); }
-void vextractps(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(32) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_N4, 0x17, imm); }
-void vfmadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x98); }
-void vfmadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x98); }
-void vfmadd132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x99); }
-void vfmadd132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x99); }
-void vfmadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA8); }
-void vfmadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA8); }
-void vfmadd213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xA9); }
-void vfmadd213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xA9); }
-void vfmadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB8); }
-void vfmadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB8); }
-void vfmadd231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xB9); }
-void vfmadd231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xB9); }
-void vfmaddsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x96); }
-void vfmaddsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x96); }
-void vfmaddsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA6); }
-void vfmaddsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA6); }
-void vfmaddsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB6); }
-void vfmaddsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB6); }
-void vfmsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9A); }
-void vfmsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9A); }
-void vfmsub132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9B); }
-void vfmsub132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9B); }
-void vfmsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAA); }
-void vfmsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAA); }
-void vfmsub213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAB); }
-void vfmsub213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAB); }
-void vfmsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBA); }
-void vfmsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBA); }
-void vfmsub231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBB); }
-void vfmsub231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBB); }
-void vfmsubadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x97); }
-void vfmsubadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x97); }
-void vfmsubadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA7); }
-void vfmsubadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA7); }
-void vfmsubadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB7); }
-void vfmsubadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB7); }
-void vfnmadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9C); }
-void vfnmadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9C); }
-void vfnmadd132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9D); }
-void vfnmadd132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9D); }
-void vfnmadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAC); }
-void vfnmadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAC); }
-void vfnmadd213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAD); }
-void vfnmadd213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAD); }
-void vfnmadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBC); }
-void vfnmadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBC); }
-void vfnmadd231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBD); }
-void vfnmadd231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBD); }
-void vfnmsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9E); }
-void vfnmsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9E); }
-void vfnmsub132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9F); }
-void vfnmsub132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9F); }
-void vfnmsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAE); }
-void vfnmsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAE); }
-void vfnmsub213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAF); }
-void vfnmsub213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAF); }
-void vfnmsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBE); }
-void vfnmsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBE); }
-void vfnmsub231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBF); }
-void vfnmsub231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBF); }
-void vgatherdpd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x92, 0); }
-void vgatherdps(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x92, 1); }
-void vgatherqpd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x93, 1); }
-void vgatherqps(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x93, 2); }
-void vgf2p8affineinvqb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_SAE_Z | T_B64, 0xCF, imm); }
-void vgf2p8affineqb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_SAE_Z | T_B64, 0xCE, imm); }
-void vgf2p8mulb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_SAE_Z, 0xCF); }
-void vhaddpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0x7C); }
-void vhaddps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0x7C); }
-void vhsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0x7D); }
-void vhsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0x7D); }
-void vinsertf128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isXMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x18, imm); }
-void vinserti128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isXMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x38, imm); }
-void vinsertps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_W0 | T_EW0 | T_EVEX, 0x21, imm); }
-void vlddqu(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, cvtIdx0(x), addr, T_0F | T_F2 | T_W0 | T_YMM, 0xF0); }
-void vldmxcsr(const Address& addr) { opAVX_X_X_XM(xm2, xm0, addr, T_0F, 0xAE); }
-void vmaskmovdqu(const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x1, xm0, x2, T_0F | T_66, 0xF7); }
-void vmaskmovpd(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2F); }
-void vmaskmovpd(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2D); }
-void vmaskmovps(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2E); }
-void vmaskmovps(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2C); }
-void vmaxpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5F); }
-void vmaxps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5F); }
-void vmaxsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5F); }
-void vmaxss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5F); }
-void vminpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5D); }
-void vminps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5D); }
-void vminsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5D); }
-void vminss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5D); }
-void vmovapd(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_M_K, 0x29); }
-void vmovapd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0x28); }
-void vmovaps(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_0F | T_EW0 | T_YMM | T_EVEX | T_M_K, 0x29); }
-void vmovaps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX, 0x28); }
-void vmovd(const Operand& op, const Xmm& x) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, xm0, op, T_0F | T_66 | T_W0 | T_EVEX | T_N4, 0x7E); }
-void vmovd(const Xmm& x, const Operand& op) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, xm0, op, T_0F | T_66 | T_W0 | T_EVEX | T_N4, 0x6E); }
-void vmovddup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_DUP | T_F2 | T_0F | T_EW1 | T_YMM | T_EVEX | T_ER_X | T_ER_Y | T_ER_Z, 0x12); }
-void vmovdqa(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_YMM, 0x7F); }
-void vmovdqa(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_YMM, 0x6F); }
-void vmovdqu(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_F3 | T_0F | T_YMM, 0x7F); }
-void vmovdqu(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_YMM, 0x6F); }
-void vmovhlps(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_0F | T_EVEX | T_EW0, 0x12); }
-void vmovhpd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x17); }
-void vmovhpd(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x16); }
-void vmovhps(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_EVEX | T_EW0 | T_N8, 0x17); }
-void vmovhps(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_EVEX | T_EW0 | T_N8, 0x16); }
-void vmovlhps(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_0F | T_EVEX | T_EW0, 0x16); }
-void vmovlpd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x13); }
-void vmovlpd(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x12); }
-void vmovlps(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_EVEX | T_EW0 | T_N8, 0x13); }
-void vmovlps(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_EVEX | T_EW0 | T_N8, 0x12); }
-void vmovmskpd(const Reg& r, const Xmm& x) { if (!r.isBit(i32e)) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_66 | T_W0 | T_YMM, 0x50); }
-void vmovmskps(const Reg& r, const Xmm& x) { if (!r.isBit(i32e)) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_W0 | T_YMM, 0x50); }
-void vmovntdq(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_66 | T_YMM | T_EVEX | T_EW0, 0xE7); }
-void vmovntdqa(const Xmm& x, const Address& addr) { opVex(x, 0, addr, T_0F38 | T_66 | T_YMM | T_EVEX | T_EW0, 0x2A); }
-void vmovntpd(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_66 | T_YMM | T_EVEX | T_EW1, 0x2B); }
-void vmovntps(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_YMM | T_EVEX | T_EW0, 0x2B); }
-void vmovq(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, x.getIdx() < 16 ? 0xD6 : 0x7E); }
-void vmovq(const Xmm& x, const Address& addr) { int type, code; if (x.getIdx() < 16) { type = T_0F | T_F3; code = 0x7E; } else { type = T_0F | T_66 | T_EVEX | T_EW1 | T_N8; code = 0x6E; } opAVX_X_X_XM(x, xm0, addr, type, code); }
-void vmovq(const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x1, xm0, x2, T_0F | T_F3 | T_EVEX | T_EW1 | T_N8, 0x7E); }
-void vmovsd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_M_K, 0x11); }
-void vmovsd(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, xm0, addr, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX, 0x10); }
-void vmovsd(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX, 0x10); }
-void vmovshdup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX, 0x16); }
-void vmovsldup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX, 0x12); }
-void vmovss(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_M_K, 0x11); }
-void vmovss(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, xm0, addr, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX, 0x10); }
-void vmovss(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX, 0x10); }
-void vmovupd(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_M_K, 0x11); }
-void vmovupd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0x10); }
-void vmovups(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_0F | T_EW0 | T_YMM | T_EVEX | T_M_K, 0x11); }
-void vmovups(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX, 0x10); }
-void vmpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x42, imm); }
-void vmulpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x59); }
-void vmulps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x59); }
-void vmulsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x59); }
-void vmulss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x59); }
-void vorpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x56); }
-void vorps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x56); }
-void vpabsb(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x1C); }
-void vpabsd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x1E); }
-void vpabsw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x1D); }
-void vpackssdw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x6B); }
-void vpacksswb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x63); }
-void vpackusdw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x2B); }
-void vpackuswb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x67); }
-void vpaddb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xFC); }
-void vpaddd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xFE); }
-void vpaddq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xD4); }
-void vpaddsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEC); }
-void vpaddsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xED); }
-void vpaddusb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDC); }
-void vpaddusw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDD); }
-void vpaddw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xFD); }
-void vpalignr(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_YMM | T_EVEX, 0x0F, imm); }
-void vpand(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xDB); }
-void vpandn(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xDF); }
-void vpavgb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE0); }
-void vpavgw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE3); }
-void vpblendd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x02, imm); }
-void vpblendvb(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4C, x4.getIdx() << 4); }
-void vpblendw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0E, imm); }
-void vpbroadcastb(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x78); }
-void vpbroadcastd(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x58); }
-void vpbroadcastq(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX, 0x59); }
-void vpbroadcastw(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x79); }
-void vpclmulqdq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM | T_EVEX, 0x44, imm); }
-void vpcmpeqb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x74); }
-void vpcmpeqd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x76); }
-void vpcmpeqq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x29); }
-void vpcmpeqw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x75); }
-void vpcmpestri(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x61, imm); }
-void vpcmpestrm(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x60, imm); }
-void vpcmpgtb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x64); }
-void vpcmpgtd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x66); }
-void vpcmpgtq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x37); }
-void vpcmpgtw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x65); }
-void vpcmpistri(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x63, imm); }
-void vpcmpistrm(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x62, imm); }
-void vperm2f128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isYMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x06, imm); }
-void vperm2i128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isYMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x46, imm); }
-void vpermd(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x36); }
-void vpermilpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x0D); }
-void vpermilpd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_EVEX | T_B64, 0x05, imm); }
-void vpermilps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x0C); }
-void vpermilps(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_EVEX | T_B32, 0x04, imm); }
-void vpermpd(const Ymm& y, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(y, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x01, imm); }
-void vpermpd(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x16); }
-void vpermps(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x16); }
-void vpermq(const Ymm& y, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(y, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x00, imm); }
-void vpermq(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x36); }
-void vpextrb(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(8|16|i32e) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_EVEX | T_N1, 0x14, imm); }
-void vpextrd(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(32) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N4, 0x16, imm); }
-void vpextrq(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(64) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W1 | T_EVEX | T_EW1 | T_N8, 0x16, imm); }
-void vpextrw(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(16|i32e) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); if (op.isREG() && x.getIdx() < 16) { opAVX_X_X_XM(Xmm(op.getIdx()), xm0, x, T_0F | T_66, 0xC5, imm); } else { opVex(x, 0, op, T_0F3A | T_66 | T_EVEX | T_N2, 0x15, imm); } }
-void vpgatherdd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x90, 1); }
-void vpgatherdq(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x90, 0); }
-void vpgatherqd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x91, 2); }
-void vpgatherqq(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x91, 1); }
-void vphaddd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x02); }
-void vphaddsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x03); }
-void vphaddw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x01); }
-void vphminposuw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38, 0x41); }
-void vphsubd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x06); }
-void vphsubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x07); }
-void vphsubw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x05); }
-void vpinsrb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_EVEX | T_N1, 0x20, imm); }
-void vpinsrd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N4, 0x22, imm); }
-void vpinsrq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(64) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_W1 | T_EVEX | T_EW1 | T_N8, 0x22, imm); }
-void vpinsrw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F | T_66 | T_EVEX | T_N2, 0xC4, imm); }
-void vpmaddubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x04); }
-void vpmaddwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF5); }
-void vpmaskmovd(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x8E); }
-void vpmaskmovd(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x8C); }
-void vpmaskmovq(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W1 | T_YMM, 0x8E); }
-void vpmaskmovq(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W1 | T_YMM, 0x8C); }
-void vpmaxsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3C); }
-void vpmaxsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3D); }
-void vpmaxsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEE); }
-void vpmaxub(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDE); }
-void vpmaxud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3F); }
-void vpmaxuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3E); }
-void vpminsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x38); }
-void vpminsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x39); }
-void vpminsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEA); }
-void vpminub(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDA); }
-void vpminud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3B); }
-void vpminuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3A); }
-void vpmovmskb(const Reg32e& r, const Xmm& x) { if (!x.is(Operand::XMM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(x.isYMM() ? Ymm(r.getIdx()) : Xmm(r.getIdx()), 0, x, T_0F | T_66 | T_YMM, 0xD7); }
-void vpmovsxbd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x21); }
-void vpmovsxbq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N2 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x22); }
-void vpmovsxbw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x20); }
-void vpmovsxdq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX, 0x25); }
-void vpmovsxwd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x23); }
-void vpmovsxwq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x24); }
-void vpmovzxbd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x31); }
-void vpmovzxbq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N2 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x32); }
-void vpmovzxbw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x30); }
-void vpmovzxdq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX, 0x35); }
-void vpmovzxwd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x33); }
-void vpmovzxwq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x34); }
-void vpmuldq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x28); }
-void vpmulhrsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x0B); }
-void vpmulhuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE4); }
-void vpmulhw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE5); }
-void vpmulld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x40); }
-void vpmullw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD5); }
-void vpmuludq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xF4); }
-void vpor(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xEB); }
-void vpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF6); }
-void vpshufb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x00); }
-void vpshufd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x70, imm); }
-void vpshufhw(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_YMM | T_EVEX, 0x70, imm); }
-void vpshuflw(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_F2 | T_0F | T_YMM | T_EVEX, 0x70, imm); }
-void vpsignb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x08); }
-void vpsignd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x0A); }
-void vpsignw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x09); }
-void vpslld(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); }
-void vpslld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xF2); }
-void vpslldq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 7), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x73, imm); }
-void vpsllq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64 | T_MEM_EVEX, 0x73, imm); }
-void vpsllq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0xF3); }
-void vpsllvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x47); }
-void vpsllvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x47); }
-void vpsllw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); }
-void vpsllw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xF1); }
-void vpsrad(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); }
-void vpsrad(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xE2); }
-void vpsravd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x46); }
-void vpsraw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); }
-void vpsraw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xE1); }
-void vpsrld(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); }
-void vpsrld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xD2); }
-void vpsrldq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 3), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x73, imm); }
-void vpsrlq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64 | T_MEM_EVEX, 0x73, imm); }
-void vpsrlq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0xD3); }
-void vpsrlvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x45); }
-void vpsrlvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x45); }
-void vpsrlw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); }
-void vpsrlw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xD1); }
-void vpsubb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF8); }
-void vpsubd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xFA); }
-void vpsubq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xFB); }
-void vpsubsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE8); }
-void vpsubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE9); }
-void vpsubusb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD8); }
-void vpsubusw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD9); }
-void vpsubw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF9); }
-void vptest(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x17); }
-void vpunpckhbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x68); }
-void vpunpckhdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x6A); }
-void vpunpckhqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x6D); }
-void vpunpckhwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x69); }
-void vpunpcklbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x60); }
-void vpunpckldq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x62); }
-void vpunpcklqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x6C); }
-void vpunpcklwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x61); }
-void vpxor(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xEF); }
-void vrcpps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_YMM, 0x53); }
-void vrcpss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0x53); }
-void vroundpd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_YMM, 0x09, imm); }
-void vroundps(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_YMM, 0x08, imm); }
-void vroundsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x0B, imm); }
-void vroundss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x0A, imm); }
-void vrsqrtps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_YMM, 0x52); }
-void vrsqrtss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0x52); }
-void vshufpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xC6, imm); }
-void vshufps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xC6, imm); }
-void vsqrtpd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x51); }
-void vsqrtps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x51); }
-void vsqrtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_ER_X, 0x51); }
-void vsqrtss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_ER_X, 0x51); }
-void vstmxcsr(const Address& addr) { opAVX_X_X_XM(xm3, xm0, addr, T_0F, 0xAE); }
-void vsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5C); }
-void vsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5C); }
-void vsubsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5C); }
-void vsubss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5C); }
-void vtestpd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x0F); }
-void vtestps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x0E); }
-void vucomisd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_66 | T_0F | T_EW1 | T_EVEX | T_SAE_X, 0x2E); }
-void vucomiss(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x2E); }
-void vunpckhpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x15); }
-void vunpckhps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x15); }
-void vunpcklpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x14); }
-void vunpcklps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x14); }
-void vxorpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x57); }
-void vxorps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x57); }
-void vzeroall() { db(0xC5); db(0xFC); db(0x77); }
-void vzeroupper() { db(0xC5); db(0xF8); db(0x77); }
-void wait() { db(0x9B); }
-void wbinvd() { db(0x0F); db(0x09); }
-void wrmsr() { db(0x0F); db(0x30); }
-void xadd(const Operand& op, const Reg& reg) { opModRM(reg, op, (op.isREG() && reg.isREG() && op.getBit() == reg.getBit()), op.isMEM(), 0x0F, 0xC0 | (reg.isBit(8) ? 0 : 1)); }
-void xgetbv() { db(0x0F); db(0x01); db(0xD0); }
-void xlatb() { db(0xD7); }
-void xor_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x30, 6); }
-void xor_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x30); }
-void xorpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x57, 0x66, isXMM_XMMorMEM); }
-void xorps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x57, 0x100, isXMM_XMMorMEM); }
-#ifdef XBYAK_ENABLE_OMITTED_OPERAND
-void vblendpd(const Xmm& x, const Operand& op, uint8 imm) { vblendpd(x, x, op, imm); }
-void vblendps(const Xmm& x, const Operand& op, uint8 imm) { vblendps(x, x, op, imm); }
-void vblendvpd(const Xmm& x1, const Operand& op, const Xmm& x4) { vblendvpd(x1, x1, op, x4); }
-void vblendvps(const Xmm& x1, const Operand& op, const Xmm& x4) { vblendvps(x1, x1, op, x4); }
-void vcmpeq_ospd(const Xmm& x, const Operand& op) { vcmpeq_ospd(x, x, op); }
-void vcmpeq_osps(const Xmm& x, const Operand& op) { vcmpeq_osps(x, x, op); }
-void vcmpeq_ossd(const Xmm& x, const Operand& op) { vcmpeq_ossd(x, x, op); }
-void vcmpeq_osss(const Xmm& x, const Operand& op) { vcmpeq_osss(x, x, op); }
-void vcmpeq_uqpd(const Xmm& x, const Operand& op) { vcmpeq_uqpd(x, x, op); }
-void vcmpeq_uqps(const Xmm& x, const Operand& op) { vcmpeq_uqps(x, x, op); }
-void vcmpeq_uqsd(const Xmm& x, const Operand& op) { vcmpeq_uqsd(x, x, op); }
-void vcmpeq_uqss(const Xmm& x, const Operand& op) { vcmpeq_uqss(x, x, op); }
-void vcmpeq_uspd(const Xmm& x, const Operand& op) { vcmpeq_uspd(x, x, op); }
-void vcmpeq_usps(const Xmm& x, const Operand& op) { vcmpeq_usps(x, x, op); }
-void vcmpeq_ussd(const Xmm& x, const Operand& op) { vcmpeq_ussd(x, x, op); }
-void vcmpeq_usss(const Xmm& x, const Operand& op) { vcmpeq_usss(x, x, op); }
-void vcmpeqpd(const Xmm& x, const Operand& op) { vcmpeqpd(x, x, op); }
-void vcmpeqps(const Xmm& x, const Operand& op) { vcmpeqps(x, x, op); }
-void vcmpeqsd(const Xmm& x, const Operand& op) { vcmpeqsd(x, x, op); }
-void vcmpeqss(const Xmm& x, const Operand& op) { vcmpeqss(x, x, op); }
-void vcmpfalse_ospd(const Xmm& x, const Operand& op) { vcmpfalse_ospd(x, x, op); }
-void vcmpfalse_osps(const Xmm& x, const Operand& op) { vcmpfalse_osps(x, x, op); }
-void vcmpfalse_ossd(const Xmm& x, const Operand& op) { vcmpfalse_ossd(x, x, op); }
-void vcmpfalse_osss(const Xmm& x, const Operand& op) { vcmpfalse_osss(x, x, op); }
-void vcmpfalsepd(const Xmm& x, const Operand& op) { vcmpfalsepd(x, x, op); }
-void vcmpfalseps(const Xmm& x, const Operand& op) { vcmpfalseps(x, x, op); }
-void vcmpfalsesd(const Xmm& x, const Operand& op) { vcmpfalsesd(x, x, op); }
-void vcmpfalsess(const Xmm& x, const Operand& op) { vcmpfalsess(x, x, op); }
-void vcmpge_oqpd(const Xmm& x, const Operand& op) { vcmpge_oqpd(x, x, op); }
-void vcmpge_oqps(const Xmm& x, const Operand& op) { vcmpge_oqps(x, x, op); }
-void vcmpge_oqsd(const Xmm& x, const Operand& op) { vcmpge_oqsd(x, x, op); }
-void vcmpge_oqss(const Xmm& x, const Operand& op) { vcmpge_oqss(x, x, op); }
-void vcmpgepd(const Xmm& x, const Operand& op) { vcmpgepd(x, x, op); }
-void vcmpgeps(const Xmm& x, const Operand& op) { vcmpgeps(x, x, op); }
-void vcmpgesd(const Xmm& x, const Operand& op) { vcmpgesd(x, x, op); }
-void vcmpgess(const Xmm& x, const Operand& op) { vcmpgess(x, x, op); }
-void vcmpgt_oqpd(const Xmm& x, const Operand& op) { vcmpgt_oqpd(x, x, op); }
-void vcmpgt_oqps(const Xmm& x, const Operand& op) { vcmpgt_oqps(x, x, op); }
-void vcmpgt_oqsd(const Xmm& x, const Operand& op) { vcmpgt_oqsd(x, x, op); }
-void vcmpgt_oqss(const Xmm& x, const Operand& op) { vcmpgt_oqss(x, x, op); }
-void vcmpgtpd(const Xmm& x, const Operand& op) { vcmpgtpd(x, x, op); }
-void vcmpgtps(const Xmm& x, const Operand& op) { vcmpgtps(x, x, op); }
-void vcmpgtsd(const Xmm& x, const Operand& op) { vcmpgtsd(x, x, op); }
-void vcmpgtss(const Xmm& x, const Operand& op) { vcmpgtss(x, x, op); }
-void vcmple_oqpd(const Xmm& x, const Operand& op) { vcmple_oqpd(x, x, op); }
-void vcmple_oqps(const Xmm& x, const Operand& op) { vcmple_oqps(x, x, op); }
-void vcmple_oqsd(const Xmm& x, const Operand& op) { vcmple_oqsd(x, x, op); }
-void vcmple_oqss(const Xmm& x, const Operand& op) { vcmple_oqss(x, x, op); }
-void vcmplepd(const Xmm& x, const Operand& op) { vcmplepd(x, x, op); }
-void vcmpleps(const Xmm& x, const Operand& op) { vcmpleps(x, x, op); }
-void vcmplesd(const Xmm& x, const Operand& op) { vcmplesd(x, x, op); }
-void vcmpless(const Xmm& x, const Operand& op) { vcmpless(x, x, op); }
-void vcmplt_oqpd(const Xmm& x, const Operand& op) { vcmplt_oqpd(x, x, op); }
-void vcmplt_oqps(const Xmm& x, const Operand& op) { vcmplt_oqps(x, x, op); }
-void vcmplt_oqsd(const Xmm& x, const Operand& op) { vcmplt_oqsd(x, x, op); }
-void vcmplt_oqss(const Xmm& x, const Operand& op) { vcmplt_oqss(x, x, op); }
-void vcmpltpd(const Xmm& x, const Operand& op) { vcmpltpd(x, x, op); }
-void vcmpltps(const Xmm& x, const Operand& op) { vcmpltps(x, x, op); }
-void vcmpltsd(const Xmm& x, const Operand& op) { vcmpltsd(x, x, op); }
-void vcmpltss(const Xmm& x, const Operand& op) { vcmpltss(x, x, op); }
-void vcmpneq_oqpd(const Xmm& x, const Operand& op) { vcmpneq_oqpd(x, x, op); }
-void vcmpneq_oqps(const Xmm& x, const Operand& op) { vcmpneq_oqps(x, x, op); }
-void vcmpneq_oqsd(const Xmm& x, const Operand& op) { vcmpneq_oqsd(x, x, op); }
-void vcmpneq_oqss(const Xmm& x, const Operand& op) { vcmpneq_oqss(x, x, op); }
-void vcmpneq_ospd(const Xmm& x, const Operand& op) { vcmpneq_ospd(x, x, op); }
-void vcmpneq_osps(const Xmm& x, const Operand& op) { vcmpneq_osps(x, x, op); }
-void vcmpneq_ossd(const Xmm& x, const Operand& op) { vcmpneq_ossd(x, x, op); }
-void vcmpneq_osss(const Xmm& x, const Operand& op) { vcmpneq_osss(x, x, op); }
-void vcmpneq_uspd(const Xmm& x, const Operand& op) { vcmpneq_uspd(x, x, op); }
-void vcmpneq_usps(const Xmm& x, const Operand& op) { vcmpneq_usps(x, x, op); }
-void vcmpneq_ussd(const Xmm& x, const Operand& op) { vcmpneq_ussd(x, x, op); }
-void vcmpneq_usss(const Xmm& x, const Operand& op) { vcmpneq_usss(x, x, op); }
-void vcmpneqpd(const Xmm& x, const Operand& op) { vcmpneqpd(x, x, op); }
-void vcmpneqps(const Xmm& x, const Operand& op) { vcmpneqps(x, x, op); }
-void vcmpneqsd(const Xmm& x, const Operand& op) { vcmpneqsd(x, x, op); }
-void vcmpneqss(const Xmm& x, const Operand& op) { vcmpneqss(x, x, op); }
-void vcmpnge_uqpd(const Xmm& x, const Operand& op) { vcmpnge_uqpd(x, x, op); }
-void vcmpnge_uqps(const Xmm& x, const Operand& op) { vcmpnge_uqps(x, x, op); }
-void vcmpnge_uqsd(const Xmm& x, const Operand& op) { vcmpnge_uqsd(x, x, op); }
-void vcmpnge_uqss(const Xmm& x, const Operand& op) { vcmpnge_uqss(x, x, op); }
-void vcmpngepd(const Xmm& x, const Operand& op) { vcmpngepd(x, x, op); }
-void vcmpngeps(const Xmm& x, const Operand& op) { vcmpngeps(x, x, op); }
-void vcmpngesd(const Xmm& x, const Operand& op) { vcmpngesd(x, x, op); }
-void vcmpngess(const Xmm& x, const Operand& op) { vcmpngess(x, x, op); }
-void vcmpngt_uqpd(const Xmm& x, const Operand& op) { vcmpngt_uqpd(x, x, op); }
-void vcmpngt_uqps(const Xmm& x, const Operand& op) { vcmpngt_uqps(x, x, op); }
-void vcmpngt_uqsd(const Xmm& x, const Operand& op) { vcmpngt_uqsd(x, x, op); }
-void vcmpngt_uqss(const Xmm& x, const Operand& op) { vcmpngt_uqss(x, x, op); }
-void vcmpngtpd(const Xmm& x, const Operand& op) { vcmpngtpd(x, x, op); }
-void vcmpngtps(const Xmm& x, const Operand& op) { vcmpngtps(x, x, op); }
-void vcmpngtsd(const Xmm& x, const Operand& op) { vcmpngtsd(x, x, op); }
-void vcmpngtss(const Xmm& x, const Operand& op) { vcmpngtss(x, x, op); }
-void vcmpnle_uqpd(const Xmm& x, const Operand& op) { vcmpnle_uqpd(x, x, op); }
-void vcmpnle_uqps(const Xmm& x, const Operand& op) { vcmpnle_uqps(x, x, op); }
-void vcmpnle_uqsd(const Xmm& x, const Operand& op) { vcmpnle_uqsd(x, x, op); }
-void vcmpnle_uqss(const Xmm& x, const Operand& op) { vcmpnle_uqss(x, x, op); }
-void vcmpnlepd(const Xmm& x, const Operand& op) { vcmpnlepd(x, x, op); }
-void vcmpnleps(const Xmm& x, const Operand& op) { vcmpnleps(x, x, op); }
-void vcmpnlesd(const Xmm& x, const Operand& op) { vcmpnlesd(x, x, op); }
-void vcmpnless(const Xmm& x, const Operand& op) { vcmpnless(x, x, op); }
-void vcmpnlt_uqpd(const Xmm& x, const Operand& op) { vcmpnlt_uqpd(x, x, op); }
-void vcmpnlt_uqps(const Xmm& x, const Operand& op) { vcmpnlt_uqps(x, x, op); }
-void vcmpnlt_uqsd(const Xmm& x, const Operand& op) { vcmpnlt_uqsd(x, x, op); }
-void vcmpnlt_uqss(const Xmm& x, const Operand& op) { vcmpnlt_uqss(x, x, op); }
-void vcmpnltpd(const Xmm& x, const Operand& op) { vcmpnltpd(x, x, op); }
-void vcmpnltps(const Xmm& x, const Operand& op) { vcmpnltps(x, x, op); }
-void vcmpnltsd(const Xmm& x, const Operand& op) { vcmpnltsd(x, x, op); }
-void vcmpnltss(const Xmm& x, const Operand& op) { vcmpnltss(x, x, op); }
-void vcmpord_spd(const Xmm& x, const Operand& op) { vcmpord_spd(x, x, op); }
-void vcmpord_sps(const Xmm& x, const Operand& op) { vcmpord_sps(x, x, op); }
-void vcmpord_ssd(const Xmm& x, const Operand& op) { vcmpord_ssd(x, x, op); }
-void vcmpord_sss(const Xmm& x, const Operand& op) { vcmpord_sss(x, x, op); }
-void vcmpordpd(const Xmm& x, const Operand& op) { vcmpordpd(x, x, op); }
-void vcmpordps(const Xmm& x, const Operand& op) { vcmpordps(x, x, op); }
-void vcmpordsd(const Xmm& x, const Operand& op) { vcmpordsd(x, x, op); }
-void vcmpordss(const Xmm& x, const Operand& op) { vcmpordss(x, x, op); }
-void vcmppd(const Xmm& x, const Operand& op, uint8 imm) { vcmppd(x, x, op, imm); }
-void vcmpps(const Xmm& x, const Operand& op, uint8 imm) { vcmpps(x, x, op, imm); }
-void vcmpsd(const Xmm& x, const Operand& op, uint8 imm) { vcmpsd(x, x, op, imm); }
-void vcmpss(const Xmm& x, const Operand& op, uint8 imm) { vcmpss(x, x, op, imm); }
-void vcmptrue_uspd(const Xmm& x, const Operand& op) { vcmptrue_uspd(x, x, op); }
-void vcmptrue_usps(const Xmm& x, const Operand& op) { vcmptrue_usps(x, x, op); }
-void vcmptrue_ussd(const Xmm& x, const Operand& op) { vcmptrue_ussd(x, x, op); }
-void vcmptrue_usss(const Xmm& x, const Operand& op) { vcmptrue_usss(x, x, op); }
-void vcmptruepd(const Xmm& x, const Operand& op) { vcmptruepd(x, x, op); }
-void vcmptrueps(const Xmm& x, const Operand& op) { vcmptrueps(x, x, op); }
-void vcmptruesd(const Xmm& x, const Operand& op) { vcmptruesd(x, x, op); }
-void vcmptruess(const Xmm& x, const Operand& op) { vcmptruess(x, x, op); }
-void vcmpunord_spd(const Xmm& x, const Operand& op) { vcmpunord_spd(x, x, op); }
-void vcmpunord_sps(const Xmm& x, const Operand& op) { vcmpunord_sps(x, x, op); }
-void vcmpunord_ssd(const Xmm& x, const Operand& op) { vcmpunord_ssd(x, x, op); }
-void vcmpunord_sss(const Xmm& x, const Operand& op) { vcmpunord_sss(x, x, op); }
-void vcmpunordpd(const Xmm& x, const Operand& op) { vcmpunordpd(x, x, op); }
-void vcmpunordps(const Xmm& x, const Operand& op) { vcmpunordps(x, x, op); }
-void vcmpunordsd(const Xmm& x, const Operand& op) { vcmpunordsd(x, x, op); }
-void vcmpunordss(const Xmm& x, const Operand& op) { vcmpunordss(x, x, op); }
-void vcvtsd2ss(const Xmm& x, const Operand& op) { vcvtsd2ss(x, x, op); }
-void vcvtsi2sd(const Xmm& x, const Operand& op) { vcvtsi2sd(x, x, op); }
-void vcvtsi2ss(const Xmm& x, const Operand& op) { vcvtsi2ss(x, x, op); }
-void vcvtss2sd(const Xmm& x, const Operand& op) { vcvtss2sd(x, x, op); }
-void vdppd(const Xmm& x, const Operand& op, uint8 imm) { vdppd(x, x, op, imm); }
-void vdpps(const Xmm& x, const Operand& op, uint8 imm) { vdpps(x, x, op, imm); }
-void vinsertps(const Xmm& x, const Operand& op, uint8 imm) { vinsertps(x, x, op, imm); }
-void vmpsadbw(const Xmm& x, const Operand& op, uint8 imm) { vmpsadbw(x, x, op, imm); }
-void vpackssdw(const Xmm& x, const Operand& op) { vpackssdw(x, x, op); }
-void vpacksswb(const Xmm& x, const Operand& op) { vpacksswb(x, x, op); }
-void vpackusdw(const Xmm& x, const Operand& op) { vpackusdw(x, x, op); }
-void vpackuswb(const Xmm& x, const Operand& op) { vpackuswb(x, x, op); }
-void vpaddb(const Xmm& x, const Operand& op) { vpaddb(x, x, op); }
-void vpaddd(const Xmm& x, const Operand& op) { vpaddd(x, x, op); }
-void vpaddq(const Xmm& x, const Operand& op) { vpaddq(x, x, op); }
-void vpaddsb(const Xmm& x, const Operand& op) { vpaddsb(x, x, op); }
-void vpaddsw(const Xmm& x, const Operand& op) { vpaddsw(x, x, op); }
-void vpaddusb(const Xmm& x, const Operand& op) { vpaddusb(x, x, op); }
-void vpaddusw(const Xmm& x, const Operand& op) { vpaddusw(x, x, op); }
-void vpaddw(const Xmm& x, const Operand& op) { vpaddw(x, x, op); }
-void vpalignr(const Xmm& x, const Operand& op, uint8 imm) { vpalignr(x, x, op, imm); }
-void vpand(const Xmm& x, const Operand& op) { vpand(x, x, op); }
-void vpandn(const Xmm& x, const Operand& op) { vpandn(x, x, op); }
-void vpavgb(const Xmm& x, const Operand& op) { vpavgb(x, x, op); }
-void vpavgw(const Xmm& x, const Operand& op) { vpavgw(x, x, op); }
-void vpblendd(const Xmm& x, const Operand& op, uint8 imm) { vpblendd(x, x, op, imm); }
-void vpblendvb(const Xmm& x1, const Operand& op, const Xmm& x4) { vpblendvb(x1, x1, op, x4); }
-void vpblendw(const Xmm& x, const Operand& op, uint8 imm) { vpblendw(x, x, op, imm); }
-void vpclmulqdq(const Xmm& x, const Operand& op, uint8 imm) { vpclmulqdq(x, x, op, imm); }
-void vpcmpeqb(const Xmm& x, const Operand& op) { vpcmpeqb(x, x, op); }
-void vpcmpeqd(const Xmm& x, const Operand& op) { vpcmpeqd(x, x, op); }
-void vpcmpeqq(const Xmm& x, const Operand& op) { vpcmpeqq(x, x, op); }
-void vpcmpeqw(const Xmm& x, const Operand& op) { vpcmpeqw(x, x, op); }
-void vpcmpgtb(const Xmm& x, const Operand& op) { vpcmpgtb(x, x, op); }
-void vpcmpgtd(const Xmm& x, const Operand& op) { vpcmpgtd(x, x, op); }
-void vpcmpgtq(const Xmm& x, const Operand& op) { vpcmpgtq(x, x, op); }
-void vpcmpgtw(const Xmm& x, const Operand& op) { vpcmpgtw(x, x, op); }
-void vphaddd(const Xmm& x, const Operand& op) { vphaddd(x, x, op); }
-void vphaddsw(const Xmm& x, const Operand& op) { vphaddsw(x, x, op); }
-void vphaddw(const Xmm& x, const Operand& op) { vphaddw(x, x, op); }
-void vphsubd(const Xmm& x, const Operand& op) { vphsubd(x, x, op); }
-void vphsubsw(const Xmm& x, const Operand& op) { vphsubsw(x, x, op); }
-void vphsubw(const Xmm& x, const Operand& op) { vphsubw(x, x, op); }
-void vpinsrb(const Xmm& x, const Operand& op, uint8 imm) { vpinsrb(x, x, op, imm); }
-void vpinsrd(const Xmm& x, const Operand& op, uint8 imm) { vpinsrd(x, x, op, imm); }
-void vpinsrq(const Xmm& x, const Operand& op, uint8 imm) { vpinsrq(x, x, op, imm); }
-void vpinsrw(const Xmm& x, const Operand& op, uint8 imm) { vpinsrw(x, x, op, imm); }
-void vpmaddubsw(const Xmm& x, const Operand& op) { vpmaddubsw(x, x, op); }
-void vpmaddwd(const Xmm& x, const Operand& op) { vpmaddwd(x, x, op); }
-void vpmaxsb(const Xmm& x, const Operand& op) { vpmaxsb(x, x, op); }
-void vpmaxsd(const Xmm& x, const Operand& op) { vpmaxsd(x, x, op); }
-void vpmaxsw(const Xmm& x, const Operand& op) { vpmaxsw(x, x, op); }
-void vpmaxub(const Xmm& x, const Operand& op) { vpmaxub(x, x, op); }
-void vpmaxud(const Xmm& x, const Operand& op) { vpmaxud(x, x, op); }
-void vpmaxuw(const Xmm& x, const Operand& op) { vpmaxuw(x, x, op); }
-void vpminsb(const Xmm& x, const Operand& op) { vpminsb(x, x, op); }
-void vpminsd(const Xmm& x, const Operand& op) { vpminsd(x, x, op); }
-void vpminsw(const Xmm& x, const Operand& op) { vpminsw(x, x, op); }
-void vpminub(const Xmm& x, const Operand& op) { vpminub(x, x, op); }
-void vpminud(const Xmm& x, const Operand& op) { vpminud(x, x, op); }
-void vpminuw(const Xmm& x, const Operand& op) { vpminuw(x, x, op); }
-void vpmuldq(const Xmm& x, const Operand& op) { vpmuldq(x, x, op); }
-void vpmulhrsw(const Xmm& x, const Operand& op) { vpmulhrsw(x, x, op); }
-void vpmulhuw(const Xmm& x, const Operand& op) { vpmulhuw(x, x, op); }
-void vpmulhw(const Xmm& x, const Operand& op) { vpmulhw(x, x, op); }
-void vpmulld(const Xmm& x, const Operand& op) { vpmulld(x, x, op); }
-void vpmullw(const Xmm& x, const Operand& op) { vpmullw(x, x, op); }
-void vpmuludq(const Xmm& x, const Operand& op) { vpmuludq(x, x, op); }
-void vpor(const Xmm& x, const Operand& op) { vpor(x, x, op); }
-void vpsadbw(const Xmm& x, const Operand& op) { vpsadbw(x, x, op); }
-void vpsignb(const Xmm& x, const Operand& op) { vpsignb(x, x, op); }
-void vpsignd(const Xmm& x, const Operand& op) { vpsignd(x, x, op); }
-void vpsignw(const Xmm& x, const Operand& op) { vpsignw(x, x, op); }
-void vpslld(const Xmm& x, const Operand& op) { vpslld(x, x, op); }
-void vpslld(const Xmm& x, uint8 imm) { vpslld(x, x, imm); }
-void vpslldq(const Xmm& x, uint8 imm) { vpslldq(x, x, imm); }
-void vpsllq(const Xmm& x, const Operand& op) { vpsllq(x, x, op); }
-void vpsllq(const Xmm& x, uint8 imm) { vpsllq(x, x, imm); }
-void vpsllw(const Xmm& x, const Operand& op) { vpsllw(x, x, op); }
-void vpsllw(const Xmm& x, uint8 imm) { vpsllw(x, x, imm); }
-void vpsrad(const Xmm& x, const Operand& op) { vpsrad(x, x, op); }
-void vpsrad(const Xmm& x, uint8 imm) { vpsrad(x, x, imm); }
-void vpsraw(const Xmm& x, const Operand& op) { vpsraw(x, x, op); }
-void vpsraw(const Xmm& x, uint8 imm) { vpsraw(x, x, imm); }
-void vpsrld(const Xmm& x, const Operand& op) { vpsrld(x, x, op); }
-void vpsrld(const Xmm& x, uint8 imm) { vpsrld(x, x, imm); }
-void vpsrldq(const Xmm& x, uint8 imm) { vpsrldq(x, x, imm); }
-void vpsrlq(const Xmm& x, const Operand& op) { vpsrlq(x, x, op); }
-void vpsrlq(const Xmm& x, uint8 imm) { vpsrlq(x, x, imm); }
-void vpsrlw(const Xmm& x, const Operand& op) { vpsrlw(x, x, op); }
-void vpsrlw(const Xmm& x, uint8 imm) { vpsrlw(x, x, imm); }
-void vpsubb(const Xmm& x, const Operand& op) { vpsubb(x, x, op); }
-void vpsubd(const Xmm& x, const Operand& op) { vpsubd(x, x, op); }
-void vpsubq(const Xmm& x, const Operand& op) { vpsubq(x, x, op); }
-void vpsubsb(const Xmm& x, const Operand& op) { vpsubsb(x, x, op); }
-void vpsubsw(const Xmm& x, const Operand& op) { vpsubsw(x, x, op); }
-void vpsubusb(const Xmm& x, const Operand& op) { vpsubusb(x, x, op); }
-void vpsubusw(const Xmm& x, const Operand& op) { vpsubusw(x, x, op); }
-void vpsubw(const Xmm& x, const Operand& op) { vpsubw(x, x, op); }
-void vpunpckhbw(const Xmm& x, const Operand& op) { vpunpckhbw(x, x, op); }
-void vpunpckhdq(const Xmm& x, const Operand& op) { vpunpckhdq(x, x, op); }
-void vpunpckhqdq(const Xmm& x, const Operand& op) { vpunpckhqdq(x, x, op); }
-void vpunpckhwd(const Xmm& x, const Operand& op) { vpunpckhwd(x, x, op); }
-void vpunpcklbw(const Xmm& x, const Operand& op) { vpunpcklbw(x, x, op); }
-void vpunpckldq(const Xmm& x, const Operand& op) { vpunpckldq(x, x, op); }
-void vpunpcklqdq(const Xmm& x, const Operand& op) { vpunpcklqdq(x, x, op); }
-void vpunpcklwd(const Xmm& x, const Operand& op) { vpunpcklwd(x, x, op); }
-void vpxor(const Xmm& x, const Operand& op) { vpxor(x, x, op); }
-void vrcpss(const Xmm& x, const Operand& op) { vrcpss(x, x, op); }
-void vroundsd(const Xmm& x, const Operand& op, uint8 imm) { vroundsd(x, x, op, imm); }
-void vroundss(const Xmm& x, const Operand& op, uint8 imm) { vroundss(x, x, op, imm); }
-void vrsqrtss(const Xmm& x, const Operand& op) { vrsqrtss(x, x, op); }
-void vshufpd(const Xmm& x, const Operand& op, uint8 imm) { vshufpd(x, x, op, imm); }
-void vshufps(const Xmm& x, const Operand& op, uint8 imm) { vshufps(x, x, op, imm); }
-void vsqrtsd(const Xmm& x, const Operand& op) { vsqrtsd(x, x, op); }
-void vsqrtss(const Xmm& x, const Operand& op) { vsqrtss(x, x, op); }
-void vunpckhpd(const Xmm& x, const Operand& op) { vunpckhpd(x, x, op); }
-void vunpckhps(const Xmm& x, const Operand& op) { vunpckhps(x, x, op); }
-void vunpcklpd(const Xmm& x, const Operand& op) { vunpcklpd(x, x, op); }
-void vunpcklps(const Xmm& x, const Operand& op) { vunpcklps(x, x, op); }
-#endif
-#ifdef XBYAK64
-void jecxz(std::string label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); }
-void jecxz(const Label& label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); }
-void jrcxz(std::string label) { opJmp(label, T_SHORT, 0xe3, 0, 0); }
-void jrcxz(const Label& label) { opJmp(label, T_SHORT, 0xe3, 0, 0); }
-void cdqe() { db(0x48); db(0x98); }
-void cqo() { db(0x48); db(0x99); }
-void cmpsq() { db(0x48); db(0xA7); }
-void movsq() { db(0x48); db(0xA5); }
-void scasq() { db(0x48); db(0xAF); }
-void stosq() { db(0x48); db(0xAB); }
-void cmpxchg16b(const Address& addr) { opModM(addr, Reg64(1), 0x0F, 0xC7); }
-void movq(const Reg64& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x7E); }
-void movq(const Mmx& mmx, const Reg64& reg) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x6E); }
-void movsxd(const Reg64& reg, const Operand& op) { if (!op.isBit(32)) throw Error(ERR_BAD_COMBINATION); opModRM(reg, op, op.isREG(), op.isMEM(), 0x63); }
-void pextrq(const Operand& op, const Xmm& xmm, uint8 imm) { if (!op.isREG(64) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(Reg64(xmm.getIdx()), op, 0x16, 0x66, 0, imm, 0x3A); }
-void pinsrq(const Xmm& xmm, const Operand& op, uint8 imm) { if (!op.isREG(64) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(Reg64(xmm.getIdx()), op, 0x22, 0x66, 0, imm, 0x3A); }
-void vcvtss2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W1 | T_EVEX | T_EW1 | T_ER_X | T_N8, 0x2D); }
-void vcvttss2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W1 | T_EVEX | T_EW1 | T_SAE_X | T_N8, 0x2C); }
-void vcvtsd2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W1 | T_EVEX | T_EW1 | T_N4 | T_ER_X, 0x2D); }
-void vcvttsd2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W1 | T_EVEX | T_EW1 | T_N4 | T_SAE_X, 0x2C); }
-void vmovq(const Xmm& x, const Reg64& r) { opAVX_X_X_XM(x, xm0, Xmm(r.getIdx()), T_66 | T_0F | T_W1 | T_EVEX | T_EW1, 0x6E); }
-void vmovq(const Reg64& r, const Xmm& x) { opAVX_X_X_XM(x, xm0, Xmm(r.getIdx()), T_66 | T_0F | T_W1 | T_EVEX | T_EW1, 0x7E); }
-#else
-void jcxz(std::string label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); }
-void jcxz(const Label& label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); }
-void jecxz(std::string label) { opJmp(label, T_SHORT, 0xe3, 0, 0); }
-void jecxz(const Label& label) { opJmp(label, T_SHORT, 0xe3, 0, 0); }
-void aaa() { db(0x37); }
-void aad() { db(0xD5); db(0x0A); }
-void aam() { db(0xD4); db(0x0A); }
-void aas() { db(0x3F); }
-void daa() { db(0x27); }
-void das() { db(0x2F); }
-void popad() { db(0x61); }
-void popfd() { db(0x9D); }
-void pusha() { db(0x60); }
-void pushad() { db(0x60); }
-void pushfd() { db(0x9C); }
-void popa() { db(0x61); }
-#endif
-#ifndef XBYAK_NO_OP_NAMES
-void and(const Operand& op1, const Operand& op2) { and_(op1, op2); }
-void and(const Operand& op, uint32 imm) { and_(op, imm); }
-void or(const Operand& op1, const Operand& op2) { or_(op1, op2); }
-void or(const Operand& op, uint32 imm) { or_(op, imm); }
-void xor(const Operand& op1, const Operand& op2) { xor_(op1, op2); }
-void xor(const Operand& op, uint32 imm) { xor_(op, imm); }
-void not(const Operand& op) { not_(op); }
-#endif
-#ifndef XBYAK_DISABLE_AVX512
-void kaddb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x4A); }
-void kaddd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x4A); }
-void kaddq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x4A); }
-void kaddw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x4A); }
-void kandb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x41); }
-void kandd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x41); }
-void kandnb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x42); }
-void kandnd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x42); }
-void kandnq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x42); }
-void kandnw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x42); }
-void kandq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x41); }
-void kandw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x41); }
-void kmovb(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_66 | T_W0, 0x91); }
-void kmovb(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_66 | T_W0, 0x90); }
-void kmovb(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_66 | T_W0, 0x92); }
-void kmovb(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_66 | T_W0, 0x93); }
-void kmovd(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_66 | T_W1, 0x91); }
-void kmovd(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_66 | T_W1, 0x90); }
-void kmovd(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_F2 | T_W0, 0x92); }
-void kmovd(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_F2 | T_W0, 0x93); }
-void kmovq(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_W1, 0x91); }
-void kmovq(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_W1, 0x90); }
-void kmovw(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_W0, 0x91); }
-void kmovw(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_W0, 0x90); }
-void kmovw(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_W0, 0x92); }
-void kmovw(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_W0, 0x93); }
-void knotb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x44); }
-void knotd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x44); }
-void knotq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x44); }
-void knotw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x44); }
-void korb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x45); }
-void kord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x45); }
-void korq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x45); }
-void kortestb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x98); }
-void kortestd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x98); }
-void kortestq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x98); }
-void kortestw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x98); }
-void korw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x45); }
-void kshiftlb(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x32, imm); }
-void kshiftld(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x33, imm); }
-void kshiftlq(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x33, imm); }
-void kshiftlw(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x32, imm); }
-void kshiftrb(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x30, imm); }
-void kshiftrd(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x31, imm); }
-void kshiftrq(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x31, imm); }
-void kshiftrw(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x30, imm); }
-void ktestb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x99); }
-void ktestd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x99); }
-void ktestq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x99); }
-void ktestw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x99); }
-void kunpckbw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x4B); }
-void kunpckdq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x4B); }
-void kunpckwd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x4B); }
-void kxnorb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x46); }
-void kxnord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x46); }
-void kxnorq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x46); }
-void kxnorw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x46); }
-void kxorb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x47); }
-void kxord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x47); }
-void kxorq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x47); }
-void kxorw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x47); }
-void v4fmaddps(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x9A); }
-void v4fmaddss(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_F2 | T_EW0 | T_MUST_EVEX | T_N16, 0x9B); }
-void v4fnmaddps(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0xAA); }
-void v4fnmaddss(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_F2 | T_EW0 | T_MUST_EVEX | T_N16, 0xAB); }
-void valignd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x03, imm); }
-void valignq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x03, imm); }
-void vblendmpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x65); }
-void vblendmps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x65); }
-void vbroadcastf32x2(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N8, 0x19); }
-void vbroadcastf32x4(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N16, 0x1A); }
-void vbroadcastf32x8(const Zmm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N32, 0x1B); }
-void vbroadcastf64x2(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N16, 0x1A); }
-void vbroadcastf64x4(const Zmm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N32, 0x1B); }
-void vbroadcasti32x2(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N8, 0x59); }
-void vbroadcasti32x4(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N16, 0x5A); }
-void vbroadcasti32x8(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N32, 0x5B); }
-void vbroadcasti64x2(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N16, 0x5A); }
-void vbroadcasti64x4(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N32, 0x5B); }
-void vcmppd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); }
-void vcmpps(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_0F | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); }
-void vcmpsd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_N8 | T_F2 | T_0F | T_EW1 | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); }
-void vcmpss(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_N4 | T_F3 | T_0F | T_EW0 | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); }
-void vcompressb(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x63); }
-void vcompresspd(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8A); }
-void vcompressps(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8A); }
-void vcompressw(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x63); }
-void vcvtpd2qq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x7B); }
-void vcvtpd2udq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x79); }
-void vcvtpd2uqq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x79); }
-void vcvtps2qq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_ER_Y, 0x7B); }
-void vcvtps2udq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_0F | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x79); }
-void vcvtps2uqq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_ER_Y, 0x79); }
-void vcvtqq2pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0xE6); }
-void vcvtqq2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x5B); }
-void vcvtsd2usi(const Reg32e& r, const Operand& op) { int type = (T_F2 | T_0F | T_MUST_EVEX | T_N8 | T_ER_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x79); }
-void vcvtss2usi(const Reg32e& r, const Operand& op) { int type = (T_F3 | T_0F | T_MUST_EVEX | T_N4 | T_ER_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x79); }
-void vcvttpd2qq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x7A); }
-void vcvttpd2udq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_SAE_Z, 0x78); }
-void vcvttpd2uqq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x78); }
-void vcvttps2qq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x7A); }
-void vcvttps2udq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_0F | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x78); }
-void vcvttps2uqq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x78); }
-void vcvttsd2usi(const Reg32e& r, const Operand& op) { int type = (T_F2 | T_0F | T_MUST_EVEX | T_N8 | T_SAE_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x78); }
-void vcvttss2usi(const Reg32e& r, const Operand& op) { int type = (T_F3 | T_0F | T_MUST_EVEX | T_N4 | T_SAE_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x78); }
-void vcvtudq2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_F3 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL, 0x7A); }
-void vcvtudq2ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x7A); }
-void vcvtuqq2pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x7A); }
-void vcvtuqq2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_F2 | T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x7A); }
-void vcvtusi2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_F2 | T_0F | T_MUST_EVEX, T_W1 | T_EW1 | T_ER_X | T_N8, T_W0 | T_EW0 | T_N4, 0x7B); }
-void vcvtusi2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_F3 | T_0F | T_MUST_EVEX | T_ER_X, T_W1 | T_EW1 | T_N8, T_W0 | T_EW0 | T_N4, 0x7B); }
-void vdbpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x42, imm); }
-void vexp2pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xC8); }
-void vexp2ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xC8); }
-void vexpandpd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x88); }
-void vexpandps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x88); }
-void vextractf32x4(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x19, imm); }
-void vextractf32x8(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x1B, imm); }
-void vextractf64x2(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x19, imm); }
-void vextractf64x4(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x1B, imm); }
-void vextracti32x4(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x39, imm); }
-void vextracti32x8(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3B, imm); }
-void vextracti64x2(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x39, imm); }
-void vextracti64x4(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3B, imm); }
-void vfixupimmpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x54, imm); }
-void vfixupimmps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x54, imm); }
-void vfixupimmsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_Z | T_MUST_EVEX, 0x55, imm); }
-void vfixupimmss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_Z | T_MUST_EVEX, 0x55, imm); }
-void vfpclasspd(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isBit(128|256|512)) throw Error(ERR_BAD_MEM_SIZE); Reg x = k; x.setBit(op.getBit()); opVex(x, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW1 | T_B64, 0x66, imm); }
-void vfpclassps(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isBit(128|256|512)) throw Error(ERR_BAD_MEM_SIZE); Reg x = k; x.setBit(op.getBit()); opVex(x, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW0 | T_B32, 0x66, imm); }
-void vfpclasssd(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isXMEM()) throw Error(ERR_BAD_MEM_SIZE); opVex(k, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_EW1 | T_N8, 0x67, imm); }
-void vfpclassss(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isXMEM()) throw Error(ERR_BAD_MEM_SIZE); opVex(k, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_EW0 | T_N4, 0x67, imm); }
-void vgatherdpd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x92, 1); }
-void vgatherdps(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x92, 0); }
-void vgatherpf0dpd(const Address& addr) { opGatherFetch(addr, zm1, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); }
-void vgatherpf0dps(const Address& addr) { opGatherFetch(addr, zm1, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); }
-void vgatherpf0qpd(const Address& addr) { opGatherFetch(addr, zm1, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); }
-void vgatherpf0qps(const Address& addr) { opGatherFetch(addr, zm1, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); }
-void vgatherpf1dpd(const Address& addr) { opGatherFetch(addr, zm2, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); }
-void vgatherpf1dps(const Address& addr) { opGatherFetch(addr, zm2, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); }
-void vgatherpf1qpd(const Address& addr) { opGatherFetch(addr, zm2, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); }
-void vgatherpf1qps(const Address& addr) { opGatherFetch(addr, zm2, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); }
-void vgatherqpd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x93, 0); }
-void vgatherqps(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x93, 2); }
-void vgetexppd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x42); }
-void vgetexpps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x42); }
-void vgetexpsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x43); }
-void vgetexpss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x43); }
-void vgetmantpd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x26, imm); }
-void vgetmantps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x26, imm); }
-void vgetmantsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x27, imm); }
-void vgetmantss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x27, imm); }
-void vinsertf32x4(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x18, imm); }
-void vinsertf32x8(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x1A, imm); }
-void vinsertf64x2(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x18, imm); }
-void vinsertf64x4(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x1A, imm); }
-void vinserti32x4(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x38, imm); }
-void vinserti32x8(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3A, imm); }
-void vinserti64x2(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x38, imm); }
-void vinserti64x4(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3A, imm); }
-void vmovdqa32(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_66 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); }
-void vmovdqa32(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); }
-void vmovdqa64(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_66 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); }
-void vmovdqa64(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); }
-void vmovdqu16(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F2 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); }
-void vmovdqu16(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); }
-void vmovdqu32(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F3 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); }
-void vmovdqu32(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); }
-void vmovdqu64(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); }
-void vmovdqu64(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); }
-void vmovdqu8(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); }
-void vmovdqu8(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); }
-void vp4dpwssd(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x52); }
-void vp4dpwssds(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x53); }
-void vpabsq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_MUST_EVEX | T_EW1 | T_B64 | T_YMM, 0x1F); }
-void vpandd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xDB); }
-void vpandnd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xDF); }
-void vpandnq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xDF); }
-void vpandq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xDB); }
-void vpblendmb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x66); }
-void vpblendmd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x64); }
-void vpblendmq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x64); }
-void vpblendmw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x66); }
-void vpbroadcastb(const Xmm& x, const Reg8& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7A); }
-void vpbroadcastd(const Xmm& x, const Reg32& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7C); }
-void vpbroadcastmb2q(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1, 0x2A); }
-void vpbroadcastmw2d(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0, 0x3A); }
-void vpbroadcastw(const Xmm& x, const Reg16& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7B); }
-void vpcmpb(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3F, imm); }
-void vpcmpd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x1F, imm); }
-void vpcmpeqb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x74); }
-void vpcmpeqd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_B32, 0x76); }
-void vpcmpeqq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x29); }
-void vpcmpeqw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x75); }
-void vpcmpgtb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x64); }
-void vpcmpgtd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x66); }
-void vpcmpgtq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x37); }
-void vpcmpgtw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x65); }
-void vpcmpq(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x1F, imm); }
-void vpcmpub(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3E, imm); }
-void vpcmpud(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x1E, imm); }
-void vpcmpuq(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x1E, imm); }
-void vpcmpuw(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3E, imm); }
-void vpcmpw(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3F, imm); }
-void vpcompressd(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8B); }
-void vpcompressq(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8B); }
-void vpconflictd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xC4); }
-void vpconflictq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xC4); }
-void vpdpbusd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x50); }
-void vpdpbusds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x51); }
-void vpdpwssd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x52); }
-void vpdpwssds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x53); }
-void vpermb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8D); }
-void vpermi2b(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x75); }
-void vpermi2d(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x76); }
-void vpermi2pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x77); }
-void vpermi2ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x77); }
-void vpermi2q(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x76); }
-void vpermi2w(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x75); }
-void vpermt2b(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7D); }
-void vpermt2d(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x7E); }
-void vpermt2pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x7F); }
-void vpermt2ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x7F); }
-void vpermt2q(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x7E); }
-void vpermt2w(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x7D); }
-void vpermw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8D); }
-void vpexpandb(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x62); }
-void vpexpandd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x89); }
-void vpexpandq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x89); }
-void vpexpandw(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x62); }
-void vpgatherdd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x90, 0); }
-void vpgatherdq(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x90, 1); }
-void vpgatherqd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x91, 2); }
-void vpgatherqq(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x91, 0); }
-void vplzcntd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x44); }
-void vplzcntq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x44); }
-void vpmadd52huq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xB5); }
-void vpmadd52luq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xB4); }
-void vpmaxsq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3D); }
-void vpmaxuq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3F); }
-void vpminsq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x39); }
-void vpminuq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3B); }
-void vpmovb2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x29); }
-void vpmovd2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x39); }
-void vpmovdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x31, false); }
-void vpmovdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x33, true); }
-void vpmovm2b(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x28); }
-void vpmovm2d(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x38); }
-void vpmovm2q(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x38); }
-void vpmovm2w(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x28); }
-void vpmovq2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x39); }
-void vpmovqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x32, false); }
-void vpmovqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x35, true); }
-void vpmovqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x34, false); }
-void vpmovsdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x21, false); }
-void vpmovsdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x23, true); }
-void vpmovsqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x22, false); }
-void vpmovsqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x25, true); }
-void vpmovsqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x24, false); }
-void vpmovswb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x20, true); }
-void vpmovusdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x11, false); }
-void vpmovusdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x13, true); }
-void vpmovusqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x12, false); }
-void vpmovusqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x15, true); }
-void vpmovusqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x14, false); }
-void vpmovuswb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x10, true); }
-void vpmovw2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x29); }
-void vpmovwb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x30, true); }
-void vpmullq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x40); }
-void vpmultishiftqb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x83); }
-void vpopcntb(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x54); }
-void vpopcntd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x55); }
-void vpopcntq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x55); }
-void vpopcntw(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x54); }
-void vpord(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xEB); }
-void vporq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xEB); }
-void vprold(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 1), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x72, imm); }
-void vprolq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 1), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); }
-void vprolvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x15); }
-void vprolvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x15); }
-void vprord(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 0), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x72, imm); }
-void vprorq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 0), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); }
-void vprorvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x14); }
-void vprorvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x14); }
-void vpscatterdd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA0, 0); }
-void vpscatterdq(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA0, 1); }
-void vpscatterqd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA1, 2); }
-void vpscatterqq(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA1, 0); }
-void vpshldd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x71, imm); }
-void vpshldq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x71, imm); }
-void vpshldvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x71); }
-void vpshldvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x71); }
-void vpshldvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x70); }
-void vpshldw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x70, imm); }
-void vpshrdd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x73, imm); }
-void vpshrdq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x73, imm); }
-void vpshrdvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x73); }
-void vpshrdvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x73); }
-void vpshrdvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x72); }
-void vpshrdw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x72, imm); }
-void vpshufbitqmb(const Opmask& k, const Xmm& x, const Operand& op) { opVex(k, &x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8F); }
-void vpsllvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x12); }
-void vpsraq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); }
-void vpsraq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX, 0xE2); }
-void vpsravq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x46); }
-void vpsravw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x11); }
-void vpsrlvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x10); }
-void vpternlogd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x25, imm); }
-void vpternlogq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x25, imm); }
-void vptestmb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x26); }
-void vptestmd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x27); }
-void vptestmq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x27); }
-void vptestmw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x26); }
-void vptestnmb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x26); }
-void vptestnmd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x27); }
-void vptestnmq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x27); }
-void vptestnmw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x26); }
-void vpxord(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xEF); }
-void vpxorq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xEF); }
-void vrangepd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x50, imm); }
-void vrangeps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x50, imm); }
-void vrangesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x51, imm); }
-void vrangess(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x51, imm); }
-void vrcp14pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x4C); }
-void vrcp14ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x4C); }
-void vrcp14sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX, 0x4D); }
-void vrcp14ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX, 0x4D); }
-void vrcp28pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xCA); }
-void vrcp28ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xCA); }
-void vrcp28sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0xCB); }
-void vrcp28ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0xCB); }
-void vreducepd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x56, imm); }
-void vreduceps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x56, imm); }
-void vreducesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x57, imm); }
-void vreducess(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x57, imm); }
-void vrndscalepd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x09, imm); }
-void vrndscaleps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x08, imm); }
-void vrndscalesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_MUST_EVEX, 0x0B, imm); }
-void vrndscaless(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_MUST_EVEX, 0x0A, imm); }
-void vrsqrt14pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x4E); }
-void vrsqrt14ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x4E); }
-void vrsqrt14sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x4F); }
-void vrsqrt14ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x4F); }
-void vrsqrt28pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xCC); }
-void vrsqrt28ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xCC); }
-void vrsqrt28sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0xCD); }
-void vrsqrt28ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0xCD); }
-void vscalefpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x2C); }
-void vscalefps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x2C); }
-void vscalefsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_ER_X | T_MUST_EVEX, 0x2D); }
-void vscalefss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_ER_X | T_MUST_EVEX, 0x2D); }
-void vscatterdpd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA2, 1); }
-void vscatterdps(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA2, 0); }
-void vscatterpf0dpd(const Address& addr) { opGatherFetch(addr, zm5, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); }
-void vscatterpf0dps(const Address& addr) { opGatherFetch(addr, zm5, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); }
-void vscatterpf0qpd(const Address& addr) { opGatherFetch(addr, zm5, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); }
-void vscatterpf0qps(const Address& addr) { opGatherFetch(addr, zm5, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); }
-void vscatterpf1dpd(const Address& addr) { opGatherFetch(addr, zm6, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); }
-void vscatterpf1dps(const Address& addr) { opGatherFetch(addr, zm6, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); }
-void vscatterpf1qpd(const Address& addr) { opGatherFetch(addr, zm6, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); }
-void vscatterpf1qps(const Address& addr) { opGatherFetch(addr, zm6, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); }
-void vscatterqpd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA3, 0); }
-void vscatterqps(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA3, 2); }
-void vshuff32x4(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW0 | T_B32, 0x23, imm); }
-void vshuff64x2(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW1 | T_B64, 0x23, imm); }
-void vshufi32x4(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW0 | T_B32, 0x43, imm); }
-void vshufi64x2(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW1 | T_B64, 0x43, imm); }
-#ifdef XBYAK64
-void kmovq(const Opmask& k, const Reg64& r) { opVex(k, 0, r, T_L0 | T_0F | T_F2 | T_W1, 0x92); }
-void kmovq(const Reg64& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_F2 | T_W1, 0x93); }
-void vpbroadcastq(const Xmm& x, const Reg64& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x7C); }
-#endif
-#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h
deleted file mode 100644
index 8ef076e680..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h
+++ /dev/null
@@ -1,772 +0,0 @@
-/*******************************************************************************
-* Copyright 2016-2019 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-/*******************************************************************************
-* Copyright (c) 2007 MITSUNARI Shigeo
-* All rights reserved.
-*
-* Redistribution and use in source and binary forms, with or without
-* modification, are permitted provided that the following conditions are met:
-*
-* Redistributions of source code must retain the above copyright notice, this
-* list of conditions and the following disclaimer.
-* Redistributions in binary form must reproduce the above copyright notice,
-* this list of conditions and the following disclaimer in the documentation
-* and/or other materials provided with the distribution.
-* Neither the name of the copyright owner nor the names of its contributors may
-* be used to endorse or promote products derived from this software without
-* specific prior written permission.
-*
-* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
-* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
-* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
-* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
-* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
-* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
-* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
-* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
-* THE POSSIBILITY OF SUCH DAMAGE.
-*******************************************************************************/
-
-#ifndef XBYAK_XBYAK_UTIL_H_
-#define XBYAK_XBYAK_UTIL_H_
-
-/**
- utility class and functions for Xbyak
- Xbyak::util::Clock ; rdtsc timer
- Xbyak::util::Cpu ; detect CPU
- @note this header is UNDER CONSTRUCTION!
-*/
-#include "xbyak.h"
-
-#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64)
- #define XBYAK_INTEL_CPU_SPECIFIC
-#endif
-
-#ifdef XBYAK_INTEL_CPU_SPECIFIC
-#ifdef _MSC_VER
- #if (_MSC_VER < 1400) && defined(XBYAK32)
- static inline __declspec(naked) void __cpuid(int[4], int)
- {
- __asm {
- push ebx
- push esi
- mov eax, dword ptr [esp + 4 * 2 + 8] // eaxIn
- cpuid
- mov esi, dword ptr [esp + 4 * 2 + 4] // data
- mov dword ptr [esi], eax
- mov dword ptr [esi + 4], ebx
- mov dword ptr [esi + 8], ecx
- mov dword ptr [esi + 12], edx
- pop esi
- pop ebx
- ret
- }
- }
- #else
- #include <intrin.h> // for __cpuid
- #endif
-#else
- #ifndef __GNUC_PREREQ
- #define __GNUC_PREREQ(major, minor) ((((__GNUC__) << 16) + (__GNUC_MINOR__)) >= (((major) << 16) + (minor)))
- #endif
- #if __GNUC_PREREQ(4, 3) && !defined(__APPLE__)
- #include <cpuid.h>
- #else
- #if defined(__APPLE__) && defined(XBYAK32) // avoid err : can't find a register in class `BREG' while reloading `asm'
- #define __cpuid(eaxIn, a, b, c, d) __asm__ __volatile__("pushl %%ebx\ncpuid\nmovl %%ebp, %%esi\npopl %%ebx" : "=a"(a), "=S"(b), "=c"(c), "=d"(d) : "0"(eaxIn))
- #define __cpuid_count(eaxIn, ecxIn, a, b, c, d) __asm__ __volatile__("pushl %%ebx\ncpuid\nmovl %%ebp, %%esi\npopl %%ebx" : "=a"(a), "=S"(b), "=c"(c), "=d"(d) : "0"(eaxIn), "2"(ecxIn))
- #else
- #define __cpuid(eaxIn, a, b, c, d) __asm__ __volatile__("cpuid\n" : "=a"(a), "=b"(b), "=c"(c), "=d"(d) : "0"(eaxIn))
- #define __cpuid_count(eaxIn, ecxIn, a, b, c, d) __asm__ __volatile__("cpuid\n" : "=a"(a), "=b"(b), "=c"(c), "=d"(d) : "0"(eaxIn), "2"(ecxIn))
- #endif
- #endif
-#endif
-#endif
-
-namespace Xbyak { namespace util {
-
-typedef enum {
- SmtLevel = 1,
- CoreLevel = 2
-} IntelCpuTopologyLevel;
-
-/**
- CPU detection class
-*/
-class Cpu {
- uint64 type_;
- //system topology
- bool x2APIC_supported_;
- static const size_t maxTopologyLevels = 2;
- unsigned int numCores_[maxTopologyLevels];
-
- static const unsigned int maxNumberCacheLevels = 10;
- unsigned int dataCacheSize_[maxNumberCacheLevels];
- unsigned int coresSharignDataCache_[maxNumberCacheLevels];
- unsigned int dataCacheLevels_;
-
- unsigned int get32bitAsBE(const char *x) const
- {
- return x[0] | (x[1] << 8) | (x[2] << 16) | (x[3] << 24);
- }
- unsigned int mask(int n) const
- {
- return (1U << n) - 1;
- }
- void setFamily()
- {
- unsigned int data[4] = {};
- getCpuid(1, data);
- stepping = data[0] & mask(4);
- model = (data[0] >> 4) & mask(4);
- family = (data[0] >> 8) & mask(4);
- // type = (data[0] >> 12) & mask(2);
- extModel = (data[0] >> 16) & mask(4);
- extFamily = (data[0] >> 20) & mask(8);
- if (family == 0x0f) {
- displayFamily = family + extFamily;
- } else {
- displayFamily = family;
- }
- if (family == 6 || family == 0x0f) {
- displayModel = (extModel << 4) + model;
- } else {
- displayModel = model;
- }
- }
- unsigned int extractBit(unsigned int val, unsigned int base, unsigned int end)
- {
- return (val >> base) & ((1u << (end - base)) - 1);
- }
- void setNumCores()
- {
- if ((type_ & tINTEL) == 0) return;
-
- unsigned int data[4] = {};
-
- /* CAUTION: These numbers are configuration as shipped by Intel. */
- getCpuidEx(0x0, 0, data);
- if (data[0] >= 0xB) {
- /*
- if leaf 11 exists(x2APIC is supported),
- we use it to get the number of smt cores and cores on socket
-
- leaf 0xB can be zeroed-out by a hypervisor
- */
- x2APIC_supported_ = true;
- for (unsigned int i = 0; i < maxTopologyLevels; i++) {
- getCpuidEx(0xB, i, data);
- IntelCpuTopologyLevel level = (IntelCpuTopologyLevel)extractBit(data[2], 8, 15);
- if (level == SmtLevel || level == CoreLevel) {
- numCores_[level - 1] = extractBit(data[1], 0, 15);
- }
- }
- } else {
- /*
- Failed to deremine num of cores without x2APIC support.
- TODO: USE initial APIC ID to determine ncores.
- */
- numCores_[SmtLevel - 1] = 0;
- numCores_[CoreLevel - 1] = 0;
- }
-
- }
- void setCacheHierarchy()
- {
- if ((type_ & tINTEL) == 0) return;
- const unsigned int NO_CACHE = 0;
- const unsigned int DATA_CACHE = 1;
-// const unsigned int INSTRUCTION_CACHE = 2;
- const unsigned int UNIFIED_CACHE = 3;
- unsigned int smt_width = 0;
- unsigned int logical_cores = 0;
- unsigned int data[4] = {};
-
- if (x2APIC_supported_) {
- smt_width = numCores_[0];
- logical_cores = numCores_[1];
- }
-
- /*
- Assumptions:
- the first level of data cache is not shared (which is the
- case for every existing architecture) and use this to
- determine the SMT width for arch not supporting leaf 11.
- when leaf 4 reports a number of core less than numCores_
- on socket reported by leaf 11, then it is a correct number
- of cores not an upperbound.
- */
- for (int i = 0; dataCacheLevels_ < maxNumberCacheLevels; i++) {
- getCpuidEx(0x4, i, data);
- unsigned int cacheType = extractBit(data[0], 0, 4);
- if (cacheType == NO_CACHE) break;
- if (cacheType == DATA_CACHE || cacheType == UNIFIED_CACHE) {
- unsigned int actual_logical_cores = extractBit(data[0], 14, 25) + 1;
- if (logical_cores != 0) { // true only if leaf 0xB is supported and valid
- actual_logical_cores = (std::min)(actual_logical_cores, logical_cores);
- }
- assert(actual_logical_cores != 0);
- dataCacheSize_[dataCacheLevels_] =
- (extractBit(data[1], 22, 31) + 1)
- * (extractBit(data[1], 12, 21) + 1)
- * (extractBit(data[1], 0, 11) + 1)
- * (data[2] + 1);
- if (cacheType == DATA_CACHE && smt_width == 0) smt_width = actual_logical_cores;
- assert(smt_width != 0);
- // FIXME: check and fix number of cores sharing L3 cache for different configurations
- // (HT-, 2 sockets), (HT-, 1 socket), (HT+, 2 sockets), (HT+, 1 socket)
- coresSharignDataCache_[dataCacheLevels_] = (std::max)(actual_logical_cores / smt_width, 1u);
- dataCacheLevels_++;
- }
- }
- }
-
-public:
- int model;
- int family;
- int stepping;
- int extModel;
- int extFamily;
- int displayFamily; // family + extFamily
- int displayModel; // model + extModel
-
- unsigned int getNumCores(IntelCpuTopologyLevel level) {
- if (level != SmtLevel && level != CoreLevel) throw Error(ERR_BAD_PARAMETER);
- if (!x2APIC_supported_) throw Error(ERR_X2APIC_IS_NOT_SUPPORTED);
- return (level == CoreLevel)
- ? numCores_[level - 1] / numCores_[SmtLevel - 1]
- : numCores_[level - 1];
- }
-
- unsigned int getDataCacheLevels() const { return dataCacheLevels_; }
- unsigned int getCoresSharingDataCache(unsigned int i) const
- {
- if (i >= dataCacheLevels_) throw Error(ERR_BAD_PARAMETER);
- return coresSharignDataCache_[i];
- }
- unsigned int getDataCacheSize(unsigned int i) const
- {
- if (i >= dataCacheLevels_) throw Error(ERR_BAD_PARAMETER);
- return dataCacheSize_[i];
- }
-
- /*
- data[] = { eax, ebx, ecx, edx }
- */
- static inline void getCpuid(unsigned int eaxIn, unsigned int data[4])
- {
-#ifdef XBYAK_INTEL_CPU_SPECIFIC
- #ifdef _MSC_VER
- __cpuid(reinterpret_cast<int*>(data), eaxIn);
- #else
- __cpuid(eaxIn, data[0], data[1], data[2], data[3]);
- #endif
-#else
- (void)eaxIn;
- (void)data;
-#endif
- }
- static inline void getCpuidEx(unsigned int eaxIn, unsigned int ecxIn, unsigned int data[4])
- {
-#ifdef XBYAK_INTEL_CPU_SPECIFIC
- #ifdef _MSC_VER
- __cpuidex(reinterpret_cast<int*>(data), eaxIn, ecxIn);
- #else
- __cpuid_count(eaxIn, ecxIn, data[0], data[1], data[2], data[3]);
- #endif
-#else
- (void)eaxIn;
- (void)ecxIn;
- (void)data;
-#endif
- }
- static inline uint64 getXfeature()
- {
-#ifdef XBYAK_INTEL_CPU_SPECIFIC
- #ifdef _MSC_VER
- return _xgetbv(0);
- #else
- unsigned int eax, edx;
- // xgetvb is not support on gcc 4.2
-// __asm__ volatile("xgetbv" : "=a"(eax), "=d"(edx) : "c"(0));
- __asm__ volatile(".byte 0x0f, 0x01, 0xd0" : "=a"(eax), "=d"(edx) : "c"(0));
- return ((uint64)edx << 32) | eax;
- #endif
-#else
- return 0;
-#endif
- }
- typedef uint64 Type;
-
- static const Type NONE = 0;
- static const Type tMMX = 1 << 0;
- static const Type tMMX2 = 1 << 1;
- static const Type tCMOV = 1 << 2;
- static const Type tSSE = 1 << 3;
- static const Type tSSE2 = 1 << 4;
- static const Type tSSE3 = 1 << 5;
- static const Type tSSSE3 = 1 << 6;
- static const Type tSSE41 = 1 << 7;
- static const Type tSSE42 = 1 << 8;
- static const Type tPOPCNT = 1 << 9;
- static const Type tAESNI = 1 << 10;
- static const Type tSSE5 = 1 << 11;
- static const Type tOSXSAVE = 1 << 12;
- static const Type tPCLMULQDQ = 1 << 13;
- static const Type tAVX = 1 << 14;
- static const Type tFMA = 1 << 15;
-
- static const Type t3DN = 1 << 16;
- static const Type tE3DN = 1 << 17;
- static const Type tSSE4a = 1 << 18;
- static const Type tRDTSCP = 1 << 19;
- static const Type tAVX2 = 1 << 20;
- static const Type tBMI1 = 1 << 21; // andn, bextr, blsi, blsmsk, blsr, tzcnt
- static const Type tBMI2 = 1 << 22; // bzhi, mulx, pdep, pext, rorx, sarx, shlx, shrx
- static const Type tLZCNT = 1 << 23;
-
- static const Type tINTEL = 1 << 24;
- static const Type tAMD = 1 << 25;
-
- static const Type tENHANCED_REP = 1 << 26; // enhanced rep movsb/stosb
- static const Type tRDRAND = 1 << 27;
- static const Type tADX = 1 << 28; // adcx, adox
- static const Type tRDSEED = 1 << 29; // rdseed
- static const Type tSMAP = 1 << 30; // stac
- static const Type tHLE = uint64(1) << 31; // xacquire, xrelease, xtest
- static const Type tRTM = uint64(1) << 32; // xbegin, xend, xabort
- static const Type tF16C = uint64(1) << 33; // vcvtph2ps, vcvtps2ph
- static const Type tMOVBE = uint64(1) << 34; // mobve
- static const Type tAVX512F = uint64(1) << 35;
- static const Type tAVX512DQ = uint64(1) << 36;
- static const Type tAVX512_IFMA = uint64(1) << 37;
- static const Type tAVX512IFMA = tAVX512_IFMA;
- static const Type tAVX512PF = uint64(1) << 38;
- static const Type tAVX512ER = uint64(1) << 39;
- static const Type tAVX512CD = uint64(1) << 40;
- static const Type tAVX512BW = uint64(1) << 41;
- static const Type tAVX512VL = uint64(1) << 42;
- static const Type tAVX512_VBMI = uint64(1) << 43;
- static const Type tAVX512VBMI = tAVX512_VBMI; // changed by Intel's manual
- static const Type tAVX512_4VNNIW = uint64(1) << 44;
- static const Type tAVX512_4FMAPS = uint64(1) << 45;
- static const Type tPREFETCHWT1 = uint64(1) << 46;
- static const Type tPREFETCHW = uint64(1) << 47;
- static const Type tSHA = uint64(1) << 48;
- static const Type tMPX = uint64(1) << 49;
- static const Type tAVX512_VBMI2 = uint64(1) << 50;
- static const Type tGFNI = uint64(1) << 51;
- static const Type tVAES = uint64(1) << 52;
- static const Type tVPCLMULQDQ = uint64(1) << 53;
- static const Type tAVX512_VNNI = uint64(1) << 54;
- static const Type tAVX512_BITALG = uint64(1) << 55;
- static const Type tAVX512_VPOPCNTDQ = uint64(1) << 56;
-
- Cpu()
- : type_(NONE)
- , x2APIC_supported_(false)
- , numCores_()
- , dataCacheSize_()
- , coresSharignDataCache_()
- , dataCacheLevels_(0)
- {
- unsigned int data[4] = {};
- const unsigned int& EAX = data[0];
- const unsigned int& EBX = data[1];
- const unsigned int& ECX = data[2];
- const unsigned int& EDX = data[3];
- getCpuid(0, data);
- const unsigned int maxNum = EAX;
- static const char intel[] = "ntel";
- static const char amd[] = "cAMD";
- if (ECX == get32bitAsBE(amd)) {
- type_ |= tAMD;
- getCpuid(0x80000001, data);
- if (EDX & (1U << 31)) type_ |= t3DN;
- if (EDX & (1U << 15)) type_ |= tCMOV;
- if (EDX & (1U << 30)) type_ |= tE3DN;
- if (EDX & (1U << 22)) type_ |= tMMX2;
- if (EDX & (1U << 27)) type_ |= tRDTSCP;
- }
- if (ECX == get32bitAsBE(intel)) {
- type_ |= tINTEL;
- getCpuid(0x80000001, data);
- if (EDX & (1U << 27)) type_ |= tRDTSCP;
- if (ECX & (1U << 5)) type_ |= tLZCNT;
- if (ECX & (1U << 8)) type_ |= tPREFETCHW;
- }
- getCpuid(1, data);
- if (ECX & (1U << 0)) type_ |= tSSE3;
- if (ECX & (1U << 9)) type_ |= tSSSE3;
- if (ECX & (1U << 19)) type_ |= tSSE41;
- if (ECX & (1U << 20)) type_ |= tSSE42;
- if (ECX & (1U << 22)) type_ |= tMOVBE;
- if (ECX & (1U << 23)) type_ |= tPOPCNT;
- if (ECX & (1U << 25)) type_ |= tAESNI;
- if (ECX & (1U << 1)) type_ |= tPCLMULQDQ;
- if (ECX & (1U << 27)) type_ |= tOSXSAVE;
- if (ECX & (1U << 30)) type_ |= tRDRAND;
- if (ECX & (1U << 29)) type_ |= tF16C;
-
- if (EDX & (1U << 15)) type_ |= tCMOV;
- if (EDX & (1U << 23)) type_ |= tMMX;
- if (EDX & (1U << 25)) type_ |= tMMX2 | tSSE;
- if (EDX & (1U << 26)) type_ |= tSSE2;
-
- if (type_ & tOSXSAVE) {
- // check XFEATURE_ENABLED_MASK[2:1] = '11b'
- uint64 bv = getXfeature();
- if ((bv & 6) == 6) {
- if (ECX & (1U << 28)) type_ |= tAVX;
- if (ECX & (1U << 12)) type_ |= tFMA;
- if (((bv >> 5) & 7) == 7) {
- getCpuidEx(7, 0, data);
- if (EBX & (1U << 16)) type_ |= tAVX512F;
- if (type_ & tAVX512F) {
- if (EBX & (1U << 17)) type_ |= tAVX512DQ;
- if (EBX & (1U << 21)) type_ |= tAVX512_IFMA;
- if (EBX & (1U << 26)) type_ |= tAVX512PF;
- if (EBX & (1U << 27)) type_ |= tAVX512ER;
- if (EBX & (1U << 28)) type_ |= tAVX512CD;
- if (EBX & (1U << 30)) type_ |= tAVX512BW;
- if (EBX & (1U << 31)) type_ |= tAVX512VL;
- if (ECX & (1U << 1)) type_ |= tAVX512_VBMI;
- if (ECX & (1U << 6)) type_ |= tAVX512_VBMI2;
- if (ECX & (1U << 8)) type_ |= tGFNI;
- if (ECX & (1U << 9)) type_ |= tVAES;
- if (ECX & (1U << 10)) type_ |= tVPCLMULQDQ;
- if (ECX & (1U << 11)) type_ |= tAVX512_VNNI;
- if (ECX & (1U << 12)) type_ |= tAVX512_BITALG;
- if (ECX & (1U << 14)) type_ |= tAVX512_VPOPCNTDQ;
- if (EDX & (1U << 2)) type_ |= tAVX512_4VNNIW;
- if (EDX & (1U << 3)) type_ |= tAVX512_4FMAPS;
- }
- }
- }
- }
- if (maxNum >= 7) {
- getCpuidEx(7, 0, data);
- if (type_ & tAVX && (EBX & (1U << 5))) type_ |= tAVX2;
- if (EBX & (1U << 3)) type_ |= tBMI1;
- if (EBX & (1U << 8)) type_ |= tBMI2;
- if (EBX & (1U << 9)) type_ |= tENHANCED_REP;
- if (EBX & (1U << 18)) type_ |= tRDSEED;
- if (EBX & (1U << 19)) type_ |= tADX;
- if (EBX & (1U << 20)) type_ |= tSMAP;
- if (EBX & (1U << 4)) type_ |= tHLE;
- if (EBX & (1U << 11)) type_ |= tRTM;
- if (EBX & (1U << 14)) type_ |= tMPX;
- if (EBX & (1U << 29)) type_ |= tSHA;
- if (ECX & (1U << 0)) type_ |= tPREFETCHWT1;
- }
- setFamily();
- setNumCores();
- setCacheHierarchy();
- }
- void putFamily() const
- {
- printf("family=%d, model=%X, stepping=%d, extFamily=%d, extModel=%X\n",
- family, model, stepping, extFamily, extModel);
- printf("display:family=%X, model=%X\n", displayFamily, displayModel);
- }
- bool has(Type type) const
- {
- return (type & type_) != 0;
- }
-};
-
-class Clock {
-public:
- static inline uint64 getRdtsc()
- {
-#ifdef XBYAK_INTEL_CPU_SPECIFIC
- #ifdef _MSC_VER
- return __rdtsc();
- #else
- unsigned int eax, edx;
- __asm__ volatile("rdtsc" : "=a"(eax), "=d"(edx));
- return ((uint64)edx << 32) | eax;
- #endif
-#else
- // TODO: Need another impl of Clock or rdtsc-equivalent for non-x86 cpu
- return 0;
-#endif
- }
- Clock()
- : clock_(0)
- , count_(0)
- {
- }
- void begin()
- {
- clock_ -= getRdtsc();
- }
- void end()
- {
- clock_ += getRdtsc();
- count_++;
- }
- int getCount() const { return count_; }
- uint64 getClock() const { return clock_; }
- void clear() { count_ = 0; clock_ = 0; }
-private:
- uint64 clock_;
- int count_;
-};
-
-#ifdef XBYAK64
-const int UseRCX = 1 << 6;
-const int UseRDX = 1 << 7;
-
-class Pack {
- static const size_t maxTblNum = 15;
- const Xbyak::Reg64 *tbl_[maxTblNum];
- size_t n_;
-public:
- Pack() : tbl_(), n_(0) {}
- Pack(const Xbyak::Reg64 *tbl, size_t n) { init(tbl, n); }
- Pack(const Pack& rhs)
- : n_(rhs.n_)
- {
- for (size_t i = 0; i < n_; i++) tbl_[i] = rhs.tbl_[i];
- }
- Pack& operator=(const Pack& rhs)
- {
- n_ = rhs.n_;
- for (size_t i = 0; i < n_; i++) tbl_[i] = rhs.tbl_[i];
- return *this;
- }
- Pack(const Xbyak::Reg64& t0)
- { n_ = 1; tbl_[0] = &t0; }
- Pack(const Xbyak::Reg64& t1, const Xbyak::Reg64& t0)
- { n_ = 2; tbl_[0] = &t0; tbl_[1] = &t1; }
- Pack(const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0)
- { n_ = 3; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; }
- Pack(const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0)
- { n_ = 4; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; }
- Pack(const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0)
- { n_ = 5; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; }
- Pack(const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0)
- { n_ = 6; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; }
- Pack(const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0)
- { n_ = 7; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; }
- Pack(const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0)
- { n_ = 8; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; }
- Pack(const Xbyak::Reg64& t8, const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0)
- { n_ = 9; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; tbl_[8] = &t8; }
- Pack(const Xbyak::Reg64& t9, const Xbyak::Reg64& t8, const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0)
- { n_ = 10; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; tbl_[8] = &t8; tbl_[9] = &t9; }
- Pack& append(const Xbyak::Reg64& t)
- {
- if (n_ == maxTblNum) {
- fprintf(stderr, "ERR Pack::can't append\n");
- throw Error(ERR_BAD_PARAMETER);
- }
- tbl_[n_++] = &t;
- return *this;
- }
- void init(const Xbyak::Reg64 *tbl, size_t n)
- {
- if (n > maxTblNum) {
- fprintf(stderr, "ERR Pack::init bad n=%d\n", (int)n);
- throw Error(ERR_BAD_PARAMETER);
- }
- n_ = n;
- for (size_t i = 0; i < n; i++) {
- tbl_[i] = &tbl[i];
- }
- }
- const Xbyak::Reg64& operator[](size_t n) const
- {
- if (n >= n_) {
- fprintf(stderr, "ERR Pack bad n=%d(%d)\n", (int)n, (int)n_);
- throw Error(ERR_BAD_PARAMETER);
- }
- return *tbl_[n];
- }
- size_t size() const { return n_; }
- /*
- get tbl[pos, pos + num)
- */
- Pack sub(size_t pos, size_t num = size_t(-1)) const
- {
- if (num == size_t(-1)) num = n_ - pos;
- if (pos + num > n_) {
- fprintf(stderr, "ERR Pack::sub bad pos=%d, num=%d\n", (int)pos, (int)num);
- throw Error(ERR_BAD_PARAMETER);
- }
- Pack pack;
- pack.n_ = num;
- for (size_t i = 0; i < num; i++) {
- pack.tbl_[i] = tbl_[pos + i];
- }
- return pack;
- }
- void put() const
- {
- for (size_t i = 0; i < n_; i++) {
- printf("%s ", tbl_[i]->toString());
- }
- printf("\n");
- }
-};
-
-class StackFrame {
-#ifdef XBYAK64_WIN
- static const int noSaveNum = 6;
- static const int rcxPos = 0;
- static const int rdxPos = 1;
-#else
- static const int noSaveNum = 8;
- static const int rcxPos = 3;
- static const int rdxPos = 2;
-#endif
- static const int maxRegNum = 14; // maxRegNum = 16 - rsp - rax
- Xbyak::CodeGenerator *code_;
- int pNum_;
- int tNum_;
- bool useRcx_;
- bool useRdx_;
- int saveNum_;
- int P_;
- bool makeEpilog_;
- Xbyak::Reg64 pTbl_[4];
- Xbyak::Reg64 tTbl_[maxRegNum];
- Pack p_;
- Pack t_;
- StackFrame(const StackFrame&);
- void operator=(const StackFrame&);
-public:
- const Pack& p;
- const Pack& t;
- /*
- make stack frame
- @param sf [in] this
- @param pNum [in] num of function parameter(0 <= pNum <= 4)
- @param tNum [in] num of temporary register(0 <= tNum, with UseRCX, UseRDX) #{pNum + tNum [+rcx] + [rdx]} <= 14
- @param stackSizeByte [in] local stack size
- @param makeEpilog [in] automatically call close() if true
-
- you can use
- rax
- gp0, ..., gp(pNum - 1)
- gt0, ..., gt(tNum-1)
- rcx if tNum & UseRCX
- rdx if tNum & UseRDX
- rsp[0..stackSizeByte - 1]
- */
- StackFrame(Xbyak::CodeGenerator *code, int pNum, int tNum = 0, int stackSizeByte = 0, bool makeEpilog = true)
- : code_(code)
- , pNum_(pNum)
- , tNum_(tNum & ~(UseRCX | UseRDX))
- , useRcx_((tNum & UseRCX) != 0)
- , useRdx_((tNum & UseRDX) != 0)
- , saveNum_(0)
- , P_(0)
- , makeEpilog_(makeEpilog)
- , p(p_)
- , t(t_)
- {
- using namespace Xbyak;
- if (pNum < 0 || pNum > 4) throw Error(ERR_BAD_PNUM);
- const int allRegNum = pNum + tNum_ + (useRcx_ ? 1 : 0) + (useRdx_ ? 1 : 0);
- if (tNum_ < 0 || allRegNum > maxRegNum) throw Error(ERR_BAD_TNUM);
- const Reg64& _rsp = code->rsp;
- saveNum_ = (std::max)(0, allRegNum - noSaveNum);
- const int *tbl = getOrderTbl() + noSaveNum;
- for (int i = 0; i < saveNum_; i++) {
- code->push(Reg64(tbl[i]));
- }
- P_ = (stackSizeByte + 7) / 8;
- if (P_ > 0 && (P_ & 1) == (saveNum_ & 1)) P_++; // (rsp % 16) == 8, then increment P_ for 16 byte alignment
- P_ *= 8;
- if (P_ > 0) code->sub(_rsp, P_);
- int pos = 0;
- for (int i = 0; i < pNum; i++) {
- pTbl_[i] = Xbyak::Reg64(getRegIdx(pos));
- }
- for (int i = 0; i < tNum_; i++) {
- tTbl_[i] = Xbyak::Reg64(getRegIdx(pos));
- }
- if (useRcx_ && rcxPos < pNum) code_->mov(code_->r10, code_->rcx);
- if (useRdx_ && rdxPos < pNum) code_->mov(code_->r11, code_->rdx);
- p_.init(pTbl_, pNum);
- t_.init(tTbl_, tNum_);
- }
- /*
- make epilog manually
- @param callRet [in] call ret() if true
- */
- void close(bool callRet = true)
- {
- using namespace Xbyak;
- const Reg64& _rsp = code_->rsp;
- const int *tbl = getOrderTbl() + noSaveNum;
- if (P_ > 0) code_->add(_rsp, P_);
- for (int i = 0; i < saveNum_; i++) {
- code_->pop(Reg64(tbl[saveNum_ - 1 - i]));
- }
-
- if (callRet) code_->ret();
- }
- ~StackFrame()
- {
- if (!makeEpilog_) return;
- try {
- close();
- } catch (std::exception& e) {
- printf("ERR:StackFrame %s\n", e.what());
- //exit(1);
- }
- }
-private:
- const int *getOrderTbl() const
- {
- using namespace Xbyak;
- static const int tbl[] = {
-#ifdef XBYAK64_WIN
- Operand::RCX, Operand::RDX, Operand::R8, Operand::R9, Operand::R10, Operand::R11, Operand::RDI, Operand::RSI,
-#else
- Operand::RDI, Operand::RSI, Operand::RDX, Operand::RCX, Operand::R8, Operand::R9, Operand::R10, Operand::R11,
-#endif
- Operand::RBX, Operand::RBP, Operand::R12, Operand::R13, Operand::R14, Operand::R15
- };
- return &tbl[0];
- }
- int getRegIdx(int& pos) const
- {
- assert(pos < maxRegNum);
- using namespace Xbyak;
- const int *tbl = getOrderTbl();
- int r = tbl[pos++];
- if (useRcx_) {
- if (r == Operand::RCX) { return Operand::R10; }
- if (r == Operand::R10) { r = tbl[pos++]; }
- }
- if (useRdx_) {
- if (r == Operand::RDX) { return Operand::R11; }
- if (r == Operand::R11) { return tbl[pos++]; }
- }
- return r;
- }
-};
-#endif
-
-} } // end of util
-#endif
diff --git a/thirdparty/oidn/patches/godot-changes-c58c5216.patch b/thirdparty/oidn/patches/godot-changes-c58c5216.patch
deleted file mode 100644
index c01f00187b..0000000000
--- a/thirdparty/oidn/patches/godot-changes-c58c5216.patch
+++ /dev/null
@@ -1,337 +0,0 @@
-diff --git a/common/platform.h b/common/platform.h
-index be14bc7..9373b61 100644
---- a/common/platform.h
-+++ b/common/platform.h
-@@ -19,7 +19,7 @@
- #if defined(_WIN32)
- #define WIN32_LEAN_AND_MEAN
- #define NOMINMAX
-- #include <Windows.h>
-+ #include <windows.h>
- #elif defined(__APPLE__)
- #include <sys/sysctl.h>
- #endif
-@@ -129,4 +129,3 @@ namespace oidn {
- std::string getBuildName();
-
- } // namespace oidn
--
-diff --git a/core/autoencoder.cpp b/core/autoencoder.cpp
-index d6915e6..d8da684 100644
---- a/core/autoencoder.cpp
-+++ b/core/autoencoder.cpp
-@@ -90,13 +90,19 @@ namespace oidn {
- if (!dirty)
- return;
-
-- device->executeTask([&]()
-- {
-+ // -- GODOT start --
-+ //device->executeTask([&]()
-+ //{
-+ // GODOT end --
-+
- if (mayiuse(avx512_common))
- net = buildNet<16>();
- else
- net = buildNet<8>();
-- });
-+
-+ // GODOT start --
-+ //});
-+ // GODOT end --
-
- dirty = false;
- }
-@@ -108,9 +114,10 @@ namespace oidn {
-
- if (!net)
- return;
--
-- device->executeTask([&]()
-- {
-+ // -- GODOT start --
-+ //device->executeTask([&]()
-+ //{
-+ // -- GODOT end --
- Progress progress;
- progress.func = progressFunc;
- progress.userPtr = progressUserPtr;
-@@ -156,7 +163,9 @@ namespace oidn {
- tileIndex++;
- }
- }
-- });
-+ // -- GODOT start --
-+ //});
-+ // -- GODOT end --
- }
-
- void AutoencoderFilter::computeTileSize()
-@@ -464,6 +473,11 @@ namespace oidn {
- return std::make_shared<GammaTransferFunction>();
- }
-
-+// -- GODOT start --
-+// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
-+#if 0
-+// -- GODOT end --
-+
- // --------------------------------------------------------------------------
- // RTFilter
- // --------------------------------------------------------------------------
-@@ -491,6 +505,9 @@ namespace oidn {
- weightData.hdr_alb = weights::rt_hdr_alb;
- weightData.hdr_alb_nrm = weights::rt_hdr_alb_nrm;
- }
-+// -- GODOT start --
-+#endif
-+// -- GODOT end --
-
- // --------------------------------------------------------------------------
- // RTLightmapFilter
-diff --git a/core/autoencoder.h b/core/autoencoder.h
-index c199052..98b6108 100644
---- a/core/autoencoder.h
-+++ b/core/autoencoder.h
-@@ -93,11 +93,18 @@ namespace oidn {
- // RTFilter - Generic ray tracing denoiser
- // --------------------------------------------------------------------------
-
-+// -- GODOT start --
-+// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
-+#if 0
-+// -- GODOT end --
- class RTFilter : public AutoencoderFilter
- {
- public:
- explicit RTFilter(const Ref<Device>& device);
- };
-+// -- GODOT start --
-+#endif
-+// -- GODOT end --
-
- // --------------------------------------------------------------------------
- // RTLightmapFilter - Ray traced lightmap denoiser
-diff --git a/core/common.h b/core/common.h
-index a3a7e8a..a35dd90 100644
---- a/core/common.h
-+++ b/core/common.h
-@@ -27,7 +27,9 @@
- #include "common/ref.h"
- #include "common/exception.h"
- #include "common/thread.h"
--#include "common/tasking.h"
-+// -- GODOT start --
-+//#include "common/tasking.h"
-+// -- GODOT end --
- #include "math.h"
-
- namespace oidn {
-diff --git a/core/device.cpp b/core/device.cpp
-index c455695..3cd658b 100644
---- a/core/device.cpp
-+++ b/core/device.cpp
-@@ -29,7 +29,9 @@ namespace oidn {
-
- Device::~Device()
- {
-- observer.reset();
-+ // -- GODOT start --
-+ //observer.reset();
-+ // -- GODOT end --
- }
-
- void Device::setError(Device* device, Error code, const std::string& message)
-@@ -141,6 +143,9 @@ namespace oidn {
- if (isCommitted())
- throw Exception(Error::InvalidOperation, "device can be committed only once");
-
-+ // -- GODOT start --
-+ #if 0
-+ // -- GODOT end --
- // Get the optimal thread affinities
- if (setAffinity)
- {
-@@ -157,7 +162,10 @@ namespace oidn {
- // Automatically set the thread affinities
- if (affinity)
- observer = std::make_shared<PinningObserver>(affinity, *arena);
--
-+ // -- GODOT start --
-+ #endif
-+ numThreads = 1;
-+ // -- GODOT end --
- dirty = false;
-
- if (isVerbose())
-@@ -191,9 +199,17 @@ namespace oidn {
-
- Ref<Filter> filter;
-
-+// -- GODOT start --
-+// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
-+#if 0
-+// -- GODOT end --
- if (type == "RT")
- filter = makeRef<RTFilter>(Ref<Device>(this));
-- else if (type == "RTLightmap")
-+// -- GODOT start --
-+// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
-+#endif
-+ if (type == "RTLightmap")
-+// -- GODOT end --
- filter = makeRef<RTLightmapFilter>(Ref<Device>(this));
- else
- throw Exception(Error::InvalidArgument, "unknown filter type");
-@@ -210,11 +226,12 @@ namespace oidn {
- std::cout << " Build : " << getBuildName() << std::endl;
- std::cout << " Platform: " << getPlatformName() << std::endl;
-
-- std::cout << " Tasking :";
-- std::cout << " TBB" << TBB_VERSION_MAJOR << "." << TBB_VERSION_MINOR;
-- std::cout << " TBB_header_interface_" << TBB_INTERFACE_VERSION << " TBB_lib_interface_" << tbb::TBB_runtime_interface_version();
-- std::cout << std::endl;
--
-+// -- GODOT start --
-+// std::cout << " Tasking :";
-+// std::cout << " TBB" << TBB_VERSION_MAJOR << "." << TBB_VERSION_MINOR;
-+// std::cout << " TBB_header_interface_" << TBB_INTERFACE_VERSION << " TBB_lib_interface_" << tbb::TBB_runtime_interface_version();
-+// std::cout << std::endl;
-+// -- GODOT end --
- std::cout << std::endl;
- }
-
-diff --git a/core/device.h b/core/device.h
-index c2df714..d9cfd85 100644
---- a/core/device.h
-+++ b/core/device.h
-@@ -41,10 +41,12 @@ namespace oidn {
- ErrorFunction errorFunc = nullptr;
- void* errorUserPtr = nullptr;
-
-- // Tasking
-- std::shared_ptr<tbb::task_arena> arena;
-- std::shared_ptr<PinningObserver> observer;
-- std::shared_ptr<ThreadAffinity> affinity;
-+// -- GODOT start --
-+// // Tasking
-+// std::shared_ptr<tbb::task_arena> arena;
-+// std::shared_ptr<PinningObserver> observer;
-+// std::shared_ptr<ThreadAffinity> affinity;
-+// -- GODOT end --
-
- // Parameters
- int numThreads = 0; // autodetect by default
-@@ -66,17 +68,19 @@ namespace oidn {
-
- void commit();
-
-- template<typename F>
-- void executeTask(F& f)
-- {
-- arena->execute(f);
-- }
-+// -- GODOT start --
-+// template<typename F>
-+// void executeTask(F& f)
-+// {
-+// arena->execute(f);
-+// }
-
-- template<typename F>
-- void executeTask(const F& f)
-- {
-- arena->execute(f);
-- }
-+// template<typename F>
-+// void executeTask(const F& f)
-+// {
-+// arena->execute(f);
-+// }
-+// -- GODOT end --
-
- Ref<Buffer> newBuffer(size_t byteSize);
- Ref<Buffer> newBuffer(void* ptr, size_t byteSize);
-@@ -86,7 +90,10 @@ namespace oidn {
- __forceinline std::mutex& getMutex() { return mutex; }
-
- private:
-- bool isCommitted() const { return bool(arena); }
-+// -- GODOT start --
-+ //bool isCommitted() const { return bool(arena); }
-+ bool isCommitted() const { return false; }
-+// -- GODOT end --
- void checkCommitted();
-
- void print();
-diff --git a/core/network.cpp b/core/network.cpp
-index 8c2de09..ed8328c 100644
---- a/core/network.cpp
-+++ b/core/network.cpp
-@@ -17,6 +17,9 @@
- #include "upsample.h"
- #include "weights_reorder.h"
- #include "network.h"
-+// -- GODOT start --
-+#include <cstring>
-+// -- GODOT end --
-
- namespace oidn {
-
-diff --git a/core/transfer_function.cpp b/core/transfer_function.cpp
-index 601f814..ce5deca 100644
---- a/core/transfer_function.cpp
-+++ b/core/transfer_function.cpp
-@@ -38,16 +38,24 @@ namespace oidn {
- // Compute the average log luminance of the downsampled image
- using Sum = std::pair<float, int>;
-
-- Sum sum =
-- tbb::parallel_reduce(
-- tbb::blocked_range2d<int>(0, HK, 0, WK),
-- Sum(0.f, 0),
-- [&](const tbb::blocked_range2d<int>& r, Sum sum) -> Sum
-+ // -- GODOT start --
-+ // Sum sum =
-+ // tbb::parallel_reduce(
-+ // tbb::blocked_range2d<int>(0, HK, 0, WK),
-+ // Sum(0.f, 0),
-+ // [&](const tbb::blocked_range2d<int>& r, Sum sum) -> Sum
-+ // {
-+ // // Iterate over blocks
-+ // for (int i = r.rows().begin(); i != r.rows().end(); ++i)
-+ // {
-+ // for (int j = r.cols().begin(); j != r.cols().end(); ++j)
-+ // {
-+
-+ Sum sum = Sum(0.0f, 0);
-+
-+ for (int i = 0; i != HK; ++i)
- {
-- // Iterate over blocks
-- for (int i = r.rows().begin(); i != r.rows().end(); ++i)
-- {
-- for (int j = r.cols().begin(); j != r.cols().end(); ++j)
-+ for (int j = 0; j != WK; ++j)
- {
- // Compute the average luminance in the current block
- const int beginH = int(ptrdiff_t(i) * H / HK);
-@@ -82,11 +90,12 @@ namespace oidn {
- }
- }
-
-- return sum;
-- },
-- [](Sum a, Sum b) -> Sum { return Sum(a.first+b.first, a.second+b.second); },
-- tbb::static_partitioner()
-- );
-+ // return sum;
-+ // },
-+ // [](Sum a, Sum b) -> Sum { return Sum(a.first+b.first, a.second+b.second); },
-+ // tbb::static_partitioner()
-+ // );
-+ // -- GODOT end --
-
- return (sum.second > 0) ? (key / exp2(sum.first / float(sum.second))) : 1.f;
- }
diff --git a/thirdparty/oidn/patches/mkl-dnn-fix-vs2017-build.patch b/thirdparty/oidn/patches/mkl-dnn-fix-vs2017-build.patch
deleted file mode 100644
index 50d94ebffa..0000000000
--- a/thirdparty/oidn/patches/mkl-dnn-fix-vs2017-build.patch
+++ /dev/null
@@ -1,45 +0,0 @@
-Rediffed by @akien-mga to match oidn 1.1.0 source.
-
-From 1e42e6db81e1a5270ecc0191c5385ce7e7d978e9 Mon Sep 17 00:00:00 2001
-From: Jeremy Wong <jmw@netvigator.com>
-Date: Wed, 11 Sep 2019 04:46:53 +0800
-Subject: [PATCH] src: initialize members in some structures to prevent compile
- errors with VS2017
-
-addresses "error C3615: constexpr function '...' cannot result in a constant expression" with VS2017
----
- src/cpu/rnn/rnn_reorders.hpp | 2 +-
- src/cpu/simple_concat.hpp | 6 +++---
- src/cpu/simple_sum.hpp | 2 +-
- 3 files changed, 5 insertions(+), 5 deletions(-)
-
-diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp
-index 597c63e3f8..ae1551390a 100644
---- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp
-+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp
-@@ -131,7 +131,7 @@ struct rnn_weights_reorder_t : public cpu_primitive_t {
- return status::success;
- }
-
-- format_tag_t itag_;
-+ format_tag_t itag_ = mkldnn_format_tag_undef;
-
- private:
- void init_scratchpad() {
-diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp
-index 5177275452..057cc3c4c7 100644
---- a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp
-+++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp
-@@ -96,9 +96,9 @@ struct simple_concat_t: public cpu_primitive_t {
- return status::success;
- }
-
-- int perm_[MKLDNN_MAX_NDIMS];
-- int iperm_[MKLDNN_MAX_NDIMS];
-- dims_t blocks_;
-+ int perm_[MKLDNN_MAX_NDIMS] {};
-+ int iperm_[MKLDNN_MAX_NDIMS] {};
-+ dims_t blocks_ {};
-
- dim_t nelems_to_concat(const memory_desc_wrapper &data_d) const {
- const int ndims = data_d.ndims();
diff --git a/thirdparty/oidn/weights/LICENSE.txt b/thirdparty/oidn/weights/LICENSE.txt
deleted file mode 100644
index d645695673..0000000000
--- a/thirdparty/oidn/weights/LICENSE.txt
+++ /dev/null
@@ -1,202 +0,0 @@
-
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [yyyy] [name of copyright owner]
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
diff --git a/thirdparty/oidn/weights/rtlightmap_hdr.tza b/thirdparty/oidn/weights/rtlightmap_hdr.tza
deleted file mode 100644
index 12459a33bc..0000000000
--- a/thirdparty/oidn/weights/rtlightmap_hdr.tza
+++ /dev/null
Binary files differ