summaryrefslogtreecommitdiffstats
path: root/thirdparty/basis_universal/encoder/basisu_kernels_imp.h
diff options
context:
space:
mode:
Diffstat (limited to 'thirdparty/basis_universal/encoder/basisu_kernels_imp.h')
-rw-r--r--thirdparty/basis_universal/encoder/basisu_kernels_imp.h63
1 files changed, 63 insertions, 0 deletions
diff --git a/thirdparty/basis_universal/encoder/basisu_kernels_imp.h b/thirdparty/basis_universal/encoder/basisu_kernels_imp.h
index 046880517b..dcf1ce069a 100644
--- a/thirdparty/basis_universal/encoder/basisu_kernels_imp.h
+++ b/thirdparty/basis_universal/encoder/basisu_kernels_imp.h
@@ -548,6 +548,65 @@ namespace CPPSPMD_NAME(basisu_kernels_namespace)
}
};
+ struct update_covar_matrix_16x16 : spmd_kernel
+ {
+ void _call(
+ uint32_t num_vecs, const void* pWeighted_vecs_void, const void* pOrigin_void, const uint32_t* pVec_indices, void* pMatrix16x16_void)
+ {
+ const std::pair<vec16F, uint64_t>* pWeighted_vecs = static_cast< const std::pair<vec16F, uint64_t> *>(pWeighted_vecs_void);
+
+ const float* pOrigin = static_cast<const float*>(pOrigin_void);
+ vfloat org0 = loadu_linear_all(pOrigin), org1 = loadu_linear_all(pOrigin + 4), org2 = loadu_linear_all(pOrigin + 8), org3 = loadu_linear_all(pOrigin + 12);
+
+ vfloat mat[16][4];
+ vfloat vzero(zero_vfloat());
+
+ for (uint32_t i = 0; i < 16; i++)
+ {
+ store_all(mat[i][0], vzero);
+ store_all(mat[i][1], vzero);
+ store_all(mat[i][2], vzero);
+ store_all(mat[i][3], vzero);
+ }
+
+ for (uint32_t k = 0; k < num_vecs; k++)
+ {
+ const uint32_t vec_index = pVec_indices[k];
+
+ const float* pW = pWeighted_vecs[vec_index].first.get_ptr();
+ vfloat weight((float)pWeighted_vecs[vec_index].second);
+
+ vfloat vec[4] = { loadu_linear_all(pW) - org0, loadu_linear_all(pW + 4) - org1, loadu_linear_all(pW + 8) - org2, loadu_linear_all(pW + 12) - org3 };
+
+ vfloat wvec0 = vec[0] * weight, wvec1 = vec[1] * weight, wvec2 = vec[2] * weight, wvec3 = vec[3] * weight;
+
+ for (uint32_t j = 0; j < 16; j++)
+ {
+ vfloat vx = ((const float*)vec)[j];
+
+ store_all(mat[j][0], mat[j][0] + vx * wvec0);
+ store_all(mat[j][1], mat[j][1] + vx * wvec1);
+ store_all(mat[j][2], mat[j][2] + vx * wvec2);
+ store_all(mat[j][3], mat[j][3] + vx * wvec3);
+
+ } // j
+
+ } // k
+
+ float* pMatrix = static_cast<float*>(pMatrix16x16_void);
+
+ float* pDst = pMatrix;
+ for (uint32_t i = 0; i < 16; i++)
+ {
+ storeu_linear_all(pDst, mat[i][0]);
+ storeu_linear_all(pDst + 4, mat[i][1]);
+ storeu_linear_all(pDst + 8, mat[i][2]);
+ storeu_linear_all(pDst + 12, mat[i][3]);
+ pDst += 16;
+ }
+ }
+ };
+
} // namespace
using namespace CPPSPMD_NAME(basisu_kernels_namespace);
@@ -582,3 +641,7 @@ void CPPSPMD_NAME(find_lowest_error_linear_rgb_4_N)(int64_t* pDistance, const co
spmd_call< find_lowest_error_linear_rgb_4_N >(pDistance, pBlock_colors, pSrc_pixels, n, early_out_error);
}
+void CPPSPMD_NAME(update_covar_matrix_16x16)(uint32_t num_vecs, const void* pWeighted_vecs, const void* pOrigin, const uint32_t *pVec_indices, void* pMatrix16x16)
+{
+ spmd_call < update_covar_matrix_16x16 >(num_vecs, pWeighted_vecs, pOrigin, pVec_indices, pMatrix16x16);
+}