summaryrefslogtreecommitdiffstats
path: root/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp')
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp313
1 files changed, 313 insertions, 0 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp
new file mode 100644
index 0000000000..69b7a33604
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp
@@ -0,0 +1,313 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "memory_desc_wrapper.hpp"
+#include "mkldnn_debug.h"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "jit_uni_reorder.hpp"
+
+using namespace mkldnn::impl::types;
+using namespace mkldnn::impl::status;
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace tr {
+
+/** ad-hoc structure to describe blocked memory layout */
+struct layout_desc_t {
+ data_type_t dt;
+ int ndims;
+ dims_t id;
+ dims_t dims;
+ strides_t strides;
+};
+
+status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_,
+ layout_desc_t &ld) {
+ const auto md = memory_desc_wrapper(md_);
+
+ bool ok = true
+ && md.is_blocking_desc()
+ && md.extra().flags == 0;
+ if (!ok) return invalid_arguments;
+
+ const auto &bd = md.blocking_desc();
+
+ ld.ndims = 0;
+ ld.dt = md.data_type();
+
+ auto P = [&ld](int id, int dim, ptrdiff_t stride) {
+ assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0]));
+ ld.id[ld.ndims] = id;
+ ld.dims[ld.ndims] = dim;
+ ld.strides[ld.ndims] = stride;
+ ++ld.ndims;
+ };
+
+ dims_t blocks;
+ md.compute_blocks(blocks);
+
+ for (int d = 0; d < md.ndims(); ++d) {
+ const int ld_ndims_start = ld.ndims;
+ if (blocks[d] != 1) {
+ stride_t stride = 1;
+ for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) {
+ if (bd.inner_idxs[iblk] == d)
+ P(d, bd.inner_blks[iblk], stride);
+ stride *= bd.inner_blks[iblk];
+ }
+ }
+ P(d, md.padded_dims()[d] / blocks[d], bd.strides[d]);
+
+ // TODO: NOW: revisit, do we need a reverse?
+ // TODO: NOW: consider using strides instead of block sizes in md
+ // reverse the order of dims
+ for (int ld_d = 0; ld_d < (ld.ndims - ld_ndims_start) / 2; ++ld_d) {
+ const int idx0 = ld_ndims_start + ld_d;
+ const int idx1 = ld.ndims - 1 - ld_d;
+ nstl::swap(ld.dims[idx0], ld.dims[idx1]);
+ nstl::swap(ld.strides[idx0], ld.strides[idx1]);
+ }
+ }
+
+ return success;
+}
+
+status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
+ const primitive_attr_t *attr) {
+ auto im_d = memory_desc_wrapper(imd);
+ auto om_d = memory_desc_wrapper(omd);
+
+ bool ok = true
+ && im_d.is_blocking_desc()
+ && om_d.is_blocking_desc()
+ && !im_d.has_zero_dim()
+ && !om_d.has_zero_dim();
+ if (!ok)
+ return unimplemented;
+
+ dims_t iblocks, oblocks;
+ im_d.compute_blocks(iblocks);
+ om_d.compute_blocks(oblocks);
+
+ /* padding_dim consistency check */
+ for (int d = 0; d < im_d.ndims(); ++d) {
+ const auto pdim = im_d.padded_dims()[d];
+ bool ok = true
+ && pdim == om_d.padded_dims()[d]
+ && pdim % iblocks[d] == 0
+ && pdim % oblocks[d] == 0;
+ if (!ok) return unimplemented;
+ }
+
+ layout_desc_t ild, old;
+ status_t status = cvt_mem_desc_to_layout_desc(imd, ild);
+ if (status != success) return status;
+ status = cvt_mem_desc_to_layout_desc(omd, old);
+ if (status != success) return status;
+
+ p.itype = ild.dt;
+ p.otype = old.dt;
+
+ p.scale_type = attr->output_scales_.has_default_values()
+ ? scale_type_t::NONE
+ : (attr->output_scales_.mask_ == 0
+ ? scale_type_t::COMMON
+ : scale_type_t::MANY);
+
+ ptrdiff_t ss[max_ndims] = {0};
+ if (p.scale_type == scale_type_t::MANY) {
+ ptrdiff_t last_ss = 1;
+ for (int d = old.ndims - 1; d >=0; --d) {
+ assert((d == 0 || old.id[d - 1] <= old.id[d])
+ && "logical dimensions should be in ascending order");
+ if (attr->output_scales_.mask_ & (1 << old.id[d])) {
+ ss[d] = last_ss;
+ last_ss *= old.dims[d];
+ }
+ }
+ }
+
+ int ndims = 0;
+
+ int i_pos = 0; /* state for input -- current dimension */
+ int o_pos = 0; /* state for output -- current dimension */
+
+ while (i_pos < ild.ndims && o_pos < old.ndims) {
+ assert(ild.id[i_pos] == old.id[o_pos]);
+ if (ild.id[i_pos] != old.id[o_pos])
+ return runtime_error;
+
+ assert(ndims < max_ndims);
+ if (ndims == max_ndims)
+ return runtime_error;
+
+ if (ild.dims[i_pos] == old.dims[o_pos]) {
+ p.nodes[ndims].n = ild.dims[i_pos];
+ p.nodes[ndims].is = ild.strides[i_pos];
+ p.nodes[ndims].os = old.strides[o_pos];
+ p.nodes[ndims].ss = ss[o_pos];
+ ++ndims;
+ ++i_pos;
+ ++o_pos;
+ } else if (ild.dims[i_pos] < old.dims[o_pos]) {
+ assert(old.dims[o_pos] % ild.dims[i_pos] == 0);
+ int factor = old.dims[o_pos] / ild.dims[i_pos];
+ p.nodes[ndims].n = ild.dims[i_pos];
+ p.nodes[ndims].is = ild.strides[i_pos];
+ p.nodes[ndims].os = old.strides[o_pos] * factor;
+ p.nodes[ndims].ss = ss[o_pos] * factor;
+ ++ndims;
+ ++i_pos;
+ old.dims[o_pos] = factor;
+ } else if (ild.dims[i_pos] > old.dims[o_pos]) {
+ assert(ild.dims[i_pos] % old.dims[o_pos] == 0);
+ int factor = ild.dims[i_pos] / old.dims[o_pos];
+ p.nodes[ndims].n = old.dims[o_pos];
+ p.nodes[ndims].is = ild.strides[i_pos] * factor;
+ p.nodes[ndims].os = old.strides[o_pos];
+ p.nodes[ndims].ss = ss[o_pos];
+ ++ndims;
+ ++o_pos;
+ ild.dims[i_pos] = factor;
+ }
+ }
+ p.ndims = ndims;
+
+ dims_t zero_pos = {0};
+ p.ioff = memory_desc_wrapper(imd).off_v(zero_pos);
+ p.ooff = memory_desc_wrapper(omd).off_v(zero_pos);
+
+ const int sum_idx = attr->post_ops_.find(primitive_kind::sum);
+ p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale;
+
+ return success;
+}
+
+void prb_normalize(prb_t &p) {
+ for (int d = 0; d < p.ndims; ++d) {
+ int min_pos = d;
+ for (int j = d + 1; j < p.ndims; ++j) {
+ bool new_min = false
+ || p.nodes[j].os < p.nodes[min_pos].os
+ || (true
+ && p.nodes[j].os == p.nodes[min_pos].os
+ && p.nodes[j].n < p.nodes[min_pos].n);
+ if (new_min) min_pos = j;
+ }
+ if (min_pos != d)
+ nstl::swap(p.nodes[d], p.nodes[min_pos]);
+ }
+}
+
+void prb_simplify(prb_t &p) {
+#if defined(__GNUC__) && __GNUC__ >= 4
+/* GCC produces bogus array subscript is above array bounds warning for
+ * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Warray-bounds"
+#endif
+ for (int d = 0; d < p.ndims - 1; ++d) {
+ auto &this_node = p.nodes[d + 0];
+ auto &next_node = p.nodes[d + 1];
+ const bool fold = false
+ || next_node.n == (size_t)1 // trivial case, just drop next node
+ || (true // or real folding if possible
+ && next_node.is == (ptrdiff_t)this_node.n * this_node.is
+ && next_node.os == (ptrdiff_t)this_node.n * this_node.os
+ && next_node.ss == (ptrdiff_t)this_node.n * this_node.ss);
+ if (fold) {
+ this_node.n *= next_node.n;
+ for (int j = d + 2; j < p.ndims; ++j)
+ p.nodes[j - 1] = p.nodes[j];
+ --p.ndims;
+ --d; // make another try
+ }
+ }
+#if defined(__GNUC__) && __GNUC__ >= 4
+#pragma GCC diagnostic pop
+#endif
+}
+
+void prb_node_split(prb_t &p, int dim, size_t n1) {
+ assert(dim < p.ndims);
+ assert(p.ndims < max_ndims);
+ assert(p.nodes[dim].n % n1 == 0);
+
+ p.ndims += 1;
+
+ for (int d = p.ndims; d > dim + 1; --d)
+ p.nodes[d] = p.nodes[d - 1];
+
+ p.nodes[dim + 1].n = p.nodes[dim].n / n1;
+ p.nodes[dim + 1].is = p.nodes[dim].is * n1;
+ p.nodes[dim + 1].os = p.nodes[dim].os * n1;
+ p.nodes[dim + 1].ss = p.nodes[dim].ss * n1;
+
+ p.nodes[dim].n = n1;
+}
+
+void prb_node_swap(prb_t &p, int d0, int d1) {
+ assert(d0 < p.ndims);
+ assert(d1 < p.ndims);
+ assert(p.ndims < max_ndims);
+
+ if (d0 == d1) return;
+
+ nstl::swap(p.nodes[d0], p.nodes[d1]);
+}
+
+void prb_node_move(prb_t &p, int d0, int d1) {
+ assert(d0 < p.ndims);
+ assert(d1 < p.ndims);
+ assert(p.ndims < max_ndims);
+
+ if (d0 == d1) return;
+
+ node_t node = p.nodes[d0];
+
+ if (d0 < d1)
+ for (int d = d0; d < d1; ++d)
+ p.nodes[d] = p.nodes[d + 1];
+ else
+ for (int d = d0; d > d1; --d)
+ p.nodes[d] = p.nodes[d - 1];
+
+ p.nodes[d1] = node;
+}
+
+void prb_dump(const prb_t &p) {
+ printf("@@@ type:%s:%s ndims:%d ", mkldnn_dt2str(p.itype),
+ mkldnn_dt2str(p.otype), p.ndims);
+ for (int d = 0; d < p.ndims; ++d)
+ printf("[%zu:%td:%td:%td]",
+ p.nodes[d].n, p.nodes[d].is, p.nodes[d].os, p.nodes[d].ss);
+ printf(" off:%zu:%zu\n", p.ioff, p.ooff);
+}
+
+}
+
+}
+}
+}