summaryrefslogtreecommitdiffstats
path: root/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp
diff options
context:
space:
mode:
authorDario <dariosamo@gmail.com>2023-09-18 10:05:20 -0300
committerDario <dariosamo@gmail.com>2023-09-25 14:53:45 -0300
commitab65effed015df76b0858df27127f62b3aa94e0e (patch)
treecab7bbbdd2b63235b809560e47c3ac3784fa892b /thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp
parent1b2b726502eabaae4a15d544d92735cc2efe35b5 (diff)
downloadredot-engine-ab65effed015df76b0858df27127f62b3aa94e0e.tar.gz
Remove denoise module and thirdparty OIDN.
This is replaced by a much lighter weight and faster JNLM denoiser. OIDN is still much more accurate, and may be provided as an optional backend in the future, but the JNLM denoiser seems good enough for most use cases and removing OIDN reduces the build system complexity, binary size, and build times very significantly.
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp')
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp788
1 files changed, 0 insertions, 788 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp
deleted file mode 100644
index ead536816c..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp
+++ /dev/null
@@ -1,788 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-/*
- General architecture
-
- for diff states, we have n_states + 1 as we have n_states diff
- to propagate to the previous iteration and 1 states to propagate
- to the previous layer
- index 0 is dh for cell(t-1, l) to consume
- index 1 is dc for cell(t-1, l) to consume
- index 2 is dh for cell(t, l-1) to consume
- this indexing enables to have the same indexing for states in elemwise
- function
- only the cell execution function should be impacted
-
- */
-
-#include "math_utils.hpp"
-#include "mkldnn_thread.hpp"
-
-#include "ref_rnn.hpp"
-#include "../gemm/gemm.hpp"
-#include "../simple_q10n.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-using namespace mkldnn::impl::utils;
-using namespace mkldnn::impl::memory_tracking::names;
-using namespace rnn_utils;
-#define AOC array_offset_calculator
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-void _ref_rnn_common_t<aprop, src_type, weights_type>::gates_reduction(
- const rnn_conf_t &rnn, const acc_data_t *ws_gates_,
- float *diff_bias_) const {
- auto body = [&](int i, int k) {
- for (int j = 0; j < rnn.mb; j++)
- diff_bias_[i * rnn.dic + k]
- += ws_gates_[j * rnn.gates_ws_ld + i * rnn.dic + k];
- };
-
- // @todo block k on simd-width
-#if MKLDNN_THR == MKLDNN_THR_OMP && _OPENMP >= 201307 \
- /* icc 17.0 has a problem with simd collapse */ \
- && !((defined __INTEL_COMPILER) && (__INTEL_COMPILER == 1700))
-#pragma omp parallel for simd collapse(2)
- for (int i = 0; i < rnn.n_gates; i++)
- for (int k = 0; k < rnn.dic; k++)
- body(i, k);
-#else
- parallel_nd(rnn.n_gates, rnn.dic, body);
-#endif
-}
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::gemm)) {
- assert(ldA * ldB * ldC != 0);
- extended_sgemm(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_, &ldB,
- &beta, c_, &ldC, nullptr, pd()->rnn_.use_jit_gemm);
-}
-
-template <>
-rnn_gemm_sig((ref_rnn_fwd_u8s8_t::gemm)) {
- assert(!"non packed gemm is disabled for int8");
-}
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::packed_gemm)) {
-#if (USE_MKL_PACKED_GEMM)
- assert(transA == 'N');
- cblas_sgemm_compute(CblasColMajor, CblasPacked,
- (transB == 'T') ? CblasTrans : CblasNoTrans, m, n, k, a_, ldA, b_,
- ldB, beta, c_, ldC);
-#else
- UNUSED(transA);
- UNUSED(transB);
- UNUSED(m);
- UNUSED(n);
- UNUSED(k);
- UNUSED(alpha);
- UNUSED(ldA);
- UNUSED(b_);
- UNUSED(ldB);
- UNUSED(beta);
- UNUSED(c_);
- UNUSED(ldC);
- assert(!"packed gemm is disabled");
-#endif
-}
-
-template <>
-rnn_gemm_sig((ref_rnn_fwd_u8s8_t::packed_gemm)) {
-#if (USE_MKL_PACKED_GEMM)
- int8_t offseta = 0, offsetb = 0;
- int32_t offsetc = 0;
- cblas_gemm_s8u8s32_compute(CblasColMajor, (CBLAS_TRANSPOSE)CblasPacked,
- CblasNoTrans, CblasFixOffset, m, n, k, alpha, a_, ldA, offseta, b_,
- ldB, offsetb, beta, c_, ldC, &offsetc);
-#else
- UNUSED(transA);
- UNUSED(transB);
- UNUSED(m);
- UNUSED(n);
- UNUSED(k);
- UNUSED(alpha);
- UNUSED(ldA);
- UNUSED(b_);
- UNUSED(ldB);
- UNUSED(beta);
- UNUSED(c_);
- UNUSED(ldC);
- assert(!"packed gemm is disabled");
-#endif
-}
-
-//*************** Grid computations strategy: linear ***************//
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-rnn_grid_execution_sig(
- (_ref_rnn_common_t<aprop, src_type, weights_type>::linear_execution)) {
- AOC<src_data_t, 4> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
- rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld);
- AOC<float, 4> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir,
- rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld);
- AOC<float, 5> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir,
- (rnn.n_states + 1), rnn.n_iter + 1,
- rnn.states_nld * rnn.states_ws_ld);
- AOC<acc_data_t, 4> ws_gates(ws_gates_, rnn.n_layer, rnn.n_dir, rnn.n_iter,
- rnn.gates_nld * rnn.gates_ws_ld);
- AOC<weights_data_t *, 3> weights_input(
- weights_layer_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_layer);
- AOC<weights_data_t *, 3> weights_states(
- weights_states_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_iter);
- AOC<float*, 3> bias(
- bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias);
- AOC<float, 3> diff_weights_layer(diff_weights_layer_, rnn.n_layer,
- rnn.n_dir,
- rnn.diff_weights_layer_nld * rnn.diff_weights_layer_ld);
- AOC<float, 3> diff_weights_iter(diff_weights_iter_, rnn.n_layer, rnn.n_dir,
- rnn.diff_weights_iter_nld * rnn.diff_weights_iter_ld);
- AOC<float, 3> diff_bias(
- diff_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic);
- AOC<float, 4> ws_grid(
- ws_grid_, rnn.n_layer, rnn.n_dir, rnn.n_iter, (int)rnn.ws_per_cell);
-
- // We run the grid of computation
- for (int dir = 0; dir < rnn.n_dir; dir++) {
- for (int j = 0; j < rnn.n_layer; j++) {
- int lay = (aprop == prop_kind::forward) ? j : rnn.n_layer - j - 1;
-
- if ((aprop == prop_kind::forward) && rnn.merge_gemm_layer) {
- (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic,
- rnn.mb * rnn.n_iter, rnn.slc, 1.0,
- weights_input(lay, dir, 0), rnn.weights_iter_ld,
- &(ws_states(lay, dir, 1, 0)), rnn.states_ws_ld, 0.0,
- &(ws_gates(lay, dir, 0, 0)), rnn.gates_ws_ld);
- }
-
- for (int i = 0; i < rnn.n_iter; i++) {
- int iter = (aprop == prop_kind::forward) ? i : rnn.n_iter - i - 1;
- (this->*cell_func)(rnn,
- &(ws_states(lay + 1, dir, iter + 1, 0)),
- &(ws_c_states(lay + 1, dir, iter + 1, 0)),
- &(ws_diff_states(lay, dir, 0, iter, 0)),
- &(weights_input(lay, dir, 0)),
- &(weights_states(lay, dir, 0)),
- &(bias(lay, dir, 0)),
- &(ws_states(lay, dir, iter + 1, 0)),
- &(ws_states(lay + 1, dir, iter, 0)),
- &(ws_c_states(lay + 1, dir, iter, 0)),
- &(ws_diff_states(lay + 1, dir, 0, iter, 0)),
- &(ws_diff_states(lay, dir, 0, iter + 1, 0)),
- &(diff_weights_layer(lay, dir, 0)),
- &(diff_weights_iter(lay, dir, 0)),
- &(diff_bias(lay, dir, 0)),
- &(ws_gates(lay, dir, iter, 0)),
- &(ws_grid(lay, dir, iter, 0)),
- ws_cell_);
- }
-
- if ((aprop == prop_kind::backward) && rnn.merge_gemm_layer) {
- (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb * rnn.n_iter,
- rnn.n_gates * rnn.dic, 1.0, weights_input(lay, dir, 0),
- rnn.weights_layer_ld,
- (src_data_t *)(&(ws_gates(lay, dir, 0, 0))),
- rnn.gates_ws_ld, 0.0,
- (acc_data_t *)(&(ws_diff_states(
- lay, dir, rnn.n_states, 0, 0))),
- rnn.states_ws_ld);
- gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc,
- rnn.mb * rnn.n_iter, 1.0,
- (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))),
- rnn.gates_ws_ld,
- (src_data_t *)(&(ws_states(lay, dir, 1, 0))),
- rnn.states_ws_ld, 1.0,
- (acc_data_t *)(&(diff_weights_layer(lay, dir, 0))),
- rnn.diff_weights_layer_ld);
- }
- if ((aprop == prop_kind::backward) && rnn.merge_gemm_iter) {
- gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic,
- rnn.mb * rnn.n_iter, 1.0,
- (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))),
- rnn.gates_ws_ld,
- (src_data_t *)(&(ws_states(lay + 1, dir, 0, 0))),
- rnn.states_ws_ld, 1.0,
- (acc_data_t *)(&(diff_weights_iter(lay, dir, 0))),
- rnn.diff_weights_iter_ld);
- }
- }
- }
-}
-
-//********* GRID computations strategy: utility functions **********//
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_init_layer(
- const rnn_conf_t &rnn, src_data_t *__restrict ws_states_,
- float *__restrict ws_diff_states_, const src_data_t *__restrict xt_,
- const float *__restrict diff_dst_layer_) const {
-
- AOC<src_data_t, 4> ws_states(
- ws_states_, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
- auto xt_d = memory_desc_wrapper(pd()->src_md(0));
-
- parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
- auto xxt = xt_ + xt_d.blk_off(it, b);
- src_data_t *ws_l2r_ptr = &(ws_states(0, it + 1, b, 0));
- src_data_t *ws_r2l_ptr = &(ws_states(rnn.n_dir - 1, rnn.n_iter - it, b, 0));
- if (rnn.exec_dir != r2l)
- for (int c = 0; c < rnn.slc; c++)
- ws_l2r_ptr[c] = xxt[c];
- if (rnn.exec_dir != l2r)
- for (int c = 0; c < rnn.slc; c++)
- ws_r2l_ptr[c] = xxt[c];
- });
-}
-
-template <>
-void ref_rnn_bwd_f32_t::copy_init_layer(const rnn_conf_t &rnn,
- src_data_t *ws_states_, float *ws_diff_states_, const src_data_t *xt_,
- const float *diff_dst_layer_) const {
- AOC<float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir,
- (rnn.n_states + 1), rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
- auto diff_dst_layer_d = memory_desc_wrapper(pd()->diff_dst_md(0));
-
- switch (rnn.exec_dir) {
- case bi_concat:
- parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
- auto diff_dst_layer_x
- = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
- for (int s = 0; s < rnn.dic; s++) {
- ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
- = diff_dst_layer_x[s];
- ws_diff_states(
- rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s)
- = diff_dst_layer_x[rnn.dic + s];
- }
- });
- break;
- case bi_sum:
- parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
- auto diff_dst_layer_x
- = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
- for (int s = 0; s < rnn.dic; s++) {
- ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
- = diff_dst_layer_x[s];
- ws_diff_states(
- rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s)
- = diff_dst_layer_x[s];
- }
- });
- break;
- case l2r:
- parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
- auto diff_dst_layer_x
- = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
- for (int s = 0; s < rnn.dic; s++) {
- ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
- = diff_dst_layer_x[s];
- }
- });
- break;
- case r2l:
- parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
- auto diff_dst_layer_x = diff_dst_layer_
- + diff_dst_layer_d.blk_off(rnn.n_iter - it - 1, b);
- for (int s = 0; s < rnn.dic; s++) {
- ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
- = diff_dst_layer_x[s];
- }
- });
- break;
- default: assert(!"Unsupported direction"); break;
- }
-}
-
-/* For int8 configuration, input iteration states may be of types f32 or u8
- * Internally h_state is always stored in u8 and c_state is always stored in f32
- * If input states are of type u8 then h state is copied and c state is dequantized
- * If input states are of type f32 then h state is quantized and c_state is copied
- * */
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-template <typename input_data_t>
-void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_init_iter(
- const rnn_conf_t &rnn, src_data_t *__restrict ws_states_,
- float *__restrict ws_c_states_, float *__restrict ws_diff_states_,
- const input_data_t *__restrict firstit_states_,
- const float *__restrict diff_dst_iter_) const {
- AOC<src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
- rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
- AOC<float, 5> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir,
- rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
- float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
- float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
-
- const bool quantize = pd()->with_src_iter()
- && pd()->src_md(1)->data_type == data_type::f32
- && rnn.dt_conf != all_f32;
- auto maybe_q = [&](input_data_t f) {
- if (quantize) {
- float qf = f * data_scale + data_shift;
- return qz_a1b0<float, src_data_t>()(qf);
- } else
- return (src_data_t)f;
- };
-
- const bool dequantize = pd()->with_src_iter()
- && pd()->src_md(1)->data_type == data_type::u8;
- auto maybe_deq = [&](input_data_t s) {
- if (dequantize)
- return (((float)s - data_shift) / data_scale);
- else
- return (float)s;
- };
- auto firstit_states_d = memory_desc_wrapper(pd()->src_md(1));
- if (firstit_states_) {
- parallel_nd(
- rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) {
- for (int s = 0; s < rnn.sic; s++)
- ws_states(lay + 1, dir, 0, b, s) = maybe_q(
- firstit_states_[firstit_states_d.blk_off(
- lay, dir, 0, b, s)]);
- if (pd()->cell_kind() == alg_kind::vanilla_lstm)
- for (int s = 0; s < rnn.sic; s++)
- ws_c_states(lay + 1, dir, 0, b, s) = maybe_deq(
- firstit_states_[firstit_states_d.blk_off(
- lay, dir, 1, b, s)]);
- });
- } else {
- parallel_nd(
- rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) {
- for (int j = 0; j < rnn.sic; j++) {
- ws_states(lay + 1, dir, 0, b, j) = (src_data_t)0;
- ws_c_states(lay + 1, dir, 0, b, j) = 0.0f;
- }
- });
- }
-}
-
-template <>
-template <typename input_data_t>
-void ref_rnn_bwd_f32_t::copy_init_iter(const rnn_conf_t &rnn,
- src_data_t *ws_states_, float *ws_c_states_, float *ws_diff_states_,
- const input_data_t *firstit_states_,
- const float *diff_dst_iter_) const {
- AOC<float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir,
- rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
- auto diff_dst_iter_d = memory_desc_wrapper(pd()->diff_dst_md(1));
- if (diff_dst_iter_) {
- parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb,
- [&](int lay, int dir, int state, int b) {
- array_copy(&(ws_diff_states(
- lay, dir, state, rnn.n_iter, b, 0)),
- diff_dst_iter_
- + diff_dst_iter_d.blk_off(
- lay, dir, state, b),
- rnn.dic);
- });
- } else {
- parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb,
- [&](int lay, int dir, int state, int i) {
- for (int j = 0; j < rnn.dic; j++)
- ws_diff_states(lay, dir, state, rnn.n_iter, i, j)
- = 0.0f;
- });
- }
-}
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-template <typename dst_data_t>
-void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_res_layer(
- const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer,
- const src_data_t *ws_states_, const float *ws_diff_states_) const {
-
- auto dst_layer_d = memory_desc_wrapper(pd()->dst_md(0));
- AOC<const src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
- rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
- float shift = (pd()->attr()->rnn_data_qparams_.shift_);
- float scale = (pd()->attr()->rnn_data_qparams_.scale_);
-
- const bool dequantize = pd()->dst_md(0)->data_type == data_type::f32
- && rnn.dt_conf != all_f32;
- auto maybe_deq = [&](src_data_t s) {
- if (dequantize)
- return (dst_data_t)(((float)s - shift) / scale);
- else
- return (dst_data_t)s;
- };
- parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
- int dir = 0;
- if (rnn.exec_dir != r2l) {
- for (int s = 0; s < rnn.dic; s++) {
- dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)]
- = maybe_deq(ws_states(rnn.n_layer, dir, it + 1, b, s));
- }
- dir = 1;
- }
- if (rnn.exec_dir != l2r) {
- for (int s = 0; s < rnn.dic; s++)
- switch (rnn.exec_dir) {
- case bi_sum:
- dst_layer_[dst_layer_d.blk_off(it, b, s)]
- += maybe_deq(ws_states(
- rnn.n_layer, dir, rnn.n_iter - it, b, s));
- break;
- default:
- dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)]
- = maybe_deq(ws_states(
- rnn.n_layer, dir, rnn.n_iter - it, b, s));
- }
- }
- });
-}
-
-template <>
-template <typename dst_data_t>
-void ref_rnn_bwd_f32_t::copy_res_layer(
- const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer_,
- const src_data_t *ws_states_, const float *ws_diff_states_) const {
- auto diff_src_layer_d = memory_desc_wrapper(pd()->diff_src_md(0));
- AOC<const float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1,
- rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb,
- rnn.states_ws_ld);
-
- parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
- int dir = 0;
- for (int s = 0; s < rnn.slc; s++) {
- float *dst_addr = diff_src_layer_
- + diff_src_layer_d.blk_off(
- (rnn.exec_dir == r2l) ? rnn.n_iter - 1 - it : it,
- b, dir * rnn.slc + s);
- float res = ws_diff_states(0, 0, rnn.n_states, it, b, s);
- if (rnn.n_dir - 1)
- res += ws_diff_states(
- 0, 1, rnn.n_states, rnn.n_iter - 1 - it, b, s);
- dst_addr[0] = res;
- }
- });
-}
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-template <typename output_data_t>
-void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_res_iter(
- const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_,
- const src_data_t *ws_states_, float *ws_c_states_,
- const float *ws_diff_states_) const {
- auto dst_iter_d = memory_desc_wrapper(pd()->dst_md(1));
- AOC<const src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
- rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
- AOC<const float, 5> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir,
- rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
- float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
- float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
-
- const bool quantize = pd()->with_dst_iter()
- && pd()->dst_md(1)->data_type == data_type::u8
- && rnn.dt_conf != all_f32;
- auto maybe_q = [&](float f) {
- if (quantize) {
- float qf = f * data_scale + data_shift;
- return qz_a1b0<float, output_data_t>()(qf);
- } else
- return (output_data_t)f;
- };
-
- const bool dequantize = pd()->with_dst_iter()
- && pd()->dst_md(1)->data_type == data_type::f32
- && rnn.dt_conf != all_f32;
- auto maybe_deq = [&](src_data_t s) {
- if (dequantize)
- return (output_data_t)(((float)s - data_shift) / data_scale);
- else
- return (output_data_t)s;
- };
- if (dst_iter_) {
- parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb,
- [&](int lay, int dir, int b) {
- for (int s = 0; s < rnn.dic; s++) {
- dst_iter_[dst_iter_d.blk_off(lay, dir, 0, b, s)]
- = maybe_deq(ws_states(lay + 1, dir, rnn.n_iter, b, s));
- }
- if (pd()->cell_kind() == alg_kind::vanilla_lstm)
- for (int s = 0; s < rnn.dic; s++) {
- dst_iter_[dst_iter_d.blk_off(lay, dir, 1, b, s)]
- = maybe_q(ws_c_states(
- lay + 1, dir, rnn.n_iter, b, s));
- }
- });
- }
-}
-
-template <>
-template <typename output_data_t>
-void ref_rnn_bwd_f32_t::copy_res_iter(
- const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_,
- const src_data_t *ws_states_, float *ws_c_states_,
- const float *ws_diff_states_) const {
- auto diff_src_iter_d = memory_desc_wrapper(pd()->diff_src_md(1));
- AOC<const float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1,
- rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb,
- rnn.states_ws_ld);
- if (diff_src_iter_) {
- parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb,
- [&](int lay, int dir, int state, int b) {
- for (int s = 0; s < rnn.sic; s++) {
- diff_src_iter_[diff_src_iter_d.blk_off(
- lay, dir, state, b, s)]
- = ws_diff_states(lay, dir, state, 0, b, s);
- }
- });
- }
-}
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-rnn_bias_prepare_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::bias_prepare)) {
- /* Original set of bias provided by the user */
- AOC<const float, 5> b(
- b_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic);
- /* Array of pointers initialized in packing */
- AOC<float *, 3> bias(bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias);
- AOC<float, 3> scratch_bias(
- scratch_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic);
-
- if (rnn.copy_bias) {
- parallel_nd(rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic,
- [&](size_t i) { scratch_bias_[i] = b_[i]; });
- }
-
- for (int i = 0; i < rnn.n_layer; i++) {
- for (int d = 0; d < rnn.n_dir; d++) {
- int offset_bias = 0;
- for (int p = 0; p < rnn.n_parts_bias; p++) {
- bias(i, d, p) = rnn.copy_bias
- ? (float *) &scratch_bias(i, d, offset_bias)
- : (float *) &b(i, d, offset_bias);
- offset_bias += rnn.parts_bias[p] * rnn.dic;
- }
- }
- }
-
-}
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-rnn_bias_finalize_sig(
- (_ref_rnn_common_t<aprop, src_type, weights_type>::bias_finalize)) {
- if (rnn.dt_conf != all_f32) {
- float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
- float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
- float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_;
- bool scale_per_oc = pd()->attr()->rnn_weights_qparams_.mask_ != 0;
- for (int i = 0; i < rnn.n_layer * rnn.n_dir; i++)
- for (int j = 0; j < rnn.n_bias * rnn.dic; j++) {
- size_t off = i * rnn.n_bias * rnn.dic + j;
- float weights_scale
- = scale_per_oc ? weights_scales[j] : weights_scales[0];
- scratch_bias_[off] -= (w_iter_comp[off] + w_layer_comp[off])
- * data_shift / (weights_scale * data_scale);
- }
- }
-}
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-rnn_weights_assign_sig((_ref_rnn_common_t<aprop, src_type,
- weights_type>::assign_packed_weights)) {
- assert(md->format_kind == format_kind::rnn_packed);
- const auto packed_desc = md->format_desc.rnn_packed_desc;
- AOC<weights_data_t *, 3> weights(weights_,
- rnn.n_layer, rnn.n_dir, packed_desc.n_parts);
-
- size_t offset_packed = 0;
- for (int l = 0; l < rnn.n_layer; l++)
- for (int d = 0; d < rnn.n_dir; d++) {
- for (int p = 0; p < packed_desc.n_parts; p++) {
- weights(l, d, p) = (weights_data_t *)&w_[offset_packed];
- offset_packed
- += packed_desc.part_pack_size[p] / sizeof(weights_data_t);
- }
- }
-}
-
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-rnn_weights_assign_sig(
- (_ref_rnn_common_t<aprop, src_type, weights_type>::assign_weights)) {
- assert(md->format_kind == format_kind::blocked);
- const auto &blk = md->format_desc.blocking;
- /* Original set of weights provided by the user */
- AOC<const weights_data_t, 3> w(w_,
- rnn.n_layer, rnn.n_dir, (int)blk.strides[1]);
- /* Array of pointers for each part of weights */
- AOC<weights_data_t *, 3> weights(weights_, rnn.n_layer, rnn.n_dir, n_parts);
-
- for (int i = 0; i < rnn.n_layer; i++)
- for (int d = 0; d < rnn.n_dir; d++) {
- size_t offset_weights = 0;
- for (int p = 0; p < n_parts; p++) {
- weights(i, d, p) = (weights_data_t *)&w(i, d, offset_weights);
- offset_weights += gates_per_part[p] * blk.strides[3];
- }
- }
-}
-
-//********************* Execution function *********************//
-template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
-void _ref_rnn_common_t<aprop, src_type, weights_type>::execute_(
- const exec_ctx_t &ctx) const {
- const rnn_conf_t &rnn = this->pd()->rnn_;
- auto input = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC_LAYER);
- auto states = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC_ITER);
- auto layer_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_LAYER);
- auto iter_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_ITER);
- auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS);
-
- auto dst_last_layer = rnn.is_fwd
- ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_LAYER)
- : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_LAYER));
- auto dst_last_iter = rnn.is_fwd
- ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_ITER)
- : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_ITER));
-
- auto diff_dst_layer = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_LAYER);
- auto diff_dst_iter = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_ITER);
-
- auto w_layer = reinterpret_cast<const weights_data_t *>(layer_weights_n_comp);
- auto w_iter = reinterpret_cast<const weights_data_t *>(iter_weights_n_comp);
- auto w_iter_comp = reinterpret_cast<const float *>(
- iter_weights_n_comp + rnn.weights_iter_comp_offset);
- auto w_layer_comp = reinterpret_cast<const float *>(
- layer_weights_n_comp + rnn.weights_layer_comp_offset);
-
- auto scratchpad = this->scratchpad(ctx);
-
- auto ptr_wei_layer
- = scratchpad.template get<weights_data_t *>(key_rnn_ptrs_wei_layer);
- auto ptr_wei_iter
- = scratchpad.template get<weights_data_t *>(key_rnn_ptrs_wei_iter);
- auto ptr_bias =
- scratchpad.template get<float *>(key_rnn_ptrs_bia);
-
- // fetchihg buffers from the workspace
- // if no workspace was provided we use the scratchpad
- char *scratch_ptr = scratchpad.template get<char>(key_rnn_space);
- char *ws_ptr = nullptr;
- if (rnn.use_workspace)
- ws_ptr = rnn.is_fwd
- ? CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE)
- : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE));
-
- char *base_ptr = rnn.use_workspace ? ws_ptr : scratch_ptr;
- acc_data_t *ws_gates = (acc_data_t *)(base_ptr + ws_gates_offset_);
- src_data_t *ws_states = (src_data_t *)(base_ptr + ws_states_offset_);
- float *ws_c_states = (float *)(base_ptr + ws_c_states_offset_);
- float *ws_diff_states = (float *)(base_ptr + ws_diff_states_offset_);
- float *ws_grid = (float *)(base_ptr + ws_grid_comp_offset_);
- float *ws_cell = (float *)(base_ptr + ws_cell_comp_offset_);
-
- auto diff_src_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_LAYER);
- auto diff_src_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_ITER);
-
- auto diff_weights_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_LAYER);
- auto diff_weights_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_ITER);
- auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS);
-
- // Fetching extra buffers from scratchpad
- float *ws_bias = (float *)(scratch_ptr + ws_bias_offset_);
-
- // initialize diff_states to 0
- if (aprop == prop_kind::backward)
- array_set(ws_diff_states, 0.0f, rnn.ws_diff_states_size / sizeof(float));
-
- /* Pack(if using packed gemm API) or copy(if input arrays have bad leading
- * dimension */
- (this->*bias_preparation_func)(rnn, ptr_bias, bias, ws_bias);
-
- (this->*weights_iter_assign_func)(rnn, pd()->weights_md(1),
- rnn.weights_iter_nld, rnn.weights_iter_ld, rnn.dic,
- rnn.sic, rnn.n_parts_weights_iter, rnn.parts_weights_iter,
- rnn.part_weights_iter_pack_size, ptr_wei_iter, w_iter,
- ptr_bias, bias, ws_bias);
- (this->*weights_layer_assign_func)(rnn, pd()->weights_md(0),
- rnn.weights_layer_nld, rnn.weights_layer_ld, rnn.dic, rnn.slc,
- rnn.n_parts_weights_layer, rnn.parts_weights_layer,
- rnn.part_weights_layer_pack_size, ptr_wei_layer, w_layer, ptr_bias,
- bias, ws_bias);
-
- (this->*bias_finalization_func)(rnn, ws_bias, w_iter_comp, w_layer_comp);
-
- // we first need to copy the initial states and input into ws
- copy_init_layer(rnn, ws_states, ws_diff_states, input, diff_dst_layer);
- if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32
- || rnn.dt_conf == all_f32)
- copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states,
- (const float *)states, diff_dst_iter);
- else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32)
- copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states,
- (const uint8_t *)states, diff_dst_iter);
- else
- assert(!"unimplemented");
-
- // run the execution on the grid
- (this->*grid_computation)(rnn, ptr_wei_layer, ptr_wei_iter, ptr_bias,
- ws_states, ws_c_states, ws_diff_states, ws_gates, ws_cell, ws_grid,
- diff_weights_layer, diff_weights_iter, diff_bias);
-
- // Finally we copy the results to the result buffers
- if (rnn.dt_conf == u8u8u8f32 || rnn.dt_conf == f32u8f32f32
- || rnn.dt_conf == all_f32)
- copy_res_layer(rnn, (float *)dst_last_layer, diff_src_layer, ws_states,
- ws_diff_states);
- else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == f32u8f32u8)
- copy_res_layer(rnn, (uint8_t *)dst_last_layer, diff_src_layer,
- ws_states, ws_diff_states);
- else
- assert(!"unimplemented");
-
- if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32
- || rnn.dt_conf == all_f32)
- copy_res_iter(rnn, (float *)dst_last_iter, diff_src_iter, ws_states,
- ws_c_states, ws_diff_states);
- else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32)
- copy_res_iter(rnn, (uint8_t *)dst_last_iter, diff_src_iter, ws_states,
- ws_c_states, ws_diff_states);
- else
- assert(!"unimplemented");
-};
-
-/* Fix for MSVS warning C4661 */
-template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution);
-template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution);
-template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution);
-template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru);
-template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru);
-template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru);
-template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr);
-template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr);
-template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr);
-template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise);
-template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise);
-template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise);
-template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise);
-template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise);
-template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise);
-template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise);
-template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise);
-template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise);
-
-template struct _ref_rnn_common_t<prop_kind::forward, data_type::f32, data_type::f32>;
-template struct _ref_rnn_common_t<prop_kind::forward, data_type::u8, data_type::s8>;
-template struct _ref_rnn_common_t<prop_kind::backward, data_type::f32, data_type::f32>;
-
-#undef AOC
-}
-}
-}