diff options
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp')
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp | 788 |
1 files changed, 0 insertions, 788 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp deleted file mode 100644 index ead536816c..0000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp +++ /dev/null @@ -1,788 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -/* - General architecture - - for diff states, we have n_states + 1 as we have n_states diff - to propagate to the previous iteration and 1 states to propagate - to the previous layer - index 0 is dh for cell(t-1, l) to consume - index 1 is dc for cell(t-1, l) to consume - index 2 is dh for cell(t, l-1) to consume - this indexing enables to have the same indexing for states in elemwise - function - only the cell execution function should be impacted - - */ - -#include "math_utils.hpp" -#include "mkldnn_thread.hpp" - -#include "ref_rnn.hpp" -#include "../gemm/gemm.hpp" -#include "../simple_q10n.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -using namespace mkldnn::impl::utils; -using namespace mkldnn::impl::memory_tracking::names; -using namespace rnn_utils; -#define AOC array_offset_calculator - -template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> -void _ref_rnn_common_t<aprop, src_type, weights_type>::gates_reduction( - const rnn_conf_t &rnn, const acc_data_t *ws_gates_, - float *diff_bias_) const { - auto body = [&](int i, int k) { - for (int j = 0; j < rnn.mb; j++) - diff_bias_[i * rnn.dic + k] - += ws_gates_[j * rnn.gates_ws_ld + i * rnn.dic + k]; - }; - - // @todo block k on simd-width -#if MKLDNN_THR == MKLDNN_THR_OMP && _OPENMP >= 201307 \ - /* icc 17.0 has a problem with simd collapse */ \ - && !((defined __INTEL_COMPILER) && (__INTEL_COMPILER == 1700)) -#pragma omp parallel for simd collapse(2) - for (int i = 0; i < rnn.n_gates; i++) - for (int k = 0; k < rnn.dic; k++) - body(i, k); -#else - parallel_nd(rnn.n_gates, rnn.dic, body); -#endif -} - -template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> -rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::gemm)) { - assert(ldA * ldB * ldC != 0); - extended_sgemm(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_, &ldB, - &beta, c_, &ldC, nullptr, pd()->rnn_.use_jit_gemm); -} - -template <> -rnn_gemm_sig((ref_rnn_fwd_u8s8_t::gemm)) { - assert(!"non packed gemm is disabled for int8"); -} - -template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> -rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::packed_gemm)) { -#if (USE_MKL_PACKED_GEMM) - assert(transA == 'N'); - cblas_sgemm_compute(CblasColMajor, CblasPacked, - (transB == 'T') ? CblasTrans : CblasNoTrans, m, n, k, a_, ldA, b_, - ldB, beta, c_, ldC); -#else - UNUSED(transA); - UNUSED(transB); - UNUSED(m); - UNUSED(n); - UNUSED(k); - UNUSED(alpha); - UNUSED(ldA); - UNUSED(b_); - UNUSED(ldB); - UNUSED(beta); - UNUSED(c_); - UNUSED(ldC); - assert(!"packed gemm is disabled"); -#endif -} - -template <> -rnn_gemm_sig((ref_rnn_fwd_u8s8_t::packed_gemm)) { -#if (USE_MKL_PACKED_GEMM) - int8_t offseta = 0, offsetb = 0; - int32_t offsetc = 0; - cblas_gemm_s8u8s32_compute(CblasColMajor, (CBLAS_TRANSPOSE)CblasPacked, - CblasNoTrans, CblasFixOffset, m, n, k, alpha, a_, ldA, offseta, b_, - ldB, offsetb, beta, c_, ldC, &offsetc); -#else - UNUSED(transA); - UNUSED(transB); - UNUSED(m); - UNUSED(n); - UNUSED(k); - UNUSED(alpha); - UNUSED(ldA); - UNUSED(b_); - UNUSED(ldB); - UNUSED(beta); - UNUSED(c_); - UNUSED(ldC); - assert(!"packed gemm is disabled"); -#endif -} - -//*************** Grid computations strategy: linear ***************// -template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> -rnn_grid_execution_sig( - (_ref_rnn_common_t<aprop, src_type, weights_type>::linear_execution)) { - AOC<src_data_t, 4> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, - rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld); - AOC<float, 4> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, - rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld); - AOC<float, 5> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, - (rnn.n_states + 1), rnn.n_iter + 1, - rnn.states_nld * rnn.states_ws_ld); - AOC<acc_data_t, 4> ws_gates(ws_gates_, rnn.n_layer, rnn.n_dir, rnn.n_iter, - rnn.gates_nld * rnn.gates_ws_ld); - AOC<weights_data_t *, 3> weights_input( - weights_layer_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_layer); - AOC<weights_data_t *, 3> weights_states( - weights_states_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_iter); - AOC<float*, 3> bias( - bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias); - AOC<float, 3> diff_weights_layer(diff_weights_layer_, rnn.n_layer, - rnn.n_dir, - rnn.diff_weights_layer_nld * rnn.diff_weights_layer_ld); - AOC<float, 3> diff_weights_iter(diff_weights_iter_, rnn.n_layer, rnn.n_dir, - rnn.diff_weights_iter_nld * rnn.diff_weights_iter_ld); - AOC<float, 3> diff_bias( - diff_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); - AOC<float, 4> ws_grid( - ws_grid_, rnn.n_layer, rnn.n_dir, rnn.n_iter, (int)rnn.ws_per_cell); - - // We run the grid of computation - for (int dir = 0; dir < rnn.n_dir; dir++) { - for (int j = 0; j < rnn.n_layer; j++) { - int lay = (aprop == prop_kind::forward) ? j : rnn.n_layer - j - 1; - - if ((aprop == prop_kind::forward) && rnn.merge_gemm_layer) { - (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, - rnn.mb * rnn.n_iter, rnn.slc, 1.0, - weights_input(lay, dir, 0), rnn.weights_iter_ld, - &(ws_states(lay, dir, 1, 0)), rnn.states_ws_ld, 0.0, - &(ws_gates(lay, dir, 0, 0)), rnn.gates_ws_ld); - } - - for (int i = 0; i < rnn.n_iter; i++) { - int iter = (aprop == prop_kind::forward) ? i : rnn.n_iter - i - 1; - (this->*cell_func)(rnn, - &(ws_states(lay + 1, dir, iter + 1, 0)), - &(ws_c_states(lay + 1, dir, iter + 1, 0)), - &(ws_diff_states(lay, dir, 0, iter, 0)), - &(weights_input(lay, dir, 0)), - &(weights_states(lay, dir, 0)), - &(bias(lay, dir, 0)), - &(ws_states(lay, dir, iter + 1, 0)), - &(ws_states(lay + 1, dir, iter, 0)), - &(ws_c_states(lay + 1, dir, iter, 0)), - &(ws_diff_states(lay + 1, dir, 0, iter, 0)), - &(ws_diff_states(lay, dir, 0, iter + 1, 0)), - &(diff_weights_layer(lay, dir, 0)), - &(diff_weights_iter(lay, dir, 0)), - &(diff_bias(lay, dir, 0)), - &(ws_gates(lay, dir, iter, 0)), - &(ws_grid(lay, dir, iter, 0)), - ws_cell_); - } - - if ((aprop == prop_kind::backward) && rnn.merge_gemm_layer) { - (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb * rnn.n_iter, - rnn.n_gates * rnn.dic, 1.0, weights_input(lay, dir, 0), - rnn.weights_layer_ld, - (src_data_t *)(&(ws_gates(lay, dir, 0, 0))), - rnn.gates_ws_ld, 0.0, - (acc_data_t *)(&(ws_diff_states( - lay, dir, rnn.n_states, 0, 0))), - rnn.states_ws_ld); - gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, - rnn.mb * rnn.n_iter, 1.0, - (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))), - rnn.gates_ws_ld, - (src_data_t *)(&(ws_states(lay, dir, 1, 0))), - rnn.states_ws_ld, 1.0, - (acc_data_t *)(&(diff_weights_layer(lay, dir, 0))), - rnn.diff_weights_layer_ld); - } - if ((aprop == prop_kind::backward) && rnn.merge_gemm_iter) { - gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, - rnn.mb * rnn.n_iter, 1.0, - (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))), - rnn.gates_ws_ld, - (src_data_t *)(&(ws_states(lay + 1, dir, 0, 0))), - rnn.states_ws_ld, 1.0, - (acc_data_t *)(&(diff_weights_iter(lay, dir, 0))), - rnn.diff_weights_iter_ld); - } - } - } -} - -//********* GRID computations strategy: utility functions **********// - -template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> -void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_init_layer( - const rnn_conf_t &rnn, src_data_t *__restrict ws_states_, - float *__restrict ws_diff_states_, const src_data_t *__restrict xt_, - const float *__restrict diff_dst_layer_) const { - - AOC<src_data_t, 4> ws_states( - ws_states_, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); - auto xt_d = memory_desc_wrapper(pd()->src_md(0)); - - parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { - auto xxt = xt_ + xt_d.blk_off(it, b); - src_data_t *ws_l2r_ptr = &(ws_states(0, it + 1, b, 0)); - src_data_t *ws_r2l_ptr = &(ws_states(rnn.n_dir - 1, rnn.n_iter - it, b, 0)); - if (rnn.exec_dir != r2l) - for (int c = 0; c < rnn.slc; c++) - ws_l2r_ptr[c] = xxt[c]; - if (rnn.exec_dir != l2r) - for (int c = 0; c < rnn.slc; c++) - ws_r2l_ptr[c] = xxt[c]; - }); -} - -template <> -void ref_rnn_bwd_f32_t::copy_init_layer(const rnn_conf_t &rnn, - src_data_t *ws_states_, float *ws_diff_states_, const src_data_t *xt_, - const float *diff_dst_layer_) const { - AOC<float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, - (rnn.n_states + 1), rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); - auto diff_dst_layer_d = memory_desc_wrapper(pd()->diff_dst_md(0)); - - switch (rnn.exec_dir) { - case bi_concat: - parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { - auto diff_dst_layer_x - = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); - for (int s = 0; s < rnn.dic; s++) { - ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) - = diff_dst_layer_x[s]; - ws_diff_states( - rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s) - = diff_dst_layer_x[rnn.dic + s]; - } - }); - break; - case bi_sum: - parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { - auto diff_dst_layer_x - = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); - for (int s = 0; s < rnn.dic; s++) { - ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) - = diff_dst_layer_x[s]; - ws_diff_states( - rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s) - = diff_dst_layer_x[s]; - } - }); - break; - case l2r: - parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { - auto diff_dst_layer_x - = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b); - for (int s = 0; s < rnn.dic; s++) { - ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) - = diff_dst_layer_x[s]; - } - }); - break; - case r2l: - parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { - auto diff_dst_layer_x = diff_dst_layer_ - + diff_dst_layer_d.blk_off(rnn.n_iter - it - 1, b); - for (int s = 0; s < rnn.dic; s++) { - ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s) - = diff_dst_layer_x[s]; - } - }); - break; - default: assert(!"Unsupported direction"); break; - } -} - -/* For int8 configuration, input iteration states may be of types f32 or u8 - * Internally h_state is always stored in u8 and c_state is always stored in f32 - * If input states are of type u8 then h state is copied and c state is dequantized - * If input states are of type f32 then h state is quantized and c_state is copied - * */ -template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> -template <typename input_data_t> -void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_init_iter( - const rnn_conf_t &rnn, src_data_t *__restrict ws_states_, - float *__restrict ws_c_states_, float *__restrict ws_diff_states_, - const input_data_t *__restrict firstit_states_, - const float *__restrict diff_dst_iter_) const { - AOC<src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, - rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); - AOC<float, 5> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, - rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); - float data_shift = pd()->attr()->rnn_data_qparams_.shift_; - float data_scale = pd()->attr()->rnn_data_qparams_.scale_; - - const bool quantize = pd()->with_src_iter() - && pd()->src_md(1)->data_type == data_type::f32 - && rnn.dt_conf != all_f32; - auto maybe_q = [&](input_data_t f) { - if (quantize) { - float qf = f * data_scale + data_shift; - return qz_a1b0<float, src_data_t>()(qf); - } else - return (src_data_t)f; - }; - - const bool dequantize = pd()->with_src_iter() - && pd()->src_md(1)->data_type == data_type::u8; - auto maybe_deq = [&](input_data_t s) { - if (dequantize) - return (((float)s - data_shift) / data_scale); - else - return (float)s; - }; - auto firstit_states_d = memory_desc_wrapper(pd()->src_md(1)); - if (firstit_states_) { - parallel_nd( - rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) { - for (int s = 0; s < rnn.sic; s++) - ws_states(lay + 1, dir, 0, b, s) = maybe_q( - firstit_states_[firstit_states_d.blk_off( - lay, dir, 0, b, s)]); - if (pd()->cell_kind() == alg_kind::vanilla_lstm) - for (int s = 0; s < rnn.sic; s++) - ws_c_states(lay + 1, dir, 0, b, s) = maybe_deq( - firstit_states_[firstit_states_d.blk_off( - lay, dir, 1, b, s)]); - }); - } else { - parallel_nd( - rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) { - for (int j = 0; j < rnn.sic; j++) { - ws_states(lay + 1, dir, 0, b, j) = (src_data_t)0; - ws_c_states(lay + 1, dir, 0, b, j) = 0.0f; - } - }); - } -} - -template <> -template <typename input_data_t> -void ref_rnn_bwd_f32_t::copy_init_iter(const rnn_conf_t &rnn, - src_data_t *ws_states_, float *ws_c_states_, float *ws_diff_states_, - const input_data_t *firstit_states_, - const float *diff_dst_iter_) const { - AOC<float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir, - rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); - auto diff_dst_iter_d = memory_desc_wrapper(pd()->diff_dst_md(1)); - if (diff_dst_iter_) { - parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, - [&](int lay, int dir, int state, int b) { - array_copy(&(ws_diff_states( - lay, dir, state, rnn.n_iter, b, 0)), - diff_dst_iter_ - + diff_dst_iter_d.blk_off( - lay, dir, state, b), - rnn.dic); - }); - } else { - parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, - [&](int lay, int dir, int state, int i) { - for (int j = 0; j < rnn.dic; j++) - ws_diff_states(lay, dir, state, rnn.n_iter, i, j) - = 0.0f; - }); - } -} - -template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> -template <typename dst_data_t> -void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_res_layer( - const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer, - const src_data_t *ws_states_, const float *ws_diff_states_) const { - - auto dst_layer_d = memory_desc_wrapper(pd()->dst_md(0)); - AOC<const src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, - rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); - float shift = (pd()->attr()->rnn_data_qparams_.shift_); - float scale = (pd()->attr()->rnn_data_qparams_.scale_); - - const bool dequantize = pd()->dst_md(0)->data_type == data_type::f32 - && rnn.dt_conf != all_f32; - auto maybe_deq = [&](src_data_t s) { - if (dequantize) - return (dst_data_t)(((float)s - shift) / scale); - else - return (dst_data_t)s; - }; - parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { - int dir = 0; - if (rnn.exec_dir != r2l) { - for (int s = 0; s < rnn.dic; s++) { - dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)] - = maybe_deq(ws_states(rnn.n_layer, dir, it + 1, b, s)); - } - dir = 1; - } - if (rnn.exec_dir != l2r) { - for (int s = 0; s < rnn.dic; s++) - switch (rnn.exec_dir) { - case bi_sum: - dst_layer_[dst_layer_d.blk_off(it, b, s)] - += maybe_deq(ws_states( - rnn.n_layer, dir, rnn.n_iter - it, b, s)); - break; - default: - dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)] - = maybe_deq(ws_states( - rnn.n_layer, dir, rnn.n_iter - it, b, s)); - } - } - }); -} - -template <> -template <typename dst_data_t> -void ref_rnn_bwd_f32_t::copy_res_layer( - const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer_, - const src_data_t *ws_states_, const float *ws_diff_states_) const { - auto diff_src_layer_d = memory_desc_wrapper(pd()->diff_src_md(0)); - AOC<const float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, - rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, - rnn.states_ws_ld); - - parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) { - int dir = 0; - for (int s = 0; s < rnn.slc; s++) { - float *dst_addr = diff_src_layer_ - + diff_src_layer_d.blk_off( - (rnn.exec_dir == r2l) ? rnn.n_iter - 1 - it : it, - b, dir * rnn.slc + s); - float res = ws_diff_states(0, 0, rnn.n_states, it, b, s); - if (rnn.n_dir - 1) - res += ws_diff_states( - 0, 1, rnn.n_states, rnn.n_iter - 1 - it, b, s); - dst_addr[0] = res; - } - }); -} - -template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> -template <typename output_data_t> -void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_res_iter( - const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_, - const src_data_t *ws_states_, float *ws_c_states_, - const float *ws_diff_states_) const { - auto dst_iter_d = memory_desc_wrapper(pd()->dst_md(1)); - AOC<const src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir, - rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); - AOC<const float, 5> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir, - rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld); - float data_shift = pd()->attr()->rnn_data_qparams_.shift_; - float data_scale = pd()->attr()->rnn_data_qparams_.scale_; - - const bool quantize = pd()->with_dst_iter() - && pd()->dst_md(1)->data_type == data_type::u8 - && rnn.dt_conf != all_f32; - auto maybe_q = [&](float f) { - if (quantize) { - float qf = f * data_scale + data_shift; - return qz_a1b0<float, output_data_t>()(qf); - } else - return (output_data_t)f; - }; - - const bool dequantize = pd()->with_dst_iter() - && pd()->dst_md(1)->data_type == data_type::f32 - && rnn.dt_conf != all_f32; - auto maybe_deq = [&](src_data_t s) { - if (dequantize) - return (output_data_t)(((float)s - data_shift) / data_scale); - else - return (output_data_t)s; - }; - if (dst_iter_) { - parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb, - [&](int lay, int dir, int b) { - for (int s = 0; s < rnn.dic; s++) { - dst_iter_[dst_iter_d.blk_off(lay, dir, 0, b, s)] - = maybe_deq(ws_states(lay + 1, dir, rnn.n_iter, b, s)); - } - if (pd()->cell_kind() == alg_kind::vanilla_lstm) - for (int s = 0; s < rnn.dic; s++) { - dst_iter_[dst_iter_d.blk_off(lay, dir, 1, b, s)] - = maybe_q(ws_c_states( - lay + 1, dir, rnn.n_iter, b, s)); - } - }); - } -} - -template <> -template <typename output_data_t> -void ref_rnn_bwd_f32_t::copy_res_iter( - const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_, - const src_data_t *ws_states_, float *ws_c_states_, - const float *ws_diff_states_) const { - auto diff_src_iter_d = memory_desc_wrapper(pd()->diff_src_md(1)); - AOC<const float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, - rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, - rnn.states_ws_ld); - if (diff_src_iter_) { - parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb, - [&](int lay, int dir, int state, int b) { - for (int s = 0; s < rnn.sic; s++) { - diff_src_iter_[diff_src_iter_d.blk_off( - lay, dir, state, b, s)] - = ws_diff_states(lay, dir, state, 0, b, s); - } - }); - } -} - -template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> -rnn_bias_prepare_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::bias_prepare)) { - /* Original set of bias provided by the user */ - AOC<const float, 5> b( - b_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); - /* Array of pointers initialized in packing */ - AOC<float *, 3> bias(bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias); - AOC<float, 3> scratch_bias( - scratch_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic); - - if (rnn.copy_bias) { - parallel_nd(rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic, - [&](size_t i) { scratch_bias_[i] = b_[i]; }); - } - - for (int i = 0; i < rnn.n_layer; i++) { - for (int d = 0; d < rnn.n_dir; d++) { - int offset_bias = 0; - for (int p = 0; p < rnn.n_parts_bias; p++) { - bias(i, d, p) = rnn.copy_bias - ? (float *) &scratch_bias(i, d, offset_bias) - : (float *) &b(i, d, offset_bias); - offset_bias += rnn.parts_bias[p] * rnn.dic; - } - } - } - -} - -template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> -rnn_bias_finalize_sig( - (_ref_rnn_common_t<aprop, src_type, weights_type>::bias_finalize)) { - if (rnn.dt_conf != all_f32) { - float data_shift = pd()->attr()->rnn_data_qparams_.shift_; - float data_scale = pd()->attr()->rnn_data_qparams_.scale_; - float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_; - bool scale_per_oc = pd()->attr()->rnn_weights_qparams_.mask_ != 0; - for (int i = 0; i < rnn.n_layer * rnn.n_dir; i++) - for (int j = 0; j < rnn.n_bias * rnn.dic; j++) { - size_t off = i * rnn.n_bias * rnn.dic + j; - float weights_scale - = scale_per_oc ? weights_scales[j] : weights_scales[0]; - scratch_bias_[off] -= (w_iter_comp[off] + w_layer_comp[off]) - * data_shift / (weights_scale * data_scale); - } - } -} - -template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> -rnn_weights_assign_sig((_ref_rnn_common_t<aprop, src_type, - weights_type>::assign_packed_weights)) { - assert(md->format_kind == format_kind::rnn_packed); - const auto packed_desc = md->format_desc.rnn_packed_desc; - AOC<weights_data_t *, 3> weights(weights_, - rnn.n_layer, rnn.n_dir, packed_desc.n_parts); - - size_t offset_packed = 0; - for (int l = 0; l < rnn.n_layer; l++) - for (int d = 0; d < rnn.n_dir; d++) { - for (int p = 0; p < packed_desc.n_parts; p++) { - weights(l, d, p) = (weights_data_t *)&w_[offset_packed]; - offset_packed - += packed_desc.part_pack_size[p] / sizeof(weights_data_t); - } - } -} - -template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> -rnn_weights_assign_sig( - (_ref_rnn_common_t<aprop, src_type, weights_type>::assign_weights)) { - assert(md->format_kind == format_kind::blocked); - const auto &blk = md->format_desc.blocking; - /* Original set of weights provided by the user */ - AOC<const weights_data_t, 3> w(w_, - rnn.n_layer, rnn.n_dir, (int)blk.strides[1]); - /* Array of pointers for each part of weights */ - AOC<weights_data_t *, 3> weights(weights_, rnn.n_layer, rnn.n_dir, n_parts); - - for (int i = 0; i < rnn.n_layer; i++) - for (int d = 0; d < rnn.n_dir; d++) { - size_t offset_weights = 0; - for (int p = 0; p < n_parts; p++) { - weights(i, d, p) = (weights_data_t *)&w(i, d, offset_weights); - offset_weights += gates_per_part[p] * blk.strides[3]; - } - } -} - -//********************* Execution function *********************// -template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type> -void _ref_rnn_common_t<aprop, src_type, weights_type>::execute_( - const exec_ctx_t &ctx) const { - const rnn_conf_t &rnn = this->pd()->rnn_; - auto input = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC_LAYER); - auto states = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC_ITER); - auto layer_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_LAYER); - auto iter_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_ITER); - auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS); - - auto dst_last_layer = rnn.is_fwd - ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_LAYER) - : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_LAYER)); - auto dst_last_iter = rnn.is_fwd - ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_ITER) - : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_ITER)); - - auto diff_dst_layer = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_LAYER); - auto diff_dst_iter = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_ITER); - - auto w_layer = reinterpret_cast<const weights_data_t *>(layer_weights_n_comp); - auto w_iter = reinterpret_cast<const weights_data_t *>(iter_weights_n_comp); - auto w_iter_comp = reinterpret_cast<const float *>( - iter_weights_n_comp + rnn.weights_iter_comp_offset); - auto w_layer_comp = reinterpret_cast<const float *>( - layer_weights_n_comp + rnn.weights_layer_comp_offset); - - auto scratchpad = this->scratchpad(ctx); - - auto ptr_wei_layer - = scratchpad.template get<weights_data_t *>(key_rnn_ptrs_wei_layer); - auto ptr_wei_iter - = scratchpad.template get<weights_data_t *>(key_rnn_ptrs_wei_iter); - auto ptr_bias = - scratchpad.template get<float *>(key_rnn_ptrs_bia); - - // fetchihg buffers from the workspace - // if no workspace was provided we use the scratchpad - char *scratch_ptr = scratchpad.template get<char>(key_rnn_space); - char *ws_ptr = nullptr; - if (rnn.use_workspace) - ws_ptr = rnn.is_fwd - ? CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE) - : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE)); - - char *base_ptr = rnn.use_workspace ? ws_ptr : scratch_ptr; - acc_data_t *ws_gates = (acc_data_t *)(base_ptr + ws_gates_offset_); - src_data_t *ws_states = (src_data_t *)(base_ptr + ws_states_offset_); - float *ws_c_states = (float *)(base_ptr + ws_c_states_offset_); - float *ws_diff_states = (float *)(base_ptr + ws_diff_states_offset_); - float *ws_grid = (float *)(base_ptr + ws_grid_comp_offset_); - float *ws_cell = (float *)(base_ptr + ws_cell_comp_offset_); - - auto diff_src_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_LAYER); - auto diff_src_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_ITER); - - auto diff_weights_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_LAYER); - auto diff_weights_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_ITER); - auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS); - - // Fetching extra buffers from scratchpad - float *ws_bias = (float *)(scratch_ptr + ws_bias_offset_); - - // initialize diff_states to 0 - if (aprop == prop_kind::backward) - array_set(ws_diff_states, 0.0f, rnn.ws_diff_states_size / sizeof(float)); - - /* Pack(if using packed gemm API) or copy(if input arrays have bad leading - * dimension */ - (this->*bias_preparation_func)(rnn, ptr_bias, bias, ws_bias); - - (this->*weights_iter_assign_func)(rnn, pd()->weights_md(1), - rnn.weights_iter_nld, rnn.weights_iter_ld, rnn.dic, - rnn.sic, rnn.n_parts_weights_iter, rnn.parts_weights_iter, - rnn.part_weights_iter_pack_size, ptr_wei_iter, w_iter, - ptr_bias, bias, ws_bias); - (this->*weights_layer_assign_func)(rnn, pd()->weights_md(0), - rnn.weights_layer_nld, rnn.weights_layer_ld, rnn.dic, rnn.slc, - rnn.n_parts_weights_layer, rnn.parts_weights_layer, - rnn.part_weights_layer_pack_size, ptr_wei_layer, w_layer, ptr_bias, - bias, ws_bias); - - (this->*bias_finalization_func)(rnn, ws_bias, w_iter_comp, w_layer_comp); - - // we first need to copy the initial states and input into ws - copy_init_layer(rnn, ws_states, ws_diff_states, input, diff_dst_layer); - if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32 - || rnn.dt_conf == all_f32) - copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states, - (const float *)states, diff_dst_iter); - else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32) - copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states, - (const uint8_t *)states, diff_dst_iter); - else - assert(!"unimplemented"); - - // run the execution on the grid - (this->*grid_computation)(rnn, ptr_wei_layer, ptr_wei_iter, ptr_bias, - ws_states, ws_c_states, ws_diff_states, ws_gates, ws_cell, ws_grid, - diff_weights_layer, diff_weights_iter, diff_bias); - - // Finally we copy the results to the result buffers - if (rnn.dt_conf == u8u8u8f32 || rnn.dt_conf == f32u8f32f32 - || rnn.dt_conf == all_f32) - copy_res_layer(rnn, (float *)dst_last_layer, diff_src_layer, ws_states, - ws_diff_states); - else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == f32u8f32u8) - copy_res_layer(rnn, (uint8_t *)dst_last_layer, diff_src_layer, - ws_states, ws_diff_states); - else - assert(!"unimplemented"); - - if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32 - || rnn.dt_conf == all_f32) - copy_res_iter(rnn, (float *)dst_last_iter, diff_src_iter, ws_states, - ws_c_states, ws_diff_states); - else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32) - copy_res_iter(rnn, (uint8_t *)dst_last_iter, diff_src_iter, ws_states, - ws_c_states, ws_diff_states); - else - assert(!"unimplemented"); -}; - -/* Fix for MSVS warning C4661 */ -template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution); -template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution); -template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution); -template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru); -template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru); -template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru); -template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr); -template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr); -template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr); -template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise); -template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise); -template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise); -template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise); -template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise); -template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise); -template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise); -template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise); -template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise); - -template struct _ref_rnn_common_t<prop_kind::forward, data_type::f32, data_type::f32>; -template struct _ref_rnn_common_t<prop_kind::forward, data_type::u8, data_type::s8>; -template struct _ref_rnn_common_t<prop_kind::backward, data_type::f32, data_type::f32>; - -#undef AOC -} -} -} |