summaryrefslogtreecommitdiffstats
path: root/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp')
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp502
1 files changed, 0 insertions, 502 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp
deleted file mode 100644
index d61903c32d..0000000000
--- a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp
+++ /dev/null
@@ -1,502 +0,0 @@
-/*******************************************************************************
-* Copyright 2018 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*******************************************************************************/
-
-#ifndef CPU_REF_DECONVOLUTION_HPP
-#define CPU_REF_DECONVOLUTION_HPP
-
-#include <assert.h>
-#include <string.h>
-
-#include "c_types_map.hpp"
-#include "type_helpers.hpp"
-#include "utils.hpp"
-#include "primitive_iterator.hpp"
-
-#include "cpu_convolution_pd.hpp"
-#include "cpu_deconvolution_pd.hpp"
-#include "cpu_primitive.hpp"
-
-namespace mkldnn {
-namespace impl {
-namespace cpu {
-
-static status_t compute_blocked_format(bool with_groups,
- const memory_desc_t *oi_md, memory_desc_t *io_md)
-{
- /* Computes blocking for *i*o* format from *o*i* format */
-
- bool sanity_check_ok = true
- && oi_md->ndims == io_md->ndims
- && oi_md->format_kind == format_kind::blocked;
- if (!sanity_check_ok) return status::invalid_arguments;
-
- const blocking_desc_t &oi_blk = oi_md->format_desc.blocking;
- blocking_desc_t io_blk = io_md->format_desc.blocking;
-
- io_md->format_kind = format_kind::blocked;
- io_blk = oi_blk;
-
- const int ID_OC = 0 + with_groups;
- const int ID_IC = 1 + with_groups;
-
- nstl::swap(io_blk.strides[ID_OC], io_blk.strides[ID_IC]);
- for (int i_blk = 0; i_blk < io_blk.inner_nblks; ++i_blk) {
- if (utils::one_of(io_blk.inner_idxs[i_blk], ID_OC, ID_IC)) {
- io_blk.inner_idxs[i_blk] =
- (io_blk.inner_idxs[i_blk] == ID_OC ? ID_IC : ID_OC);
- }
- }
-
- return memory_desc_init_by_blocking_desc(*io_md, io_blk);
-}
-
-static status_t conv_descr_create(const deconvolution_desc_t *dd,
- convolution_desc_t *cd)
-{
- using namespace prop_kind;
- alg_kind_t alg_kind = dd->alg_kind == alg_kind::deconvolution_direct
- ? alg_kind::convolution_direct : alg_kind::convolution_winograd;
-
- const memory_desc_t *src_md, *dst_md, *d_weights_d;
- prop_kind_t prop_kind;
- memory_desc_t c_weights_d;
- if (utils::one_of(dd->prop_kind, forward_training, forward_inference)) {
- prop_kind = backward_data;
- src_md = &dd->dst_desc;
- dst_md = &dd->src_desc;
- d_weights_d = &dd->weights_desc;
- } else if (dd->prop_kind == backward_data) {
- prop_kind = forward_training;
- src_md = &dd->diff_dst_desc;
- dst_md = &dd->diff_src_desc;
- d_weights_d = &dd->weights_desc;
- } else {
- prop_kind = dd->prop_kind;
- src_md = &dd->diff_dst_desc;
- dst_md = &dd->src_desc;
- d_weights_d = &dd->diff_weights_desc;
- }
-
- const bool with_groups = d_weights_d->ndims == src_md->ndims + 1;
-
- /* create weights desc for convolution */
- c_weights_d = *d_weights_d;
-
- const int ID_OC = 0 + with_groups;
- const int ID_IC = 1 + with_groups;
-
- nstl::swap(c_weights_d.dims[ID_OC], c_weights_d.dims[ID_IC]);
- nstl::swap(c_weights_d.padded_dims[ID_OC], c_weights_d.padded_dims[ID_IC]);
- nstl::swap(c_weights_d.padded_offsets[ID_OC], c_weights_d.padded_offsets[ID_IC]);
-
- if (c_weights_d.format_kind != format_kind::any)
- CHECK(compute_blocked_format(with_groups, d_weights_d, &c_weights_d));
-
- return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d,
- prop_kind != backward_weights ? &dd->bias_desc : nullptr,
- dst_md, dd->strides, dd->dilates,
- dd->padding[0], dd->padding[1], dd->padding_kind);
-}
-
-struct ref_deconvolution_fwd_t: public cpu_primitive_t {
- struct pd_t: public cpu_deconvolution_fwd_pd_t {
- pd_t(engine_t *engine,
- const deconvolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const deconvolution_fwd_pd_t *hint_fwd_pd)
- : cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
- , conv_pd_(nullptr)
- {}
-
- pd_t(const pd_t &other)
- : cpu_deconvolution_fwd_pd_t(other)
- , conv_pd_(other.conv_pd_->clone())
- , conv_supports_bias_(other.conv_supports_bias_)
- , dst_tag_(other.dst_tag_)
- {}
-
- ~pd_t() { delete conv_pd_; }
-
- DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_fwd_t);
-
- status_t init_convolution() {
- using namespace types;
-
- convolution_desc_t cd;
- CHECK(conv_descr_create(desc(), &cd));
-
- mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd,
- &attr_, nullptr);
- while (++it != it.end()) {
- conv_pd_ = *it;
- conv_supports_bias_ =
- static_cast<cpu_convolution_bwd_data_pd_t *>(conv_pd_)
- ->support_bias();
- bool output_f32 = utils::everyone_is(data_type::f32,
- desc()->accum_data_type, desc()->dst_desc.data_type);
-
- bool ok = true
- && conv_pd_->weights_md()->extra.flags == 0
- /* deconv reference code can process only f32 bias */
- && IMPLICATION(with_bias(),
- conv_supports_bias_ || output_f32);
- if (ok) return status::success;
-
- delete conv_pd_;
- }
- conv_pd_ = nullptr;
- return status::unimplemented;
- }
-
- status_t init() {
- using namespace format_tag;
- bool ok = true
- && is_fwd()
- && utils::one_of(desc()->alg_kind,
- alg_kind::deconvolution_direct,
- alg_kind::deconvolution_winograd)
- && attr()->post_ops_.has_default_values();
-
- if (ok) {
- CHECK(init_convolution());
- if (weights_md_.format_kind == format_kind::any) {
- CHECK(compute_blocked_format(with_groups(),
- conv_pd_->weights_md(), &desc_.weights_desc));
- weights_md_ = desc_.weights_desc;
- }
- if (src_md_.format_kind == format_kind::any)
- src_md_ = *conv_pd_->diff_dst_md();
- if (dst_md_.format_kind == format_kind::any)
- dst_md_ = *conv_pd_->diff_src_md();
- if (bias_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(bias_md_, x));
-
- dst_tag_ = memory_desc_matches_one_of_tag(dst_md_,
- utils::pick(ndims() - 3, ncw, nchw, ncdhw),
- utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c),
- utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c));
-
- return status::success;
- }
-
- return status::unimplemented;
- }
-
- virtual void init_scratchpad_md() override {
- scratchpad_md_ = *conv_pd_->scratchpad_md();
- }
-
- primitive_desc_t *conv_pd_;
- bool conv_supports_bias_;
- format_tag_t dst_tag_;
- };
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- ref_deconvolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd)
- { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); }
- ~ref_deconvolution_fwd_t() { delete conv_p_; }
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- const auto &args = ctx.args();
- exec_args_t conv_args;
- conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC);
- conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS);
- if (pd()->with_bias() && pd()->conv_supports_bias_)
- conv_args[MKLDNN_ARG_BIAS] = args.at(MKLDNN_ARG_BIAS);
- conv_args[MKLDNN_ARG_DIFF_SRC] = args.at(MKLDNN_ARG_DST);
- if (!types::is_zero_md(pd()->scratchpad_md()))
- conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD);
- const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args));
-
- conv_p_->execute(conv_ctx);
-
- if (pd()->with_bias() && !pd()->conv_supports_bias_) {
- using namespace format_tag;
-
- auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
- auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
-
- switch (pd()->dst_tag_) {
- case ncdhw: case nchw: case ncw:
- compute_fwd_bias_ncdhw(bias, dst);
- break;
- case nCdhw8c: case nChw8c: case nCw8c:
- compute_fwd_bias_nCdhwXc<8>(bias, dst);
- break;
- case nCdhw16c: case nChw16c: case nCw16c:
- compute_fwd_bias_nCdhwXc<16>(bias, dst);
- break;
- default:
- compute_fwd_bias(bias, dst);
- break;
- }
- }
- return status::success;
- }
-
-private:
- void compute_fwd_bias(const data_t *bias, data_t *dst) const;
- void compute_fwd_bias_ncdhw(const data_t *bias, data_t *dst) const;
- template <int blksize> void compute_fwd_bias_nCdhwXc(const data_t *bias,
- data_t *dst) const;
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- primitive_t *conv_p_;
-};
-
-struct ref_deconvolution_bwd_data_t: public cpu_primitive_t {
- struct pd_t: public cpu_deconvolution_bwd_data_pd_t {
- pd_t(engine_t *engine, const deconvolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const deconvolution_fwd_pd_t *hint_fwd_pd)
- : cpu_deconvolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
- , conv_pd_(nullptr)
- {}
-
- pd_t(const pd_t &other)
- : cpu_deconvolution_bwd_data_pd_t(other)
- , conv_pd_(other.conv_pd_->clone()) {}
-
- ~pd_t() { delete conv_pd_; }
-
- DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_data_t);
-
- status_t init_convolution() {
- using namespace types;
-
- convolution_desc_t cd;
- status_t status = conv_descr_create(desc(), &cd);
- if (status != status::success) return status;
-
- mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd,
- &attr_, nullptr);
- while (++it != it.end()) {
- conv_pd_ = *it;
- if (conv_pd_->weights_md()->extra.flags == 0)
- return status::success;
- delete conv_pd_;
- }
-
- return status::unimplemented;
- }
-
- status_t init() {
- using namespace data_type;
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_data
- && utils::everyone_is(data_type::f32,
- desc()->diff_src_desc.data_type,
- desc()->weights_desc.data_type,
- desc()->diff_dst_desc.data_type)
- && utils::one_of(desc()->alg_kind,
- alg_kind::deconvolution_direct,
- alg_kind::deconvolution_winograd);
-
- if (ok) {
- CHECK(init_convolution());
- if (weights_md_.format_kind == format_kind::any) {
- CHECK(compute_blocked_format(with_groups(),
- conv_pd_->weights_md(), &desc_.weights_desc));
- weights_md_ = desc_.weights_desc;
- }
- if (diff_src_md_.format_kind == format_kind::any)
- diff_src_md_ = *conv_pd_->dst_md();
- if (diff_dst_md_.format_kind == format_kind::any)
- diff_dst_md_ = *conv_pd_->src_md();
-
- return status::success;
- }
-
- return status::unimplemented;
- }
-
- virtual void init_scratchpad_md() override {
- scratchpad_md_ = *conv_pd_->scratchpad_md();
- }
-
- primitive_desc_t *conv_pd_;
- };
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- ref_deconvolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd)
- { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); }
- ~ref_deconvolution_bwd_data_t() { delete conv_p_; }
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- const auto &args = ctx.args();
- exec_args_t conv_args;
- conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST);
- conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS);
- conv_args[MKLDNN_ARG_DST] = args.at(MKLDNN_ARG_DIFF_SRC);
- if (!types::is_zero_md(pd()->scratchpad_md()))
- conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD);
- const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args));
-
- conv_p_->execute(conv_ctx);
- return status::success;
- }
-
-private:
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- primitive_t *conv_p_;
-};
-
-struct ref_deconvolution_bwd_weights_t: public cpu_primitive_t {
- struct pd_t: public cpu_deconvolution_bwd_weights_pd_t {
- pd_t(engine_t *engine,
- const deconvolution_desc_t *adesc,
- const primitive_attr_t *attr,
- const deconvolution_fwd_pd_t *hint_fwd_pd)
- : cpu_deconvolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
- , conv_pd_(nullptr)
- {}
-
- pd_t(const pd_t &other)
- : cpu_deconvolution_bwd_weights_pd_t(other)
- , conv_pd_(other.conv_pd_->clone())
- , dst_tag_(other.dst_tag_)
- {}
-
- ~pd_t() { delete conv_pd_; }
-
- DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_weights_t);
-
- status_t init_convolution() {
- using namespace types;
-
- convolution_desc_t cd;
- status_t status = conv_descr_create(desc(), &cd);
- if (status != status::success) return status;
-
- mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd,
- &attr_, nullptr);
- while (++it != it.end()) {
- conv_pd_ = *it;
- if (conv_pd_->diff_weights_md()->extra.flags == 0)
- return status::success;
- delete conv_pd_;
- }
- return status::unimplemented;
- }
-
- status_t init() {
- using namespace format_tag;
- bool ok = true
- && desc()->prop_kind == prop_kind::backward_weights
- && utils::everyone_is(data_type::f32,
- desc()->src_desc.data_type,
- desc()->diff_weights_desc.data_type,
- desc()->diff_dst_desc.data_type)
- && utils::one_of(desc()->alg_kind,
- alg_kind::deconvolution_direct,
- alg_kind::deconvolution_winograd)
- && attr()->has_default_values();
- if (ok) {
- CHECK(init_convolution());
- if (diff_weights_md_.format_kind == format_kind::any) {
- CHECK(compute_blocked_format(with_groups(),
- conv_pd_->diff_weights_md(),
- &desc_.diff_weights_desc));
- diff_weights_md_ = desc_.diff_weights_desc;
- }
- if (src_md_.format_kind == format_kind::any)
- src_md_ = *conv_pd_->diff_dst_md();
- if (diff_dst_md_.format_kind == format_kind::any)
- diff_dst_md_ = *conv_pd_->src_md();
- if (diff_bias_md_.format_kind == format_kind::any)
- CHECK(memory_desc_init_by_tag(diff_bias_md_, x));
-
- dst_tag_ = memory_desc_matches_one_of_tag(diff_dst_md_,
- utils::pick(ndims() - 3, ncw, nchw, ncdhw),
- utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c),
- utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c));
-
- return status::success;
- }
-
- return status::unimplemented;
- }
-
- virtual void init_scratchpad_md() override {
- scratchpad_md_ = *conv_pd_->scratchpad_md();
- }
-
- primitive_desc_t *conv_pd_;
- format_tag_t dst_tag_;
- };
-
- typedef typename prec_traits<data_type::f32>::type data_t;
-
- ref_deconvolution_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd)
- { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); }
- ~ref_deconvolution_bwd_weights_t() { delete conv_p_; }
-
- virtual status_t execute(const exec_ctx_t &ctx) const override {
- const auto &args = ctx.args();
- exec_args_t conv_args;
- conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC);
- conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST);
- conv_args[MKLDNN_ARG_DIFF_WEIGHTS] = args.at(MKLDNN_ARG_DIFF_WEIGHTS);
- if (!types::is_zero_md(pd()->scratchpad_md()))
- conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD);
- const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args));
-
- status_t status = conv_p_->execute(conv_ctx);
- if (status != status::success) return status;
-
- if (pd()->with_bias()) {
- using namespace format_tag;
-
- auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
- auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
-
- switch (pd()->dst_tag_) {
- case ncdhw: case nchw: case ncw:
- compute_bwd_bias_ncdhw(diff_dst, diff_bias);
- break;
- case nCdhw8c: case nChw8c: case nCw8c:
- compute_bwd_bias_nCdhwXc<8>(diff_dst, diff_bias);
- break;
- case nCdhw16c: case nChw16c: case nCw16c:
- compute_bwd_bias_nCdhwXc<16>(diff_dst, diff_bias);
- break;
- default:
- compute_bwd_bias(diff_dst, diff_bias);
- break;
- }
- }
- return status::success;
- }
-
-private:
- const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
- void compute_bwd_bias(const data_t *diff_dst, data_t *diff_bias) const;
- void compute_bwd_bias_ncdhw(const data_t *diff_dst,
- data_t *diff_bias) const;
- template <int blksize> void compute_bwd_bias_nCdhwXc(
- const data_t *diff_dst, data_t *diff_bias) const;
-
- primitive_t *conv_p_;
-};
-
-}
-}
-}
-
-#endif
-
-// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s