我正在尝试将SSE函数转换为AVX。该函数进行矢量矩阵乘法,这是我正在使用的SSE代码:
void multiply_matrix_by_vector_SSE(float* m, float* v, float* result, unsigned const int vector_dims)
{
size_t i, j;
for (i = 0; i < vector_dims; ++i)
{
__m128 acc = _mm_setzero_ps();
for (j = 0; j < vector_dims; j += 4)
{
__m128 vec = _mm_load_ps(&v[j]);
__m128 mat = _mm_load_ps(&m[j + vector_dims * i]);
//acc = _mm_add_ps(acc, _mm_mul_ps(mat, vec));
acc = _mm_fmadd_ps(mat, vec, acc);
}
acc = _mm_hadd_ps(acc, acc);
acc = _mm_hadd_ps(acc, acc);
_mm_store_ss(&result[i], acc);
}
}
这就是我对AVX提出的建议:
void multiply_matrix_by_vector_AVX(float* m, float* v, float* result, unsigned const int vector_dims)
{
size_t i, j;
for (i = 0; i < vector_dims; ++i)
{
__m256 acc = _mm256_setzero_ps();
for (j = 0; j < vector_dims; j += 8)
{
__m256 vec = _mm256_load_ps(&v[j]);
__m256 mat = _mm256_load_ps(&m[j + vector_dims * i]);
acc = _mm256_fmadd_ps(mat, vec, acc);
}
acc = _mm256_hadd_ps(acc, acc);
acc = _mm256_hadd_ps(acc, acc);
acc = _mm256_hadd_ps(acc, acc);
acc = _mm256_hadd_ps(acc, acc);
_mm256_store_ps(&result[i], acc);
}
}
但是,AVX代码崩溃(Access violation reading location 0xFFFFFFFFFFFFFFFF
)。
谁能帮助我使我的AVX功能正常工作?
PS:我传递给函数的矩阵和向量的大小始终是8的倍数。而且,传递给SSE函数的数组是16位对齐的(__declspec(align(16))float* = generate_matrix(256);
),传递给AVX函数的数组是32位的对齐(__declspec(align(32))float* = generate_matrix(256);
);
不幸的是,使用水平加法不会像平时一样扩展到256位,因为指令(以及大多数其他指令)是“行进的”-它的作用就像两个haddps
并行,一个在上半部,一个在下半部,没有混合,因此下半部分和上半部分不会相加。
而且,它当然仍然不是打包结果,并且打包存储中有一个对齐存储,其中写入了一些未对齐的地址,并且将失败(该错误有点怪异,但无论如何)。
无论如何,让我们确定水平总和:(未经测试)
// this part still works
acc = _mm256_hadd_ps(acc, acc);
acc = _mm256_hadd_ps(acc, acc);
// this is new
__m128 acc1 = _mm256_extractf128_ps(acc, 0);
__m128 acc2 = _mm256_extractf128_ps(acc, 1);
acc1 = _mm_add_ss(acc1, acc2);
// do scalar store, obviously
_mm_store_ss(&result[i], acc1);
顺便说一下,内部循环需要10个独立的链(和10个累加器),以使Haswell的吞吐量最大化。
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句