summaryrefslogtreecommitdiffstats
path: root/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp')
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp145
1 files changed, 145 insertions, 0 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp
new file mode 100644
index 0000000000..565e97e4fc
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp
@@ -0,0 +1,145 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_TRANSPOSE_SRC_HPP
+#define CPU_JIT_TRANSPOSE_SRC_HPP
+
+#include "cpu_barrier.hpp"
+#include "jit_primitive_conf.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct jit_trans_src_t {
+ struct ctx_t {
+ const void *src;
+ const void *tr_src;
+ const void *src_prf;
+ const void *tr_src_prf;
+
+ /* 1st conv 4fma: backward by weights */
+ int nthr_oc_b; /* number of threads process given src image */
+ int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */
+ simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */
+ };
+
+ jit_trans_src_t(const jit_conv_conf_t *conf)
+ : conf_(conf), ker_(nullptr) {}
+ virtual ~jit_trans_src_t() {}
+
+ void operator()(const ctx_t *ctx)
+ { assert(ker_); ker_(ctx); }
+
+ const jit_conv_conf_t *conf_;
+ void (*ker_)(const ctx_t *);
+};
+
+struct jit_src_transpose_s {
+ size_t size;
+ const void *src;
+ const void *tr_src;
+ const void *src_prf;
+ const void *tr_src_prf;
+};
+
+struct jit_trans_dst_t {
+ struct ctx_t {
+ const void *src;
+ const void *tr_src;
+ const void *src_prf;
+ const void *tr_src_prf;
+
+ /* 1st conv 4fma: backward by weights */
+ int nthr_oc_b; /* number of threads process given src image */
+ int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */
+ simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */
+ };
+
+ jit_trans_dst_t(const jit_conv_conf_t *conf)
+ : conf_(conf), ker_(nullptr) {}
+ virtual ~jit_trans_dst_t() {}
+
+ void operator()(const ctx_t *ctx)
+ { assert(ker_); ker_(ctx); }
+
+ const jit_conv_conf_t *conf_;
+ void (*ker_)(const ctx_t *);
+};
+
+struct jit_transpose4x16_src_t {
+ int src_pf0_distance;
+ int tr_src_pf0_distance;
+ bool src_pf1;
+ bool tr_src_pf1;
+};
+
+struct jit_transpose4x16_src : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_transpose4x16_src)
+
+ jit_transpose4x16_src(const jit_1x1_conv_conf_t *aparams,
+ jit_transpose4x16_src_t *tparams_)
+ : params(aparams), tparams(tparams_)
+ {
+ this->generate();
+ jit_ker = (decltype(jit_ker))this->getCode();
+ }
+
+ const jit_1x1_conv_conf_t *params;
+ const jit_transpose4x16_src_t *tparams;
+ void (*jit_ker)(jit_src_transpose_s *);
+
+ void operator()(jit_src_transpose_s *arg) { jit_ker(arg); }
+
+ static const int transpose_size = 4;
+private:
+ static const int typesize = sizeof(float);
+
+ int src_stride, tr_src_stride;
+
+ Xbyak::Reg64 imm_addr64 = rbx;
+
+ Xbyak::Opmask kF0 = k1;
+ Xbyak::Opmask kCC = k2;
+ Xbyak::Opmask k33 = k3;
+ Xbyak::Opmask kFFFF = k4;
+
+ Xbyak::Zmm vidx01 = zmm31;
+ Xbyak::Zmm vidx10 = zmm30;
+ Xbyak::Zmm vidx1 = zmm29;
+ Xbyak::Zmm vidxP = zmm28;
+
+ Xbyak::Reg64 reg_src = r8;
+ Xbyak::Reg64 reg_tr_src = r9;
+ Xbyak::Reg64 reg_src_prf = r10;
+ Xbyak::Reg64 reg_tr_src_prf = r11;
+ Xbyak::Reg64 reg_loop = r12;
+ Xbyak::Reg64 reg_tr_src_tmp = r13;
+ Xbyak::Reg32 regw_tmp = r14d;
+
+ void transpose_block(int ur, int nrows);
+ void transpose(int nrows);
+ void generate();
+};
+
+jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf);
+jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf);
+
+}
+}
+}
+
+#endif