diff options
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp')
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp | 155 |
1 files changed, 0 insertions, 155 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp deleted file mode 100644 index 057cc3c4c7..0000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp +++ /dev/null @@ -1,155 +0,0 @@ -/******************************************************************************* -* Copyright 2017-2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef SIMPLE_CONCAT_HPP -#define SIMPLE_CONCAT_HPP - -#include "memory_tracking.hpp" - -#include "cpu_concat_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -template <data_type_t data_type> -struct simple_concat_t: public cpu_primitive_t { - struct pd_t: public cpu_concat_pd_t { - using cpu_concat_pd_t::cpu_concat_pd_t; - - pd_t(const pd_t &rhs): cpu_concat_pd_t(rhs) { - int ndims = rhs.dst_md_.ndims; - utils::array_copy(perm_, rhs.perm_, ndims); - utils::array_copy(iperm_, rhs.iperm_, ndims); - utils::array_copy(blocks_, rhs.blocks_, ndims); - } - - DECLARE_CONCAT_PD_T("simple:any", simple_concat_t); - - status_t init() { - const memory_desc_wrapper dst_d(dst_md()); - bool ok = true - && cpu_concat_pd_t::init() == status::success - && dst_d.ndims() <= 6; - if (!ok) return status::unimplemented; - - for (size_t i = 0; i < src_mds_.size(); ++i) { - const memory_desc_wrapper i_d(&src_mds_[i]); - const memory_desc_wrapper o_d(&src_image_mds_[i]); - - const int ignore_strides = 0; - - ok = ok - && utils::everyone_is(data_type, i_d.data_type(), - o_d.data_type()) - && utils::everyone_is(format_kind::blocked, - i_d.format_kind(), o_d.format_kind()) - && types::blocking_desc_is_equal(i_d.blocking_desc(), - o_d.blocking_desc(), ignore_strides) - && types::blocking_desc_is_equal(i_d.blocking_desc(), - dst_d.blocking_desc(), ignore_strides) - && !i_d.is_additional_buffer(); - if (!ok) return status::unimplemented; - } - - dst_d.compute_blocks(blocks_); - format_perm(); - - // start dim is the first dimension after which the concatenation - // would happen contiguously - const int start_dim = perm_[concat_dim()]; - - // check that contiguous part is indeed contiguous (i.e. dense) - if (nelems_to_concat(dst_d) != - dst_d.padded_dims()[concat_dim()] / blocks_[concat_dim()] - * dst_d.blocking_desc().strides[concat_dim()]) - return status::unimplemented; - - // check that all inputs have the same strides for the - // contiguous part [concat_dim .. ndims] for the *major* dims. - // the block part is already checked above - for (size_t i = 0; i < src_mds_.size(); ++i) { - const memory_desc_wrapper i_d(&src_mds_[i]); - for (int d = start_dim; d < dst_d.ndims(); ++d) { - if (dst_d.blocking_desc().strides[iperm_[d]] - != i_d.blocking_desc().strides[iperm_[d]]) - return status::unimplemented; - } - } - - init_scratchpad(); - - return status::success; - } - - int perm_[MKLDNN_MAX_NDIMS] {}; - int iperm_[MKLDNN_MAX_NDIMS] {}; - dims_t blocks_ {}; - - dim_t nelems_to_concat(const memory_desc_wrapper &data_d) const { - const int ndims = data_d.ndims(); - - dim_t nelems = 1; - for (int i = perm_[concat_dim()]; i < ndims; i++) - nelems *= data_d.dims()[iperm_[i]] / blocks_[iperm_[i]]; - for (int i = 0; i < ndims; i++) - nelems *= blocks_[i]; - - return nelems; - } - - private: - void format_perm() { - const memory_desc_wrapper dst_d(dst_md()); - const int ndims = dst_d.ndims(); - - strides_t strides; - utils::array_copy(strides, dst_d.blocking_desc().strides, ndims); - for (int i = 0; i < ndims; i++) iperm_[i] = i; - - utils::simultaneous_sort(strides, iperm_, ndims, - [](stride_t a, stride_t b) { return b - a; }); - - for (int i = 0; i < ndims; i++) perm_[iperm_[i]] = i; - } - - void init_scratchpad() { - using namespace memory_tracking::names; - auto scratchpad = scratchpad_registry().registrar(); - scratchpad.book(key_concat_iptrs, sizeof(data_t *) * n_inputs()); - scratchpad.book(key_concat_optrs, sizeof(data_t *) * n_inputs()); - scratchpad.book(key_concat_nelems, sizeof(dim_t) * n_inputs()); - scratchpad.book(key_concat_istrides, - sizeof(strides_t) * n_inputs()); - } - }; - - simple_concat_t(const pd_t *apd): cpu_primitive_t(apd) {} - - virtual status_t execute(const exec_ctx_t &ctx) const override; - - typedef typename prec_traits<data_type>::type data_t; - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } -}; - -} -} -} - -#endif |