summaryrefslogtreecommitdiffstats
path: root/thirdparty/oidn/core/node.h
diff options
context:
space:
mode:
Diffstat (limited to 'thirdparty/oidn/core/node.h')
-rw-r--r--thirdparty/oidn/core/node.h142
1 files changed, 0 insertions, 142 deletions
diff --git a/thirdparty/oidn/core/node.h b/thirdparty/oidn/core/node.h
deleted file mode 100644
index b9ffe906df..0000000000
--- a/thirdparty/oidn/core/node.h
+++ /dev/null
@@ -1,142 +0,0 @@
-// ======================================================================== //
-// Copyright 2009-2019 Intel Corporation //
-// //
-// Licensed under the Apache License, Version 2.0 (the "License"); //
-// you may not use this file except in compliance with the License. //
-// You may obtain a copy of the License at //
-// //
-// http://www.apache.org/licenses/LICENSE-2.0 //
-// //
-// Unless required by applicable law or agreed to in writing, software //
-// distributed under the License is distributed on an "AS IS" BASIS, //
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
-// See the License for the specific language governing permissions and //
-// limitations under the License. //
-// ======================================================================== //
-
-#pragma once
-
-#include "common.h"
-#include <vector>
-
-namespace oidn {
-
- class Node
- {
- public:
- virtual ~Node() = default;
-
- virtual void execute(stream& sm) = 0;
-
- virtual std::shared_ptr<memory> getDst() const { return nullptr; }
-
- virtual size_t getScratchpadSize() const { return 0; }
- virtual void setScratchpad(const std::shared_ptr<memory>& mem) {}
-
- virtual void setTile(int h1, int w1, int h2, int w2, int H, int W)
- {
- assert(0); // not supported
- }
- };
-
- // Node wrapping an MKL-DNN primitive
- class MklNode : public Node
- {
- private:
- primitive prim;
- std::unordered_map<int, memory> args;
- std::shared_ptr<memory> scratchpad;
-
- public:
- MklNode(const primitive& prim, const std::unordered_map<int, memory>& args)
- : prim(prim),
- args(args)
- {}
-
- size_t getScratchpadSize() const override
- {
- const auto primDesc = prim.get_primitive_desc();
- const mkldnn_memory_desc_t* scratchpadDesc = mkldnn_primitive_desc_query_md(primDesc, mkldnn_query_scratchpad_md, 0);
- if (scratchpadDesc == nullptr)
- return 0;
- return mkldnn_memory_desc_get_size(scratchpadDesc);
- }
-
- void setScratchpad(const std::shared_ptr<memory>& mem) override
- {
- scratchpad = mem;
- args.insert(std::make_pair(MKLDNN_ARG_SCRATCHPAD, *scratchpad));
- }
-
- void execute(stream& sm) override
- {
- prim.execute(sm, args);
- }
- };
-
- // Convolution node
- class ConvNode : public MklNode
- {
- private:
- std::shared_ptr<memory> src;
- std::shared_ptr<memory> weights;
- std::shared_ptr<memory> bias;
- std::shared_ptr<memory> dst;
-
- public:
- ConvNode(const convolution_forward::primitive_desc& desc,
- const std::shared_ptr<memory>& src,
- const std::shared_ptr<memory>& weights,
- const std::shared_ptr<memory>& bias,
- const std::shared_ptr<memory>& dst)
- : MklNode(convolution_forward(desc),
- { { MKLDNN_ARG_SRC, *src },
- { MKLDNN_ARG_WEIGHTS, *weights },
- { MKLDNN_ARG_BIAS, *bias },
- { MKLDNN_ARG_DST, *dst } }),
- src(src), weights(weights), bias(bias), dst(dst)
- {}
-
- std::shared_ptr<memory> getDst() const override { return dst; }
- };
-
- // Pooling node
- class PoolNode : public MklNode
- {
- private:
- std::shared_ptr<memory> src;
- std::shared_ptr<memory> dst;
-
- public:
- PoolNode(const pooling_forward::primitive_desc& desc,
- const std::shared_ptr<memory>& src,
- const std::shared_ptr<memory>& dst)
- : MklNode(pooling_forward(desc),
- { { MKLDNN_ARG_SRC, *src },
- { MKLDNN_ARG_DST, *dst } }),
- src(src), dst(dst)
- {}
-
- std::shared_ptr<memory> getDst() const override { return dst; }
- };
-
- // Reorder node
- class ReorderNode : public MklNode
- {
- private:
- std::shared_ptr<memory> src;
- std::shared_ptr<memory> dst;
-
- public:
- ReorderNode(const std::shared_ptr<memory>& src,
- const std::shared_ptr<memory>& dst)
- : MklNode(reorder(reorder::primitive_desc(*src, *dst)),
- { { MKLDNN_ARG_SRC, *src },
- { MKLDNN_ARG_DST, *dst } }),
- src(src), dst(dst)
- {}
-
- std::shared_ptr<memory> getDst() const override { return dst; }
- };
-
-} // namespace oidn