diff options
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp')
-rw-r--r-- | thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp | 280 |
1 files changed, 0 insertions, 280 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp deleted file mode 100644 index 3dbe07d743..0000000000 --- a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp +++ /dev/null @@ -1,280 +0,0 @@ -/******************************************************************************* -* Copyright 2018 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "mkldnn.h" - -#include "mkldnn_traits.hpp" -#include "nstl.hpp" - -#include "jit_generator.hpp" - -#include "gemm.hpp" - -#include "f32/jit_avx512_common_gemm_f32.hpp" -#include "f32/jit_avx_gemm_f32.hpp" -#include "f32/ref_gemm_f32.hpp" - -#include "s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp" -#include "s8x8s32/simple_gemm_s8s8s32.hpp" -#include "s8x8s32/ref_gemm_s8x8s32.hpp" - -#include "os_blas.hpp" - -/* USE_MKL USE_CBLAS effect - * ------- --------- ------ - * yes yes use Intel(R) MKL CBLAS - * yes no use jit - * no yes system-dependent CBLAS - * no no use jit - */ - -namespace mkldnn { -namespace impl { -namespace cpu { - -mkldnn_status_t check_gemm_input(const char *transa, const char *transb, - const int *M, const int *N, const int *K, const int *lda, - const int *ldb, const int *ldc, const float *alpha, const float *beta, - const bool with_bias) { - if (utils::any_null(transa, transb, M, N, K, lda, ldb, ldc, alpha, beta)) - return mkldnn_invalid_arguments; - if (with_bias && *beta != 0) - return mkldnn_unimplemented; - bool consistency = true - && utils::one_of(*transa, 'T', 't', 'N', 'n') - && utils::one_of(*transb, 'T', 't', 'N', 'n') - && *M >= 0 - && *N >= 0 - && *K >= 0; - - if (!consistency) - return mkldnn_invalid_arguments; - bool isTransA = utils::one_of(*transa, 'T', 't'); - bool isTransB = utils::one_of(*transb, 'T', 't'); - int nrowA = isTransA ? *K : *M; - int nrowB = isTransB ? *N : *K; - consistency = true - && *lda >= nstl::max(1, nrowA) - && *ldb >= nstl::max(1, nrowB) - && *ldc >= nstl::max(1, *M); - if (!consistency) - return mkldnn_invalid_arguments; - - return mkldnn_success; -} - -mkldnn_status_t check_gemm_x8x8x32_input(const char *offsetc, - const char *transa, const char *transb, const int *M, const int *N, - const int *K, const int *lda, const int *ldb, const int *ldc, - const float *alpha, const float *beta, const bool with_bias) { - if (offsetc == nullptr) - return mkldnn_invalid_arguments; - if (!utils::one_of(*offsetc, 'F', 'f', 'C', 'c', 'R', 'r')) - return mkldnn_invalid_arguments; - - return check_gemm_input(transa, transb, M, N, K, lda, ldb, ldc, alpha, - beta, with_bias); -} - -mkldnn_status_t extended_sgemm(const char *transa, const char *transb, - const int *M, const int *N, const int *K, const float *alpha, - const float *A, const int *lda, const float *B, const int *ldb, - const float *beta, float *C, const int *ldc, - const float *bias, const bool force_jit_gemm) { - mkldnn_status_t status = check_gemm_input(transa, transb, M, N, K, - lda, ldb, ldc, alpha, beta, bias != nullptr); - if (status != mkldnn_success) - return status; - -#ifdef USE_CBLAS - if (!force_jit_gemm) { - bool trA = *transa == 't' || *transa == 'T'; - bool trB = *transb == 't' || *transb == 'T'; - CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans; - CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans; - cblas_sgemm(CblasColMajor, Cblas_trA, Cblas_trB, - *M, *N, *K, *alpha, A, *lda, B, *ldb, *beta, C, *ldc); - - if (bias) { - // Add bias if necessary (bias is applied to columns of C) - cblas_int incx = 1, incy = 1; - parallel_nd(*N, [&](int n) { - ptrdiff_t offset = (ptrdiff_t)n * (*ldc); - cblas_saxpy(*M, 1.0, bias, incx, C + offset, incy); - }); - } - return mkldnn_success; - } -#endif - - if (mayiuse(avx512_common)) - return jit_avx512_common_gemm_f32(transa, transb, - M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); - else if (mayiuse(avx)) - return jit_avx_gemm_f32(transa, transb, - M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); - else - return ref_gemm<float>(transa, transb, - M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias); -} - -template <typename b_dt> -mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, - const char *offsetc, const int *M, const int *N, const int *K, - const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, - const b_dt *B, const int *LDB, const int8_t *bo, const float *beta, - int32_t *C, const int *LDC, const int32_t *co) { - mkldnn_status_t status = check_gemm_x8x8x32_input(offsetc, transa, transb, - M, N, K, LDA, LDB, LDC, alpha, beta, false); - if (status != mkldnn_success) - return status; - - if (*M == 0 || *N == 0 || *K == 0) - return mkldnn_success; - -#if USE_MKL_IGEMM - bool OCisR = (*offsetc == 'R' || *offsetc == 'r'); - bool OCisC = (*offsetc == 'C' || *offsetc == 'c'); - bool AisN = (*transa == 'N' || *transa == 'n'); - bool BisN = (*transb == 'N' || *transb == 'n'); - - if (data_traits<b_dt>::data_type == data_type::u8) { - CBLAS_TRANSPOSE Cblas_trA = AisN ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE Cblas_trB = BisN ? CblasNoTrans : CblasTrans; - CBLAS_OFFSET Cblas_offsetc = - OCisR - ? CblasRowOffset - : OCisC - ? CblasColOffset - : CblasFixOffset; - cblas_gemm_s8u8s32(CblasColMajor, Cblas_trA, Cblas_trB, Cblas_offsetc, - *M, *N, *K, *alpha, A, *LDA, *ao, (uint8_t *)B, *LDB, *bo, - *beta, C, *LDC, co); - return mkldnn_success; - } else { - assert(data_traits<b_dt>::data_type == data_type::s8); - // TODO CBLAS implementation of gemm_s8s8s32 goes here. - // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo - if (utils::everyone_is(0, *ao, *bo)) { - return simple_gemm_s8s8s32(transa, transb, offsetc, M, - N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta, - C, LDC, co); - } else { - return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, - alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); - } - } -#else - cpu_isa_t isa = isa_any; - if (mayiuse(avx512_core_vnni)) { - isa = avx512_core_vnni; - } else if (mayiuse(avx512_core)) { - isa = avx512_core; - } - - if (data_traits<b_dt>::data_type == data_type::u8) { - switch (isa) { - case avx512_core: - case avx512_core_vnni: - return jit_avx512_core_gemm_s8u8s32(transa, transb, offsetc, M, - N, K, alpha, A, LDA, ao, (uint8_t *)B, LDB, bo, beta, - C, LDC, co); - default: - return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, - alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); - } - } else { - assert(data_traits<b_dt>::data_type == data_type::s8); - // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo - if ((mayiuse(avx512_core) || mayiuse(avx512_core_vnni)) - && *ao == 0 && *bo == 0) { - return simple_gemm_s8s8s32(transa, transb, offsetc, M, - N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta, - C, LDC, co); - } else { - return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K, - alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co); - } - } -#endif -} - -template -mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, - const char *offsetc, const int *M, const int *N, const int *K, - const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, - const int8_t *B, const int *LDB, const int8_t *bo, const float *beta, - int32_t *C, const int *LDC, const int32_t *co); - -template -mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb, - const char *offsetc, const int *M, const int *N, const int *K, - const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao, - const uint8_t *B, const int *LDB, const int8_t *bo, const float *beta, - int32_t *C, const int *LDC, const int32_t *co); - -} -} -} - -using namespace mkldnn::impl; -using namespace mkldnn::impl::cpu; - -mkldnn_status_t mkldnn_sgemm(const char *transa, const char *transb, - const int64_t *M, const int64_t *N, const int64_t *K, const float *alpha, - const float *A, const int64_t *lda, const float *B, const int64_t *ldb, - const float *beta, float *C, const int64_t *ldc) { - int M_s32 = (int)*M; - int N_s32 = (int)*N; - int K_s32 = (int)*K; - int lda_s32 = (int)*lda; - int ldb_s32 = (int)*ldb; - int ldc_s32 = (int)*ldc; - - return extended_sgemm(transa, transb, &M_s32, &N_s32, &K_s32, - alpha, A, &lda_s32, B, &ldb_s32, beta, C, &ldc_s32); -} - -mkldnn_status_t mkldnn_gemm_s8u8s32(const char *transa, const char *transb, - const char *offsetc, const int64_t *M, const int64_t *N, const int64_t *K, - const float *alpha, const int8_t *A, const int64_t *lda, const int8_t *ao, - const uint8_t *B, const int64_t *ldb, const int8_t *bo, const float *beta, - int32_t *C, const int64_t *ldc, const int32_t *co) { - int M_s32 = (int)*M; - int N_s32 = (int)*N; - int K_s32 = (int)*K; - int lda_s32 = (int)*lda; - int ldb_s32 = (int)*ldb; - int ldc_s32 = (int)*ldc; - return gemm_s8x8s32(transa, transb, offsetc, &M_s32, &N_s32, &K_s32, - alpha, A, &lda_s32, ao, B, &ldb_s32, bo, beta, C, &ldc_s32, co); -} - -mkldnn_status_t mkldnn_gemm_s8s8s32(const char *transa, const char *transb, - const char *offsetc, const int64_t *M, const int64_t *N, const int64_t *K, - const float *alpha, const int8_t *A, const int64_t *lda, const int8_t *ao, - const int8_t *B, const int64_t *ldb, const int8_t *bo, const float *beta, - int32_t *C, const int64_t *ldc, const int32_t *co) { - int M_s32 = (int)*M; - int N_s32 = (int)*N; - int K_s32 = (int)*K; - int lda_s32 = (int)*lda; - int ldb_s32 = (int)*ldb; - int ldc_s32 = (int)*ldc; - - return gemm_s8x8s32<int8_t>(transa, transb, offsetc, &M_s32, &N_s32, &K_s32, - alpha, A, &lda_s32, ao, B, &ldb_s32, bo, beta, C, &ldc_s32, co); -} |