diff options
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp')
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp | 288 |
1 files changed, 288 insertions, 0 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp new file mode 100644 index 0000000000..e20333e66f --- /dev/null +++ b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp @@ -0,0 +1,288 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include <assert.h> +#include <math.h> + +#include "c_types_map.hpp" +#include "type_helpers.hpp" + +#include "cpu_batch_normalization_utils.hpp" +#include "jit_generator.hpp" + +#include "nspc_batch_normalization.hpp" + +// clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases +#if (defined __clang_major__) && (__clang_major__ >= 6) +#define SAFE_TO_USE_OMP_SIMD 0 +#else +#define SAFE_TO_USE_OMP_SIMD 1 +#endif + +namespace mkldnn { +namespace impl { +namespace cpu { + +using namespace memory_tracking::names; + +void nspc_batch_normalization_fwd_t::execute_forward( + const exec_ctx_t &ctx) const { + const bool save_stats = pd()->is_training(); + const bool is_training = pd()->is_training(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + const bool calculate_stats = !pd()->stats_is_src(); + const bool with_relu = pd()->with_relu_post_op(); + + auto scratchpad = this->scratchpad(ctx); + auto tmp_mean = scratchpad.get<data_t>(key_bnorm_tmp_mean); + auto tmp_var = scratchpad.get<data_t>(key_bnorm_tmp_var); + auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction); + + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + + data_t *mean, *variance; + if (!calculate_stats) { + mean = const_cast<data_t *>( + CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN)); + variance = const_cast<data_t *>( + CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE)); + } else { + if (save_stats) { + mean = CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN); + variance = CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE); + } else { + mean = tmp_mean; + variance = tmp_var; + } + } + + auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); + auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE); + + const dim_t N = pd()->MB(); + const dim_t C = pd()->C(); + const dim_t SP = pd()->H() * pd()->W() * pd()->D(); + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift(); + auto maybe_post_op + = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; }; + + assert(mkldnn_thr_syncable()); + parallel(0, [&](const int ithr, const int nthr) { + dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0; + balance211(N, nthr, ithr, N_s, N_e); + balance211(C, nthr, ithr, C_s, C_e); + data_t *mean_loc = tmp_mean + nstl::max(C, (dim_t)16) * ithr; + data_t *variance_loc = tmp_var + nstl::max(C, (dim_t)16) * ithr; + + if (calculate_stats) { + for (dim_t c = 0; c < C; c++) + ws_reduce[C * ithr + c] = 0.; + + for (dim_t n = N_s; n < N_e; n++) + for (dim_t sp = 0; sp < SP; sp++) + PRAGMA_OMP_SIMD() + for (dim_t c = 0; c < C; c++) + ws_reduce[C * ithr + c] += src[(size_t)n * SP * C + + sp * C + c]; + + mkldnn_thr_barrier(); + + for (dim_t c = C_s; c < C_e; c++) { + mean[c] = 0; + for (dim_t n = 0; n < nthr; n++) + mean[c] += ws_reduce[C * n + c]; + mean[c] /= SP * N; + } + + mkldnn_thr_barrier(); + + for (dim_t c = 0; c < C; c++) { + mean_loc[c] = mean[c]; + ws_reduce[C * ithr + c] = 0.; + } + + for (dim_t n = N_s; n < N_e; n++) + for (dim_t sp = 0; sp < SP; sp++) + PRAGMA_OMP_SIMD() + for (dim_t c = 0; c < C; c++) { + data_t m = src[(size_t)n * SP * C + sp * C + c] + - mean_loc[c]; + ws_reduce[C * ithr + c] += m * m; + } + + mkldnn_thr_barrier(); + + for (dim_t c = C_s; c < C_e; c++) { + variance[c] = 0; + for (dim_t n = 0; n < nthr; n++) + variance[c] += ws_reduce[C * n + c]; + variance[c] /= SP * N; + } + + mkldnn_thr_barrier(); + + for (dim_t c = 0; c < C; c++) + variance_loc[c] = variance[c]; + } else { + variance_loc = variance; + mean_loc = mean; + } + + for (dim_t n = N_s; n < N_e; n++) { + for (dim_t sp = 0; sp < SP; sp++) { +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t c = 0; c < C; c++) { + data_t sqrt_variance = static_cast<data_t>( + sqrtf(variance_loc[c] + eps)); + data_t sm = (use_scaleshift ? scaleshift[c] : 1.0f) / sqrt_variance; + data_t sv = use_scaleshift ? scaleshift[C + c] : 0; + size_t d_off = (size_t)n * SP * C + sp * C + c; + data_t bn_res = sm * (src[d_off] - mean_loc[c]) + sv; + if (fuse_bn_relu) { + if (bn_res <= 0) { + bn_res = 0; + if (is_training) + ws[d_off] = 0; + } else { + if (is_training) + ws[d_off] = 1; + } + } + dst[d_off] = maybe_post_op(bn_res); + } + } + } + }); +} + +void nspc_batch_normalization_bwd_t::execute_backward( + const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC); + auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN); + auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE); + auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); + auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT); + auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE); + + auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC); + auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT); + + auto scratchpad = this->scratchpad(ctx); + auto tmp_diff_ss = scratchpad.get<data_t>(key_bnorm_tmp_diff_ss); + + if (diff_scaleshift == nullptr) + diff_scaleshift = tmp_diff_ss; + + const dim_t N = pd()->MB(); + const dim_t C = pd()->C(); + const dim_t SP = pd()->D() * pd()->H() * pd()->W(); + data_t *diff_gamma = diff_scaleshift, *diff_beta = diff_scaleshift + C; + auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction); + + const float eps = pd()->desc()->batch_norm_epsilon; + const bool use_scaleshift = pd()->use_scaleshift(); + const bool calculate_diff_stats = !pd()->use_global_stats(); + const bool fuse_bn_relu = pd()->fuse_bn_relu(); + + assert(mkldnn_thr_syncable()); + parallel(0, [&](const int ithr, const int nthr) { + dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0; + balance211(N, nthr, ithr, N_s, N_e); + balance211(C, nthr, ithr, C_s, C_e); + + data_t *diff_gamma_loc = tmp_diff_ss + 2 * C + C * ithr; + data_t *diff_beta_loc = tmp_diff_ss + 2 * C + C * (nthr + ithr); + + for (dim_t c = 0; c < C; c++) { + ws_reduce[C * ithr + c] = 0.; + ws_reduce[C * nthr + C * ithr + c] = 0.; + } + + for (dim_t n = N_s; n < N_e; n++) + for (dim_t sp = 0; sp < SP; sp++) +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t c = 0; c < C; c++) { + const size_t d_off = (size_t)n * SP * C + sp * C + c; + data_t dd; + if (fuse_bn_relu) + dd = (!ws[d_off]) ? 0 : diff_dst[d_off]; + else + dd = diff_dst[d_off]; + ws_reduce[C * ithr + c] += (src[d_off] - mean[c]) * dd; + ws_reduce[C * nthr + C * ithr + c] += dd; + } + + mkldnn_thr_barrier(); + + for (dim_t c = C_s; c < C_e; c++) { + data_t sqrt_variance + = static_cast<data_t>(1.0f / sqrtf(variance[c] + eps)); + diff_gamma[c] = 0; + diff_beta[c] = 0; + for (dim_t n = 0; n < nthr; n++) { + diff_gamma[c] += ws_reduce[C * n + c]; + diff_beta[c] += ws_reduce[C * nthr + C * n + c]; + } + diff_gamma[c] *= sqrt_variance; + } + + mkldnn_thr_barrier(); + + for (dim_t c = 0; c < C; c++) { + diff_gamma_loc[c] = diff_gamma[c]; + diff_beta_loc[c] = diff_beta[c]; + } + + for (dim_t n = N_s; n < N_e; n++) { + for (dim_t sp = 0; sp < SP; sp++) { +#if SAFE_TO_USE_OMP_SIMD + PRAGMA_OMP_SIMD() +#endif + for (dim_t c = 0; c < C; c++) { + const size_t d_off = (size_t)n * SP * C + sp * C + c; + data_t gamma = use_scaleshift ? scaleshift[c] : 1; + data_t sqrt_variance + = static_cast<data_t>(1.0f / sqrtf(variance[c] + eps)); + data_t v_diff_src; + if (fuse_bn_relu) + v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off]; + else + v_diff_src = diff_dst[d_off]; + if (calculate_diff_stats) { + v_diff_src -= diff_beta_loc[c] / (SP * N) + + (src[d_off] - mean[c]) * diff_gamma_loc[c] + * sqrt_variance / (SP * N); + } + v_diff_src *= gamma * sqrt_variance; + diff_src[d_off] = v_diff_src; + } + } + } + }); +} + +} +} +} + +// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |