summaryrefslogtreecommitdiffstats
path: root/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp')
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp288
1 files changed, 288 insertions, 0 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp
new file mode 100644
index 0000000000..e20333e66f
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp
@@ -0,0 +1,288 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include <math.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+
+#include "cpu_batch_normalization_utils.hpp"
+#include "jit_generator.hpp"
+
+#include "nspc_batch_normalization.hpp"
+
+// clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases
+#if (defined __clang_major__) && (__clang_major__ >= 6)
+#define SAFE_TO_USE_OMP_SIMD 0
+#else
+#define SAFE_TO_USE_OMP_SIMD 1
+#endif
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace memory_tracking::names;
+
+void nspc_batch_normalization_fwd_t::execute_forward(
+ const exec_ctx_t &ctx) const {
+ const bool save_stats = pd()->is_training();
+ const bool is_training = pd()->is_training();
+ const bool fuse_bn_relu = pd()->fuse_bn_relu();
+ const bool calculate_stats = !pd()->stats_is_src();
+ const bool with_relu = pd()->with_relu_post_op();
+
+ auto scratchpad = this->scratchpad(ctx);
+ auto tmp_mean = scratchpad.get<data_t>(key_bnorm_tmp_mean);
+ auto tmp_var = scratchpad.get<data_t>(key_bnorm_tmp_var);
+ auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);
+
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);
+
+ data_t *mean, *variance;
+ if (!calculate_stats) {
+ mean = const_cast<data_t *>(
+ CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN));
+ variance = const_cast<data_t *>(
+ CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE));
+ } else {
+ if (save_stats) {
+ mean = CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN);
+ variance = CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE);
+ } else {
+ mean = tmp_mean;
+ variance = tmp_var;
+ }
+ }
+
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+ auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE);
+
+ const dim_t N = pd()->MB();
+ const dim_t C = pd()->C();
+ const dim_t SP = pd()->H() * pd()->W() * pd()->D();
+
+ const float eps = pd()->desc()->batch_norm_epsilon;
+ const bool use_scaleshift = pd()->use_scaleshift();
+ auto maybe_post_op
+ = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; };
+
+ assert(mkldnn_thr_syncable());
+ parallel(0, [&](const int ithr, const int nthr) {
+ dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0;
+ balance211(N, nthr, ithr, N_s, N_e);
+ balance211(C, nthr, ithr, C_s, C_e);
+ data_t *mean_loc = tmp_mean + nstl::max(C, (dim_t)16) * ithr;
+ data_t *variance_loc = tmp_var + nstl::max(C, (dim_t)16) * ithr;
+
+ if (calculate_stats) {
+ for (dim_t c = 0; c < C; c++)
+ ws_reduce[C * ithr + c] = 0.;
+
+ for (dim_t n = N_s; n < N_e; n++)
+ for (dim_t sp = 0; sp < SP; sp++)
+ PRAGMA_OMP_SIMD()
+ for (dim_t c = 0; c < C; c++)
+ ws_reduce[C * ithr + c] += src[(size_t)n * SP * C
+ + sp * C + c];
+
+ mkldnn_thr_barrier();
+
+ for (dim_t c = C_s; c < C_e; c++) {
+ mean[c] = 0;
+ for (dim_t n = 0; n < nthr; n++)
+ mean[c] += ws_reduce[C * n + c];
+ mean[c] /= SP * N;
+ }
+
+ mkldnn_thr_barrier();
+
+ for (dim_t c = 0; c < C; c++) {
+ mean_loc[c] = mean[c];
+ ws_reduce[C * ithr + c] = 0.;
+ }
+
+ for (dim_t n = N_s; n < N_e; n++)
+ for (dim_t sp = 0; sp < SP; sp++)
+ PRAGMA_OMP_SIMD()
+ for (dim_t c = 0; c < C; c++) {
+ data_t m = src[(size_t)n * SP * C + sp * C + c]
+ - mean_loc[c];
+ ws_reduce[C * ithr + c] += m * m;
+ }
+
+ mkldnn_thr_barrier();
+
+ for (dim_t c = C_s; c < C_e; c++) {
+ variance[c] = 0;
+ for (dim_t n = 0; n < nthr; n++)
+ variance[c] += ws_reduce[C * n + c];
+ variance[c] /= SP * N;
+ }
+
+ mkldnn_thr_barrier();
+
+ for (dim_t c = 0; c < C; c++)
+ variance_loc[c] = variance[c];
+ } else {
+ variance_loc = variance;
+ mean_loc = mean;
+ }
+
+ for (dim_t n = N_s; n < N_e; n++) {
+ for (dim_t sp = 0; sp < SP; sp++) {
+#if SAFE_TO_USE_OMP_SIMD
+ PRAGMA_OMP_SIMD()
+#endif
+ for (dim_t c = 0; c < C; c++) {
+ data_t sqrt_variance = static_cast<data_t>(
+ sqrtf(variance_loc[c] + eps));
+ data_t sm = (use_scaleshift ? scaleshift[c] : 1.0f) / sqrt_variance;
+ data_t sv = use_scaleshift ? scaleshift[C + c] : 0;
+ size_t d_off = (size_t)n * SP * C + sp * C + c;
+ data_t bn_res = sm * (src[d_off] - mean_loc[c]) + sv;
+ if (fuse_bn_relu) {
+ if (bn_res <= 0) {
+ bn_res = 0;
+ if (is_training)
+ ws[d_off] = 0;
+ } else {
+ if (is_training)
+ ws[d_off] = 1;
+ }
+ }
+ dst[d_off] = maybe_post_op(bn_res);
+ }
+ }
+ }
+ });
+}
+
+void nspc_batch_normalization_bwd_t::execute_backward(
+ const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN);
+ auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE);
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);
+ auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE);
+
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+ auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT);
+
+ auto scratchpad = this->scratchpad(ctx);
+ auto tmp_diff_ss = scratchpad.get<data_t>(key_bnorm_tmp_diff_ss);
+
+ if (diff_scaleshift == nullptr)
+ diff_scaleshift = tmp_diff_ss;
+
+ const dim_t N = pd()->MB();
+ const dim_t C = pd()->C();
+ const dim_t SP = pd()->D() * pd()->H() * pd()->W();
+ data_t *diff_gamma = diff_scaleshift, *diff_beta = diff_scaleshift + C;
+ auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);
+
+ const float eps = pd()->desc()->batch_norm_epsilon;
+ const bool use_scaleshift = pd()->use_scaleshift();
+ const bool calculate_diff_stats = !pd()->use_global_stats();
+ const bool fuse_bn_relu = pd()->fuse_bn_relu();
+
+ assert(mkldnn_thr_syncable());
+ parallel(0, [&](const int ithr, const int nthr) {
+ dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0;
+ balance211(N, nthr, ithr, N_s, N_e);
+ balance211(C, nthr, ithr, C_s, C_e);
+
+ data_t *diff_gamma_loc = tmp_diff_ss + 2 * C + C * ithr;
+ data_t *diff_beta_loc = tmp_diff_ss + 2 * C + C * (nthr + ithr);
+
+ for (dim_t c = 0; c < C; c++) {
+ ws_reduce[C * ithr + c] = 0.;
+ ws_reduce[C * nthr + C * ithr + c] = 0.;
+ }
+
+ for (dim_t n = N_s; n < N_e; n++)
+ for (dim_t sp = 0; sp < SP; sp++)
+#if SAFE_TO_USE_OMP_SIMD
+ PRAGMA_OMP_SIMD()
+#endif
+ for (dim_t c = 0; c < C; c++) {
+ const size_t d_off = (size_t)n * SP * C + sp * C + c;
+ data_t dd;
+ if (fuse_bn_relu)
+ dd = (!ws[d_off]) ? 0 : diff_dst[d_off];
+ else
+ dd = diff_dst[d_off];
+ ws_reduce[C * ithr + c] += (src[d_off] - mean[c]) * dd;
+ ws_reduce[C * nthr + C * ithr + c] += dd;
+ }
+
+ mkldnn_thr_barrier();
+
+ for (dim_t c = C_s; c < C_e; c++) {
+ data_t sqrt_variance
+ = static_cast<data_t>(1.0f / sqrtf(variance[c] + eps));
+ diff_gamma[c] = 0;
+ diff_beta[c] = 0;
+ for (dim_t n = 0; n < nthr; n++) {
+ diff_gamma[c] += ws_reduce[C * n + c];
+ diff_beta[c] += ws_reduce[C * nthr + C * n + c];
+ }
+ diff_gamma[c] *= sqrt_variance;
+ }
+
+ mkldnn_thr_barrier();
+
+ for (dim_t c = 0; c < C; c++) {
+ diff_gamma_loc[c] = diff_gamma[c];
+ diff_beta_loc[c] = diff_beta[c];
+ }
+
+ for (dim_t n = N_s; n < N_e; n++) {
+ for (dim_t sp = 0; sp < SP; sp++) {
+#if SAFE_TO_USE_OMP_SIMD
+ PRAGMA_OMP_SIMD()
+#endif
+ for (dim_t c = 0; c < C; c++) {
+ const size_t d_off = (size_t)n * SP * C + sp * C + c;
+ data_t gamma = use_scaleshift ? scaleshift[c] : 1;
+ data_t sqrt_variance
+ = static_cast<data_t>(1.0f / sqrtf(variance[c] + eps));
+ data_t v_diff_src;
+ if (fuse_bn_relu)
+ v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off];
+ else
+ v_diff_src = diff_dst[d_off];
+ if (calculate_diff_stats) {
+ v_diff_src -= diff_beta_loc[c] / (SP * N)
+ + (src[d_off] - mean[c]) * diff_gamma_loc[c]
+ * sqrt_variance / (SP * N);
+ }
+ v_diff_src *= gamma * sqrt_variance;
+ diff_src[d_off] = v_diff_src;
+ }
+ }
+ }
+ });
+}
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s