diff options
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp')
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp | 502 |
1 files changed, 0 insertions, 502 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp deleted file mode 100644 index d61903c32d..0000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp +++ /dev/null @@ -1,502 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_REF_DECONVOLUTION_HPP -#define CPU_REF_DECONVOLUTION_HPP - -#include <assert.h> -#include <string.h> - -#include "c_types_map.hpp" -#include "type_helpers.hpp" -#include "utils.hpp" -#include "primitive_iterator.hpp" - -#include "cpu_convolution_pd.hpp" -#include "cpu_deconvolution_pd.hpp" -#include "cpu_primitive.hpp" - -namespace mkldnn { -namespace impl { -namespace cpu { - -static status_t compute_blocked_format(bool with_groups, - const memory_desc_t *oi_md, memory_desc_t *io_md) -{ - /* Computes blocking for *i*o* format from *o*i* format */ - - bool sanity_check_ok = true - && oi_md->ndims == io_md->ndims - && oi_md->format_kind == format_kind::blocked; - if (!sanity_check_ok) return status::invalid_arguments; - - const blocking_desc_t &oi_blk = oi_md->format_desc.blocking; - blocking_desc_t io_blk = io_md->format_desc.blocking; - - io_md->format_kind = format_kind::blocked; - io_blk = oi_blk; - - const int ID_OC = 0 + with_groups; - const int ID_IC = 1 + with_groups; - - nstl::swap(io_blk.strides[ID_OC], io_blk.strides[ID_IC]); - for (int i_blk = 0; i_blk < io_blk.inner_nblks; ++i_blk) { - if (utils::one_of(io_blk.inner_idxs[i_blk], ID_OC, ID_IC)) { - io_blk.inner_idxs[i_blk] = - (io_blk.inner_idxs[i_blk] == ID_OC ? ID_IC : ID_OC); - } - } - - return memory_desc_init_by_blocking_desc(*io_md, io_blk); -} - -static status_t conv_descr_create(const deconvolution_desc_t *dd, - convolution_desc_t *cd) -{ - using namespace prop_kind; - alg_kind_t alg_kind = dd->alg_kind == alg_kind::deconvolution_direct - ? alg_kind::convolution_direct : alg_kind::convolution_winograd; - - const memory_desc_t *src_md, *dst_md, *d_weights_d; - prop_kind_t prop_kind; - memory_desc_t c_weights_d; - if (utils::one_of(dd->prop_kind, forward_training, forward_inference)) { - prop_kind = backward_data; - src_md = &dd->dst_desc; - dst_md = &dd->src_desc; - d_weights_d = &dd->weights_desc; - } else if (dd->prop_kind == backward_data) { - prop_kind = forward_training; - src_md = &dd->diff_dst_desc; - dst_md = &dd->diff_src_desc; - d_weights_d = &dd->weights_desc; - } else { - prop_kind = dd->prop_kind; - src_md = &dd->diff_dst_desc; - dst_md = &dd->src_desc; - d_weights_d = &dd->diff_weights_desc; - } - - const bool with_groups = d_weights_d->ndims == src_md->ndims + 1; - - /* create weights desc for convolution */ - c_weights_d = *d_weights_d; - - const int ID_OC = 0 + with_groups; - const int ID_IC = 1 + with_groups; - - nstl::swap(c_weights_d.dims[ID_OC], c_weights_d.dims[ID_IC]); - nstl::swap(c_weights_d.padded_dims[ID_OC], c_weights_d.padded_dims[ID_IC]); - nstl::swap(c_weights_d.padded_offsets[ID_OC], c_weights_d.padded_offsets[ID_IC]); - - if (c_weights_d.format_kind != format_kind::any) - CHECK(compute_blocked_format(with_groups, d_weights_d, &c_weights_d)); - - return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d, - prop_kind != backward_weights ? &dd->bias_desc : nullptr, - dst_md, dd->strides, dd->dilates, - dd->padding[0], dd->padding[1], dd->padding_kind); -} - -struct ref_deconvolution_fwd_t: public cpu_primitive_t { - struct pd_t: public cpu_deconvolution_fwd_pd_t { - pd_t(engine_t *engine, - const deconvolution_desc_t *adesc, - const primitive_attr_t *attr, - const deconvolution_fwd_pd_t *hint_fwd_pd) - : cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd) - , conv_pd_(nullptr) - {} - - pd_t(const pd_t &other) - : cpu_deconvolution_fwd_pd_t(other) - , conv_pd_(other.conv_pd_->clone()) - , conv_supports_bias_(other.conv_supports_bias_) - , dst_tag_(other.dst_tag_) - {} - - ~pd_t() { delete conv_pd_; } - - DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_fwd_t); - - status_t init_convolution() { - using namespace types; - - convolution_desc_t cd; - CHECK(conv_descr_create(desc(), &cd)); - - mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd, - &attr_, nullptr); - while (++it != it.end()) { - conv_pd_ = *it; - conv_supports_bias_ = - static_cast<cpu_convolution_bwd_data_pd_t *>(conv_pd_) - ->support_bias(); - bool output_f32 = utils::everyone_is(data_type::f32, - desc()->accum_data_type, desc()->dst_desc.data_type); - - bool ok = true - && conv_pd_->weights_md()->extra.flags == 0 - /* deconv reference code can process only f32 bias */ - && IMPLICATION(with_bias(), - conv_supports_bias_ || output_f32); - if (ok) return status::success; - - delete conv_pd_; - } - conv_pd_ = nullptr; - return status::unimplemented; - } - - status_t init() { - using namespace format_tag; - bool ok = true - && is_fwd() - && utils::one_of(desc()->alg_kind, - alg_kind::deconvolution_direct, - alg_kind::deconvolution_winograd) - && attr()->post_ops_.has_default_values(); - - if (ok) { - CHECK(init_convolution()); - if (weights_md_.format_kind == format_kind::any) { - CHECK(compute_blocked_format(with_groups(), - conv_pd_->weights_md(), &desc_.weights_desc)); - weights_md_ = desc_.weights_desc; - } - if (src_md_.format_kind == format_kind::any) - src_md_ = *conv_pd_->diff_dst_md(); - if (dst_md_.format_kind == format_kind::any) - dst_md_ = *conv_pd_->diff_src_md(); - if (bias_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(bias_md_, x)); - - dst_tag_ = memory_desc_matches_one_of_tag(dst_md_, - utils::pick(ndims() - 3, ncw, nchw, ncdhw), - utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c), - utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c)); - - return status::success; - } - - return status::unimplemented; - } - - virtual void init_scratchpad_md() override { - scratchpad_md_ = *conv_pd_->scratchpad_md(); - } - - primitive_desc_t *conv_pd_; - bool conv_supports_bias_; - format_tag_t dst_tag_; - }; - - typedef typename prec_traits<data_type::f32>::type data_t; - - ref_deconvolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) - { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } - ~ref_deconvolution_fwd_t() { delete conv_p_; } - - virtual status_t execute(const exec_ctx_t &ctx) const override { - const auto &args = ctx.args(); - exec_args_t conv_args; - conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC); - conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS); - if (pd()->with_bias() && pd()->conv_supports_bias_) - conv_args[MKLDNN_ARG_BIAS] = args.at(MKLDNN_ARG_BIAS); - conv_args[MKLDNN_ARG_DIFF_SRC] = args.at(MKLDNN_ARG_DST); - if (!types::is_zero_md(pd()->scratchpad_md())) - conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD); - const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args)); - - conv_p_->execute(conv_ctx); - - if (pd()->with_bias() && !pd()->conv_supports_bias_) { - using namespace format_tag; - - auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS); - auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST); - - switch (pd()->dst_tag_) { - case ncdhw: case nchw: case ncw: - compute_fwd_bias_ncdhw(bias, dst); - break; - case nCdhw8c: case nChw8c: case nCw8c: - compute_fwd_bias_nCdhwXc<8>(bias, dst); - break; - case nCdhw16c: case nChw16c: case nCw16c: - compute_fwd_bias_nCdhwXc<16>(bias, dst); - break; - default: - compute_fwd_bias(bias, dst); - break; - } - } - return status::success; - } - -private: - void compute_fwd_bias(const data_t *bias, data_t *dst) const; - void compute_fwd_bias_ncdhw(const data_t *bias, data_t *dst) const; - template <int blksize> void compute_fwd_bias_nCdhwXc(const data_t *bias, - data_t *dst) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - primitive_t *conv_p_; -}; - -struct ref_deconvolution_bwd_data_t: public cpu_primitive_t { - struct pd_t: public cpu_deconvolution_bwd_data_pd_t { - pd_t(engine_t *engine, const deconvolution_desc_t *adesc, - const primitive_attr_t *attr, - const deconvolution_fwd_pd_t *hint_fwd_pd) - : cpu_deconvolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd) - , conv_pd_(nullptr) - {} - - pd_t(const pd_t &other) - : cpu_deconvolution_bwd_data_pd_t(other) - , conv_pd_(other.conv_pd_->clone()) {} - - ~pd_t() { delete conv_pd_; } - - DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_data_t); - - status_t init_convolution() { - using namespace types; - - convolution_desc_t cd; - status_t status = conv_descr_create(desc(), &cd); - if (status != status::success) return status; - - mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd, - &attr_, nullptr); - while (++it != it.end()) { - conv_pd_ = *it; - if (conv_pd_->weights_md()->extra.flags == 0) - return status::success; - delete conv_pd_; - } - - return status::unimplemented; - } - - status_t init() { - using namespace data_type; - bool ok = true - && desc()->prop_kind == prop_kind::backward_data - && utils::everyone_is(data_type::f32, - desc()->diff_src_desc.data_type, - desc()->weights_desc.data_type, - desc()->diff_dst_desc.data_type) - && utils::one_of(desc()->alg_kind, - alg_kind::deconvolution_direct, - alg_kind::deconvolution_winograd); - - if (ok) { - CHECK(init_convolution()); - if (weights_md_.format_kind == format_kind::any) { - CHECK(compute_blocked_format(with_groups(), - conv_pd_->weights_md(), &desc_.weights_desc)); - weights_md_ = desc_.weights_desc; - } - if (diff_src_md_.format_kind == format_kind::any) - diff_src_md_ = *conv_pd_->dst_md(); - if (diff_dst_md_.format_kind == format_kind::any) - diff_dst_md_ = *conv_pd_->src_md(); - - return status::success; - } - - return status::unimplemented; - } - - virtual void init_scratchpad_md() override { - scratchpad_md_ = *conv_pd_->scratchpad_md(); - } - - primitive_desc_t *conv_pd_; - }; - - typedef typename prec_traits<data_type::f32>::type data_t; - - ref_deconvolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) - { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } - ~ref_deconvolution_bwd_data_t() { delete conv_p_; } - - virtual status_t execute(const exec_ctx_t &ctx) const override { - const auto &args = ctx.args(); - exec_args_t conv_args; - conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST); - conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS); - conv_args[MKLDNN_ARG_DST] = args.at(MKLDNN_ARG_DIFF_SRC); - if (!types::is_zero_md(pd()->scratchpad_md())) - conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD); - const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args)); - - conv_p_->execute(conv_ctx); - return status::success; - } - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - primitive_t *conv_p_; -}; - -struct ref_deconvolution_bwd_weights_t: public cpu_primitive_t { - struct pd_t: public cpu_deconvolution_bwd_weights_pd_t { - pd_t(engine_t *engine, - const deconvolution_desc_t *adesc, - const primitive_attr_t *attr, - const deconvolution_fwd_pd_t *hint_fwd_pd) - : cpu_deconvolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd) - , conv_pd_(nullptr) - {} - - pd_t(const pd_t &other) - : cpu_deconvolution_bwd_weights_pd_t(other) - , conv_pd_(other.conv_pd_->clone()) - , dst_tag_(other.dst_tag_) - {} - - ~pd_t() { delete conv_pd_; } - - DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_weights_t); - - status_t init_convolution() { - using namespace types; - - convolution_desc_t cd; - status_t status = conv_descr_create(desc(), &cd); - if (status != status::success) return status; - - mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd, - &attr_, nullptr); - while (++it != it.end()) { - conv_pd_ = *it; - if (conv_pd_->diff_weights_md()->extra.flags == 0) - return status::success; - delete conv_pd_; - } - return status::unimplemented; - } - - status_t init() { - using namespace format_tag; - bool ok = true - && desc()->prop_kind == prop_kind::backward_weights - && utils::everyone_is(data_type::f32, - desc()->src_desc.data_type, - desc()->diff_weights_desc.data_type, - desc()->diff_dst_desc.data_type) - && utils::one_of(desc()->alg_kind, - alg_kind::deconvolution_direct, - alg_kind::deconvolution_winograd) - && attr()->has_default_values(); - if (ok) { - CHECK(init_convolution()); - if (diff_weights_md_.format_kind == format_kind::any) { - CHECK(compute_blocked_format(with_groups(), - conv_pd_->diff_weights_md(), - &desc_.diff_weights_desc)); - diff_weights_md_ = desc_.diff_weights_desc; - } - if (src_md_.format_kind == format_kind::any) - src_md_ = *conv_pd_->diff_dst_md(); - if (diff_dst_md_.format_kind == format_kind::any) - diff_dst_md_ = *conv_pd_->src_md(); - if (diff_bias_md_.format_kind == format_kind::any) - CHECK(memory_desc_init_by_tag(diff_bias_md_, x)); - - dst_tag_ = memory_desc_matches_one_of_tag(diff_dst_md_, - utils::pick(ndims() - 3, ncw, nchw, ncdhw), - utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c), - utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c)); - - return status::success; - } - - return status::unimplemented; - } - - virtual void init_scratchpad_md() override { - scratchpad_md_ = *conv_pd_->scratchpad_md(); - } - - primitive_desc_t *conv_pd_; - format_tag_t dst_tag_; - }; - - typedef typename prec_traits<data_type::f32>::type data_t; - - ref_deconvolution_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) - { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); } - ~ref_deconvolution_bwd_weights_t() { delete conv_p_; } - - virtual status_t execute(const exec_ctx_t &ctx) const override { - const auto &args = ctx.args(); - exec_args_t conv_args; - conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC); - conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST); - conv_args[MKLDNN_ARG_DIFF_WEIGHTS] = args.at(MKLDNN_ARG_DIFF_WEIGHTS); - if (!types::is_zero_md(pd()->scratchpad_md())) - conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD); - const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args)); - - status_t status = conv_p_->execute(conv_ctx); - if (status != status::success) return status; - - if (pd()->with_bias()) { - using namespace format_tag; - - auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST); - auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS); - - switch (pd()->dst_tag_) { - case ncdhw: case nchw: case ncw: - compute_bwd_bias_ncdhw(diff_dst, diff_bias); - break; - case nCdhw8c: case nChw8c: case nCw8c: - compute_bwd_bias_nCdhwXc<8>(diff_dst, diff_bias); - break; - case nCdhw16c: case nChw16c: case nCw16c: - compute_bwd_bias_nCdhwXc<16>(diff_dst, diff_bias); - break; - default: - compute_bwd_bias(diff_dst, diff_bias); - break; - } - } - return status::success; - } - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); } - void compute_bwd_bias(const data_t *diff_dst, data_t *diff_bias) const; - void compute_bwd_bias_ncdhw(const data_t *diff_dst, - data_t *diff_bias) const; - template <int blksize> void compute_bwd_bias_nCdhwXc( - const data_t *diff_dst, data_t *diff_bias) const; - - primitive_t *conv_p_; -}; - -} -} -} - -#endif - -// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s |