diff options
Diffstat (limited to 'thirdparty/basis_universal/encoder/basisu_kernels_imp.h')
-rw-r--r-- | thirdparty/basis_universal/encoder/basisu_kernels_imp.h | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/thirdparty/basis_universal/encoder/basisu_kernels_imp.h b/thirdparty/basis_universal/encoder/basisu_kernels_imp.h index 046880517b..dcf1ce069a 100644 --- a/thirdparty/basis_universal/encoder/basisu_kernels_imp.h +++ b/thirdparty/basis_universal/encoder/basisu_kernels_imp.h @@ -548,6 +548,65 @@ namespace CPPSPMD_NAME(basisu_kernels_namespace) } }; + struct update_covar_matrix_16x16 : spmd_kernel + { + void _call( + uint32_t num_vecs, const void* pWeighted_vecs_void, const void* pOrigin_void, const uint32_t* pVec_indices, void* pMatrix16x16_void) + { + const std::pair<vec16F, uint64_t>* pWeighted_vecs = static_cast< const std::pair<vec16F, uint64_t> *>(pWeighted_vecs_void); + + const float* pOrigin = static_cast<const float*>(pOrigin_void); + vfloat org0 = loadu_linear_all(pOrigin), org1 = loadu_linear_all(pOrigin + 4), org2 = loadu_linear_all(pOrigin + 8), org3 = loadu_linear_all(pOrigin + 12); + + vfloat mat[16][4]; + vfloat vzero(zero_vfloat()); + + for (uint32_t i = 0; i < 16; i++) + { + store_all(mat[i][0], vzero); + store_all(mat[i][1], vzero); + store_all(mat[i][2], vzero); + store_all(mat[i][3], vzero); + } + + for (uint32_t k = 0; k < num_vecs; k++) + { + const uint32_t vec_index = pVec_indices[k]; + + const float* pW = pWeighted_vecs[vec_index].first.get_ptr(); + vfloat weight((float)pWeighted_vecs[vec_index].second); + + vfloat vec[4] = { loadu_linear_all(pW) - org0, loadu_linear_all(pW + 4) - org1, loadu_linear_all(pW + 8) - org2, loadu_linear_all(pW + 12) - org3 }; + + vfloat wvec0 = vec[0] * weight, wvec1 = vec[1] * weight, wvec2 = vec[2] * weight, wvec3 = vec[3] * weight; + + for (uint32_t j = 0; j < 16; j++) + { + vfloat vx = ((const float*)vec)[j]; + + store_all(mat[j][0], mat[j][0] + vx * wvec0); + store_all(mat[j][1], mat[j][1] + vx * wvec1); + store_all(mat[j][2], mat[j][2] + vx * wvec2); + store_all(mat[j][3], mat[j][3] + vx * wvec3); + + } // j + + } // k + + float* pMatrix = static_cast<float*>(pMatrix16x16_void); + + float* pDst = pMatrix; + for (uint32_t i = 0; i < 16; i++) + { + storeu_linear_all(pDst, mat[i][0]); + storeu_linear_all(pDst + 4, mat[i][1]); + storeu_linear_all(pDst + 8, mat[i][2]); + storeu_linear_all(pDst + 12, mat[i][3]); + pDst += 16; + } + } + }; + } // namespace using namespace CPPSPMD_NAME(basisu_kernels_namespace); @@ -582,3 +641,7 @@ void CPPSPMD_NAME(find_lowest_error_linear_rgb_4_N)(int64_t* pDistance, const co spmd_call< find_lowest_error_linear_rgb_4_N >(pDistance, pBlock_colors, pSrc_pixels, n, early_out_error); } +void CPPSPMD_NAME(update_covar_matrix_16x16)(uint32_t num_vecs, const void* pWeighted_vecs, const void* pOrigin, const uint32_t *pVec_indices, void* pMatrix16x16) +{ + spmd_call < update_covar_matrix_16x16 >(num_vecs, pWeighted_vecs, pOrigin, pVec_indices, pMatrix16x16); +} |