summaryrefslogtreecommitdiffstats
path: root/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp')
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp265
1 files changed, 265 insertions, 0 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp
new file mode 100644
index 0000000000..d79b1a034b
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp
@@ -0,0 +1,265 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include <math.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "mkldnn_thread.hpp"
+#include "simple_q10n.hpp"
+
+#include "ref_batch_normalization.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <impl::data_type_t data_type>
+void ref_batch_normalization_fwd_t<data_type>::execute_forward(
+ const exec_ctx_t &ctx) const {
+ /* fast return */
+ if (this->pd()->has_zero_dim_memory()) return;
+
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto scaleshift = CTX_IN_MEM(const float *, MKLDNN_ARG_SCALE_SHIFT);
+
+ auto mean = pd()->stats_is_src()
+ ? const_cast<float *>(CTX_IN_MEM(const float *, MKLDNN_ARG_MEAN))
+ : CTX_OUT_MEM(float *, MKLDNN_ARG_MEAN);
+ auto variance = pd()->stats_is_src()
+ ? const_cast<float *>(CTX_IN_MEM(const float *, MKLDNN_ARG_VARIANCE))
+ : CTX_OUT_MEM(float *, MKLDNN_ARG_VARIANCE);
+
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+ auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE);
+
+ const memory_desc_wrapper data_d(pd()->src_md());
+ const memory_desc_wrapper scaleshift_d(pd()->weights_md());
+
+ const dim_t N = pd()->MB();
+ const dim_t C = pd()->C();
+ dim_t H = 1, W = 1, D = 1;
+ const bool has_spatial = utils::one_of(data_d.ndims(), 4, 5);
+ if (has_spatial) {
+ D = pd()->D();
+ H = pd()->H();
+ W = pd()->W();
+ }
+
+ const float eps = pd()->desc()->batch_norm_epsilon;
+ const bool use_scaleshift = pd()->use_scaleshift();;
+ const bool save_stats = pd()->is_training();
+ const bool is_training = pd()->is_training();
+ const bool fuse_bn_relu = pd()->fuse_bn_relu();
+ const bool calculate_stats = !pd()->stats_is_src();
+
+ const bool with_relu = pd()->with_relu_post_op();
+ auto maybe_post_op = [&](float res) {
+ return (with_relu && res < 0.0f) ? 0.0f : res;
+ };
+ const bool is_3d = data_d.ndims() == 5;
+
+ auto data_offset = [&](const memory_desc_wrapper &data_d, dim_t n, dim_t c,
+ dim_t d, dim_t h, dim_t w) {
+ if (has_spatial) {
+ if (is_3d)
+ return data_d.off(n, c, d, h, w);
+ else
+ return data_d.off(n, c, h, w);
+ } else
+ return data_d.off(n, c);
+ };
+
+ parallel_nd(C, [&](dim_t c) {
+ float v_mean = calculate_stats ? 0 : mean[c];
+ float v_variance = calculate_stats ? 0 : variance[c];
+
+ if (calculate_stats) {
+ for (dim_t n = 0; n < N; ++n)
+ for (dim_t d = 0; d < D; ++d)
+ for (dim_t h = 0; h < H; ++h)
+ for (dim_t w = 0; w < W; ++w)
+ v_mean += src[data_offset(data_d, n, c, d, h, w)];
+ v_mean /= W*N*H*D;
+
+ for (dim_t n = 0; n < N; ++n)
+ for (dim_t d = 0; d < D; ++d)
+ for (dim_t h = 0; h < H; ++h)
+ for (dim_t w = 0; w < W; ++w) {
+ float m = src[data_offset(data_d, n, c, d, h, w)] - v_mean;
+ v_variance += m*m;
+ }
+ v_variance /= W*H*N*D;
+ }
+
+ float sqrt_variance = sqrtf(v_variance + eps);
+ float sm = (use_scaleshift
+ ? scaleshift[scaleshift_d.off(0, c)]
+ : 1.0f) / sqrt_variance;
+ float sv = use_scaleshift ? scaleshift[scaleshift_d.off(1, c)] : 0;
+
+ for (dim_t n = 0; n < N; ++n)
+ for (dim_t d = 0; d < D; ++d)
+ for (dim_t h = 0; h < H; ++h)
+ for (dim_t w = 0; w < W; ++w) {
+ auto d_off = data_offset(data_d,n,c,d,h,w);
+ float bn_res = sm * ((float)src[d_off] - v_mean) + sv;
+ if (fuse_bn_relu) {
+ if (bn_res <= 0) {
+ bn_res = 0;
+ if (is_training)
+ ws[d_off] = 0;
+ } else {
+ if (is_training)
+ ws[d_off] = 1;
+ }
+ }
+ if (data_type == data_type::s8) {
+ dst[d_off] = qz_a1b0<float, data_t>()(maybe_post_op(bn_res));
+ } else {
+ dst[d_off] = static_cast<data_t>(maybe_post_op(bn_res));
+ }
+ }
+
+ if (calculate_stats) {
+ if (save_stats) {
+ mean[c] = v_mean;
+ variance[c] = v_variance;
+ }
+ }
+ });
+}
+
+template struct ref_batch_normalization_fwd_t<data_type::f32>;
+template struct ref_batch_normalization_fwd_t<data_type::s8>;
+
+template <impl::data_type_t data_type>
+void ref_batch_normalization_bwd_t<data_type>::execute_backward(
+ const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN);
+ auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE);
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);
+ auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE);
+
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+ auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT);
+
+ const memory_desc_wrapper data_d(pd()->src_md());
+ const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
+ const memory_desc_wrapper scaleshift_d(pd()->weights_md());
+ const memory_desc_wrapper diff_scaleshift_d(pd()->diff_weights_md());
+
+ const dim_t C = pd()->C();
+
+ /* fast return */
+ if (this->pd()->has_zero_dim_memory()) {
+ if (diff_scaleshift) {
+ for (dim_t c = 0; c < C; ++c) {
+ diff_scaleshift[diff_scaleshift_d.off(0, c)] = 0;
+ diff_scaleshift[diff_scaleshift_d.off(1, c)] = 0;
+ }
+ }
+ return;
+ }
+
+ const dim_t N = pd()->MB();
+ dim_t H = 1, W = 1, D = 1;
+ const bool has_spatial = utils::one_of(data_d.ndims(), 4, 5);
+ if (has_spatial) {
+ D = pd()->D();
+ H = pd()->H();
+ W = pd()->W();
+ }
+
+ const float eps = pd()->desc()->batch_norm_epsilon;
+ const bool use_scaleshift = pd()->use_scaleshift();
+ const bool calculate_diff_stats = !pd()->use_global_stats();
+ const bool fuse_bn_relu = pd()->fuse_bn_relu();
+
+ const bool is_3d = data_d.ndims() == 5;
+
+ auto data_offset = [&](const memory_desc_wrapper &data_d, dim_t n, dim_t c,
+ dim_t d, dim_t h, dim_t w) {
+ if (has_spatial) {
+ if (is_3d)
+ return data_d.off(n, c, d, h, w);
+ else
+ return data_d.off(n, c, h, w);
+ } else
+ return data_d.off(n, c);
+ };
+
+ parallel_nd(C, [&](dim_t c) {
+ data_t v_mean = mean[c];
+ data_t v_variance = variance[c];
+ data_t sqrt_variance = static_cast<data_t>(1.0f / sqrtf(v_variance + eps));
+ data_t gamma = use_scaleshift ? scaleshift[scaleshift_d.off(0, c)] : 1;
+ data_t diff_gamma = data_t(0);
+ data_t diff_beta = data_t(0);
+ diff_gamma = 0.0;
+ diff_beta = 0.0;
+
+ for (dim_t n = 0; n < N; ++n)
+ for (dim_t d = 0; d < D; ++d)
+ for (dim_t h = 0; h < H; ++h)
+ for (dim_t w = 0; w < W; ++w) {
+ const size_t s_off = data_offset(data_d, n, c, d, h, w);
+ data_t dd = diff_dst[data_offset(diff_data_d, n, c, d, h, w)];
+ if (fuse_bn_relu && !ws[s_off])
+ dd = 0;
+
+ diff_gamma += (src[s_off] - v_mean) * dd;
+ diff_beta += dd;
+ }
+ diff_gamma *= sqrt_variance;
+
+ if (diff_scaleshift) {
+ diff_scaleshift[diff_scaleshift_d.off(0, c)] = diff_gamma;
+ diff_scaleshift[diff_scaleshift_d.off(1, c)] = diff_beta;
+ }
+
+ for (dim_t n = 0; n < N; ++n)
+ for (dim_t d = 0; d < D; ++d)
+ for (dim_t h = 0; h < H; ++h)
+ for (dim_t w = 0; w < W; ++w) {
+ const size_t s_off = data_offset(data_d, n, c, d, h, w);
+ const size_t dd_off = data_offset(diff_data_d, n, c, d, h, w);
+ data_t dd = diff_dst[dd_off];
+ if (fuse_bn_relu && !ws[s_off])
+ dd = 0;
+
+ data_t v_diff_src = dd;
+ if (calculate_diff_stats) {
+ v_diff_src -= diff_beta/(D*W*H*N) +
+ (src[s_off] - v_mean) *
+ diff_gamma*sqrt_variance/(D*W*H*N);
+ }
+ v_diff_src *= gamma*sqrt_variance;
+ diff_src[dd_off] = v_diff_src;
+ }
+ });
+}
+
+template struct ref_batch_normalization_bwd_t<data_type::f32>;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s