Spaces:
Running
Running
1.5 bit: we can do even better (llama/5999)
Browse files* iq1_s: we can do even better
Spent one of the 4 scale bits on a signs of a 0.125 shift.
I.e., quants are now -1 + delta, delta, 1 + delta, where delta
is +/- 0.125.
CUDA works, same performance as before.
PPL(LLaMA-v2-7B) is now 11.85!
* iq1_s: make scalar and AVX2 work with the new version
* iq1_s: make Neon work with new version.
~10% drop in performance, so will need some more work.
* iq1_s: make Metal work with new version
* iq1_s: very slightly faster dequantize on Metal
* iq1_s: fix dequantize on the CPU
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
- ggml-cuda.cu +16 -20
- ggml-metal.metal +10 -8
- ggml-quants.c +56 -27
ggml-cuda.cu
CHANGED
|
@@ -1722,22 +1722,15 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
|
|
| 1722 |
const int il = tid/8; // 0...3
|
| 1723 |
const int ib = tid%8; // 0...7
|
| 1724 |
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
| 1725 |
-
const float
|
| 1726 |
-
|
| 1727 |
-
|
| 1728 |
-
grid32[0] =
|
| 1729 |
-
grid32[1] =
|
| 1730 |
-
grid32[0]
|
| 1731 |
for (int j = 0; j < 8; ++j) {
|
| 1732 |
-
y[j] = d * q[j];
|
| 1733 |
-
}
|
| 1734 |
-
#else
|
| 1735 |
-
const uint8_t * grid = (const uint8_t *)(iq1s_grid_gpu + (x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)));
|
| 1736 |
-
for (int j = 0; j < 4; ++j) {
|
| 1737 |
-
y[j+0] = d * ((grid[j] & 0xf) - 1);
|
| 1738 |
-
y[j+4] = d * ((grid[j] >> 4) - 1);
|
| 1739 |
}
|
| 1740 |
-
#endif
|
| 1741 |
#else
|
| 1742 |
assert(false);
|
| 1743 |
#endif
|
|
@@ -4560,22 +4553,25 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
|
|
| 4560 |
const int * q8 = (const int *)bq8_1[ib32].qs;
|
| 4561 |
for (int l = 0; l < 4; ++l) {
|
| 4562 |
const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
|
| 4563 |
-
int grid0 =
|
| 4564 |
-
int grid1 =
|
| 4565 |
sumi = __dp4a(q8[2*l+1], grid1, __dp4a(q8[2*l+0], grid0, sumi));
|
| 4566 |
}
|
| 4567 |
#else
|
| 4568 |
-
const int8_t
|
| 4569 |
for (int l = 0; l < 4; ++l) {
|
| 4570 |
const uint8_t * grid = (const uint8_t *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
|
| 4571 |
for (int j = 0; j < 4; ++j) {
|
| 4572 |
-
sumi += q8[j] * (
|
| 4573 |
}
|
| 4574 |
q8 += 8;
|
| 4575 |
}
|
| 4576 |
#endif
|
| 4577 |
-
const float
|
| 4578 |
-
|
|
|
|
|
|
|
|
|
|
| 4579 |
#else
|
| 4580 |
assert(false);
|
| 4581 |
return 0.f;
|
|
|
|
| 1722 |
const int il = tid/8; // 0...3
|
| 1723 |
const int ib = tid%8; // 0...7
|
| 1724 |
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
| 1725 |
+
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
|
| 1726 |
+
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
|
| 1727 |
+
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
|
| 1728 |
+
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
|
| 1729 |
+
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
|
| 1730 |
+
grid32[0] &= 0x0f0f0f0f;
|
| 1731 |
for (int j = 0; j < 8; ++j) {
|
| 1732 |
+
y[j] = d * (q[j] + delta);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1733 |
}
|
|
|
|
| 1734 |
#else
|
| 1735 |
assert(false);
|
| 1736 |
#endif
|
|
|
|
| 4553 |
const int * q8 = (const int *)bq8_1[ib32].qs;
|
| 4554 |
for (int l = 0; l < 4; ++l) {
|
| 4555 |
const int * grid = (const int *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
|
| 4556 |
+
int grid0 = grid[0] & 0x0f0f0f0f;
|
| 4557 |
+
int grid1 = (grid[0] >> 4) & 0x0f0f0f0f;
|
| 4558 |
sumi = __dp4a(q8[2*l+1], grid1, __dp4a(q8[2*l+0], grid0, sumi));
|
| 4559 |
}
|
| 4560 |
#else
|
| 4561 |
+
const int8_t * q8 = bq8_1[ib32].qs;
|
| 4562 |
for (int l = 0; l < 4; ++l) {
|
| 4563 |
const uint8_t * grid = (const uint8_t *)(iq1s_grid_gpu + (bq1->qs[4*ib32+l] | (((bq1->qh[ib32] >> 3*l) & 7) << 8)));
|
| 4564 |
for (int j = 0; j < 4; ++j) {
|
| 4565 |
+
sumi += q8[j] * (grid[j] & 0xf) + q8[j+4] * (grid[j] >> 4);
|
| 4566 |
}
|
| 4567 |
q8 += 8;
|
| 4568 |
}
|
| 4569 |
#endif
|
| 4570 |
+
const float delta = bq1->qh[ib32] & 0x8000 ? -1-IQ1S_DELTA : -1+IQ1S_DELTA;
|
| 4571 |
+
const float d1q = (float)bq1->d * (2*((bq1->qh[ib32] >> 12) & 7) + 1);
|
| 4572 |
+
const float d = d1q * __low2float (bq8_1[ib32].ds);
|
| 4573 |
+
const float m = d1q * __high2float(bq8_1[ib32].ds);
|
| 4574 |
+
return d * sumi + m * delta;
|
| 4575 |
#else
|
| 4576 |
assert(false);
|
| 4577 |
return 0.f;
|
ggml-metal.metal
CHANGED
|
@@ -4377,7 +4377,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
| 4377 |
+ yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
|
| 4378 |
+ yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
|
| 4379 |
}
|
| 4380 |
-
sumf[row] += (float)dh[0] * (sum -
|
| 4381 |
|
| 4382 |
dh += nb*sizeof(block_iq1_s)/2;
|
| 4383 |
qs += nb*sizeof(block_iq1_s);
|
|
@@ -5076,14 +5076,16 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 &
|
|
| 5076 |
const float d = xb->d;
|
| 5077 |
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
| 5078 |
device const uint16_t * qh = xb->qh;
|
| 5079 |
-
const float dl = d * (2*(qh[ib32] >> 12) + 1);
|
| 5080 |
-
|
| 5081 |
-
|
|
|
|
|
|
|
| 5082 |
for (int i = 0; i < 4; ++i) {
|
| 5083 |
-
reg[0][i] = dl * (grid1[i] & 0xf)
|
| 5084 |
-
reg[1][i] = dl * (grid1[i] >> 4)
|
| 5085 |
-
reg[2][i] = dl * (grid2[i] & 0xf)
|
| 5086 |
-
reg[3][i] = dl * (grid2[i] >> 4)
|
| 5087 |
}
|
| 5088 |
}
|
| 5089 |
|
|
|
|
| 4377 |
+ yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
|
| 4378 |
+ yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
|
| 4379 |
}
|
| 4380 |
+
sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
|
| 4381 |
|
| 4382 |
dh += nb*sizeof(block_iq1_s)/2;
|
| 4383 |
qs += nb*sizeof(block_iq1_s);
|
|
|
|
| 5076 |
const float d = xb->d;
|
| 5077 |
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
| 5078 |
device const uint16_t * qh = xb->qh;
|
| 5079 |
+
const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
|
| 5080 |
+
const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
|
| 5081 |
+
const uint16_t h = qh[ib32] >> 6*il;
|
| 5082 |
+
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
|
| 5083 |
+
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
|
| 5084 |
for (int i = 0; i < 4; ++i) {
|
| 5085 |
+
reg[0][i] = dl * (grid1[i] & 0xf) + ml;
|
| 5086 |
+
reg[1][i] = dl * (grid1[i] >> 4) + ml;
|
| 5087 |
+
reg[2][i] = dl * (grid2[i] & 0xf) + ml;
|
| 5088 |
+
reg[3][i] = dl * (grid2[i] >> 4) + ml;
|
| 5089 |
}
|
| 5090 |
}
|
| 5091 |
|
ggml-quants.c
CHANGED
|
@@ -3456,11 +3456,12 @@ void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, in
|
|
| 3456 |
const uint16_t * qh = x[i].qh;
|
| 3457 |
|
| 3458 |
for (int ib = 0; ib < QK_K/32; ++ib) {
|
| 3459 |
-
const float dl = d * (2*(qh[ib] >> 12) + 1);
|
|
|
|
| 3460 |
for (int l = 0; l < 4; ++l) {
|
| 3461 |
const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
|
| 3462 |
for (int j = 0; j < 8; ++j) {
|
| 3463 |
-
y[j] = dl * grid[j];
|
| 3464 |
}
|
| 3465 |
y += 8;
|
| 3466 |
}
|
|
@@ -9582,7 +9583,7 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
|
|
| 9582 |
const uint8_t * qs = x[i].qs;
|
| 9583 |
const uint16_t * qh = x[i].qh;
|
| 9584 |
|
| 9585 |
-
int sumi1 = 0, sumi2 = 0;
|
| 9586 |
|
| 9587 |
for (int ib = 0; ib < QK_K/32; ib += 2) {
|
| 9588 |
|
|
@@ -9601,12 +9602,16 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
|
|
| 9601 |
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]);
|
| 9602 |
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]);
|
| 9603 |
|
| 9604 |
-
|
| 9605 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9606 |
|
| 9607 |
}
|
| 9608 |
|
| 9609 |
-
sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2);
|
| 9610 |
}
|
| 9611 |
|
| 9612 |
*s = sumf;
|
|
@@ -9614,6 +9619,7 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
|
|
| 9614 |
#elif defined __AVX2__
|
| 9615 |
|
| 9616 |
__m256 accum = _mm256_setzero_ps();
|
|
|
|
| 9617 |
for (int i = 0; i < nb; ++i) {
|
| 9618 |
|
| 9619 |
const int8_t * q8 = y[i].qs;
|
|
@@ -9621,6 +9627,7 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
|
|
| 9621 |
const uint16_t * qh = x[i].qh;
|
| 9622 |
|
| 9623 |
__m256i sumi = _mm256_setzero_si256();
|
|
|
|
| 9624 |
for (int ib = 0; ib < QK_K/32; ib += 2) {
|
| 9625 |
const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)],
|
| 9626 |
iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
|
|
@@ -9632,17 +9639,23 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
|
|
| 9632 |
|
| 9633 |
const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
|
| 9634 |
const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
|
| 9635 |
-
const
|
| 9636 |
-
const
|
|
|
|
|
|
|
| 9637 |
|
| 9638 |
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2));
|
|
|
|
|
|
|
| 9639 |
}
|
| 9640 |
|
| 9641 |
-
|
|
|
|
|
|
|
| 9642 |
|
| 9643 |
}
|
| 9644 |
|
| 9645 |
-
*s = hsum_float_8(accum);
|
| 9646 |
|
| 9647 |
#else
|
| 9648 |
|
|
@@ -9653,9 +9666,10 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
|
|
| 9653 |
const uint8_t * qs = x[i].qs;
|
| 9654 |
const uint16_t * qh = x[i].qh;
|
| 9655 |
|
| 9656 |
-
int sumi = 0;
|
| 9657 |
for (int ib = 0; ib < QK_K/32; ++ib) {
|
| 9658 |
-
const int ls = 2*(qh[ib] >> 12) + 1;
|
|
|
|
| 9659 |
int lsum = 0;
|
| 9660 |
for (int l = 0; l < 4; ++l) {
|
| 9661 |
const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
|
|
@@ -9664,11 +9678,12 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void
|
|
| 9664 |
}
|
| 9665 |
q8 += 8;
|
| 9666 |
}
|
| 9667 |
-
sumi
|
|
|
|
| 9668 |
qs += 4;
|
| 9669 |
}
|
| 9670 |
|
| 9671 |
-
sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * sumi;
|
| 9672 |
}
|
| 9673 |
|
| 9674 |
*s = sumf;
|
|
@@ -11438,7 +11453,7 @@ static int iq1_find_best_neighbour(const uint16_t * restrict neighbours, const u
|
|
| 11438 |
}
|
| 11439 |
|
| 11440 |
static int iq1_find_best_neighbour2(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
|
| 11441 |
-
const float * restrict xval, const float * restrict weight, float scale, int8_t * restrict L, int ngrid) {
|
| 11442 |
int num_neighbors = neighbours[0];
|
| 11443 |
GGML_ASSERT(num_neighbors > 0);
|
| 11444 |
float best_score = FLT_MAX;
|
|
@@ -11447,7 +11462,7 @@ static int iq1_find_best_neighbour2(const uint16_t * restrict neighbours, const
|
|
| 11447 |
const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
|
| 11448 |
float d2 = 0;
|
| 11449 |
for (int i = 0; i < 8; ++i) {
|
| 11450 |
-
float q = (pg[i] -
|
| 11451 |
float w = weight[i];
|
| 11452 |
float diff = scale*q - xval[i];
|
| 11453 |
d2 += w*diff*diff;
|
|
@@ -11463,7 +11478,7 @@ static int iq1_find_best_neighbour2(const uint16_t * restrict neighbours, const
|
|
| 11463 |
float d2 = 0;
|
| 11464 |
for (int j = 0; j < 8; ++j) {
|
| 11465 |
float w = weight[j];
|
| 11466 |
-
float q = (grid_i[j] -
|
| 11467 |
float diff = scale*q - xval[i];
|
| 11468 |
d2 += w*diff*diff;
|
| 11469 |
}
|
|
@@ -11480,7 +11495,7 @@ static int iq1_find_best_neighbour2(const uint16_t * restrict neighbours, const
|
|
| 11480 |
const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
|
| 11481 |
float sumqx = 0, sumq2 = 0;
|
| 11482 |
for (int i = 0; i < 8; ++i) {
|
| 11483 |
-
float q = (pg[i] -
|
| 11484 |
float w = weight[i];
|
| 11485 |
sumqx += w*q*xval[i];
|
| 11486 |
sumq2 += w*q*q;
|
|
@@ -11519,6 +11534,9 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
|
| 11519 |
|
| 11520 |
block_iq1_s * y = vy;
|
| 11521 |
|
|
|
|
|
|
|
|
|
|
| 11522 |
float scales[QK_K/IQ1S_BLOCK_SIZE];
|
| 11523 |
float weight[IQ1S_BLOCK_SIZE];
|
| 11524 |
int8_t L[IQ1S_BLOCK_SIZE];
|
|
@@ -11527,6 +11545,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
|
| 11527 |
float pairs[2*IQ1S_BLOCK_SIZE];
|
| 11528 |
int * idx = (int *)(pairs + 1);
|
| 11529 |
uint16_t index[IQ1S_BLOCK_SIZE/8];
|
|
|
|
| 11530 |
|
| 11531 |
for (int ibl = 0; ibl < nbl; ++ibl) {
|
| 11532 |
|
|
@@ -11572,25 +11591,33 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
|
| 11572 |
}
|
| 11573 |
}
|
| 11574 |
float best_score = 0, scale = max;
|
| 11575 |
-
int besti1 =
|
| 11576 |
for (int i1 = 0; i1 <= IQ1S_BLOCK_SIZE; ++i1) {
|
| 11577 |
for (int i2 = i1; i2 <= IQ1S_BLOCK_SIZE; ++i2) {
|
| 11578 |
-
float sumqx =
|
| 11579 |
-
float sumq2 =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11580 |
if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
|
| 11581 |
scale = sumqx/sumq2; best_score = scale*sumqx;
|
| 11582 |
-
besti1 = i1; besti2 = i2;
|
| 11583 |
}
|
| 11584 |
}
|
| 11585 |
}
|
|
|
|
| 11586 |
for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0;
|
| 11587 |
for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
|
| 11588 |
for (int j = besti2; j < IQ1S_BLOCK_SIZE; ++j) L[idx[2*j]] = 2;
|
| 11589 |
if (scale < 0) {
|
| 11590 |
for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) L[j] = 2 - L[j];
|
| 11591 |
-
scale = -scale;
|
| 11592 |
}
|
| 11593 |
bool all_on_grid = true;
|
|
|
|
| 11594 |
for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) {
|
| 11595 |
uint16_t u = 0;
|
| 11596 |
for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j);
|
|
@@ -11598,7 +11625,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
|
| 11598 |
if (grid_index < 0) {
|
| 11599 |
all_on_grid = false;
|
| 11600 |
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
|
| 11601 |
-
grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, L + 8*k, NGRID_IQ1S);
|
| 11602 |
GGML_ASSERT(grid_index >= 0);
|
| 11603 |
}
|
| 11604 |
index[k] = grid_index;
|
|
@@ -11609,7 +11636,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
|
| 11609 |
const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
|
| 11610 |
for (int j = 0; j < 8; ++j) {
|
| 11611 |
float w = weight[8*k + j];
|
| 11612 |
-
float q = (pg[j] -
|
| 11613 |
sumqx += w*q*xb[8*k+j];
|
| 11614 |
sumq2 += w*q*q;
|
| 11615 |
}
|
|
@@ -11624,6 +11651,7 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
|
| 11624 |
y[ibl].qh[ib] = h;
|
| 11625 |
GGML_ASSERT(scale >= 0);
|
| 11626 |
scales[ib] = scale;
|
|
|
|
| 11627 |
max_scale = MAX(max_scale, scale);
|
| 11628 |
}
|
| 11629 |
|
|
@@ -11632,12 +11660,13 @@ static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy
|
|
| 11632 |
continue;
|
| 11633 |
}
|
| 11634 |
|
| 11635 |
-
float d = max_scale/
|
| 11636 |
y[ibl].d = GGML_FP32_TO_FP16(d*1.125f); // 1.085f is another fudge factor. Don't ask me why it is needed.
|
| 11637 |
float id = 1/d;
|
| 11638 |
for (int ib = 0; ib < QK_K/IQ1S_BLOCK_SIZE; ++ib) {
|
| 11639 |
int l = nearest_int(0.5f*(id*scales[ib]-1));
|
| 11640 |
-
l = MAX(0, MIN(
|
|
|
|
| 11641 |
y[ibl].qh[ib] |= (l << 12);
|
| 11642 |
}
|
| 11643 |
}
|
|
|
|
| 3456 |
const uint16_t * qh = x[i].qh;
|
| 3457 |
|
| 3458 |
for (int ib = 0; ib < QK_K/32; ++ib) {
|
| 3459 |
+
const float dl = d * (2*((qh[ib] >> 12) & 7) + 1);
|
| 3460 |
+
const float delta = qh[ib] & 0x8000 ? -IQ1S_DELTA : IQ1S_DELTA;
|
| 3461 |
for (int l = 0; l < 4; ++l) {
|
| 3462 |
const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
|
| 3463 |
for (int j = 0; j < 8; ++j) {
|
| 3464 |
+
y[j] = dl * (grid[j] + delta);
|
| 3465 |
}
|
| 3466 |
y += 8;
|
| 3467 |
}
|
|
|
|
| 9583 |
const uint8_t * qs = x[i].qs;
|
| 9584 |
const uint16_t * qh = x[i].qh;
|
| 9585 |
|
| 9586 |
+
int sumi1 = 0, sumi2 = 0, sumi3 = 0;
|
| 9587 |
|
| 9588 |
for (int ib = 0; ib < QK_K/32; ib += 2) {
|
| 9589 |
|
|
|
|
| 9602 |
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]);
|
| 9603 |
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]);
|
| 9604 |
|
| 9605 |
+
const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
|
| 9606 |
+
const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
|
| 9607 |
+
sumi1 += vaddvq_s32(p1) * ls1;
|
| 9608 |
+
sumi2 += vaddvq_s32(p2) * ls2;
|
| 9609 |
+
sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1)
|
| 9610 |
+
+ (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1);
|
| 9611 |
|
| 9612 |
}
|
| 9613 |
|
| 9614 |
+
sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3);
|
| 9615 |
}
|
| 9616 |
|
| 9617 |
*s = sumf;
|
|
|
|
| 9619 |
#elif defined __AVX2__
|
| 9620 |
|
| 9621 |
__m256 accum = _mm256_setzero_ps();
|
| 9622 |
+
float accum1 = 0;
|
| 9623 |
for (int i = 0; i < nb; ++i) {
|
| 9624 |
|
| 9625 |
const int8_t * q8 = y[i].qs;
|
|
|
|
| 9627 |
const uint16_t * qh = x[i].qh;
|
| 9628 |
|
| 9629 |
__m256i sumi = _mm256_setzero_si256();
|
| 9630 |
+
int sumi1 = 0;
|
| 9631 |
for (int ib = 0; ib < QK_K/32; ib += 2) {
|
| 9632 |
const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)],
|
| 9633 |
iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]);
|
|
|
|
| 9639 |
|
| 9640 |
const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1);
|
| 9641 |
const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2);
|
| 9642 |
+
const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
|
| 9643 |
+
const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
|
| 9644 |
+
const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(ls1));
|
| 9645 |
+
const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(ls2));
|
| 9646 |
|
| 9647 |
sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2));
|
| 9648 |
+
sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1
|
| 9649 |
+
+ (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2;
|
| 9650 |
}
|
| 9651 |
|
| 9652 |
+
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
|
| 9653 |
+
accum = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), accum);
|
| 9654 |
+
accum1 += d * sumi1;
|
| 9655 |
|
| 9656 |
}
|
| 9657 |
|
| 9658 |
+
*s = hsum_float_8(accum) + IQ1S_DELTA * accum1;
|
| 9659 |
|
| 9660 |
#else
|
| 9661 |
|
|
|
|
| 9666 |
const uint8_t * qs = x[i].qs;
|
| 9667 |
const uint16_t * qh = x[i].qh;
|
| 9668 |
|
| 9669 |
+
int sumi = 0, sumi1 = 0;
|
| 9670 |
for (int ib = 0; ib < QK_K/32; ++ib) {
|
| 9671 |
+
const int ls = 2*((qh[ib] >> 12) & 7) + 1;
|
| 9672 |
+
const int delta = qh[ib] & 0x8000 ? -1 : 1;
|
| 9673 |
int lsum = 0;
|
| 9674 |
for (int l = 0; l < 4; ++l) {
|
| 9675 |
const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8)));
|
|
|
|
| 9678 |
}
|
| 9679 |
q8 += 8;
|
| 9680 |
}
|
| 9681 |
+
sumi += ls * lsum;
|
| 9682 |
+
sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]);
|
| 9683 |
qs += 4;
|
| 9684 |
}
|
| 9685 |
|
| 9686 |
+
sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1);
|
| 9687 |
}
|
| 9688 |
|
| 9689 |
*s = sumf;
|
|
|
|
| 11453 |
}
|
| 11454 |
|
| 11455 |
static int iq1_find_best_neighbour2(const uint16_t * restrict neighbours, const uint64_t * restrict grid,
|
| 11456 |
+
const float * restrict xval, const float * restrict weight, float scale, const float * restrict xg, int8_t * restrict L, int ngrid) {
|
| 11457 |
int num_neighbors = neighbours[0];
|
| 11458 |
GGML_ASSERT(num_neighbors > 0);
|
| 11459 |
float best_score = FLT_MAX;
|
|
|
|
| 11462 |
const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
|
| 11463 |
float d2 = 0;
|
| 11464 |
for (int i = 0; i < 8; ++i) {
|
| 11465 |
+
float q = xg[(pg[i] - 1)/2];
|
| 11466 |
float w = weight[i];
|
| 11467 |
float diff = scale*q - xval[i];
|
| 11468 |
d2 += w*diff*diff;
|
|
|
|
| 11478 |
float d2 = 0;
|
| 11479 |
for (int j = 0; j < 8; ++j) {
|
| 11480 |
float w = weight[j];
|
| 11481 |
+
float q = xg[(grid_i[j] - 1)/2];
|
| 11482 |
float diff = scale*q - xval[i];
|
| 11483 |
d2 += w*diff*diff;
|
| 11484 |
}
|
|
|
|
| 11495 |
const int8_t * pg = (const int8_t *)(grid + neighbours[j]);
|
| 11496 |
float sumqx = 0, sumq2 = 0;
|
| 11497 |
for (int i = 0; i < 8; ++i) {
|
| 11498 |
+
float q = xg[(pg[i] - 1)/2];
|
| 11499 |
float w = weight[i];
|
| 11500 |
sumqx += w*q*xval[i];
|
| 11501 |
sumq2 += w*q*q;
|
|
|
|
| 11534 |
|
| 11535 |
block_iq1_s * y = vy;
|
| 11536 |
|
| 11537 |
+
const float x_p[3] = {-1 + IQ1S_DELTA, IQ1S_DELTA, 1 + IQ1S_DELTA};
|
| 11538 |
+
const float x_m[3] = {-1 - IQ1S_DELTA, -IQ1S_DELTA, 1 - IQ1S_DELTA};
|
| 11539 |
+
|
| 11540 |
float scales[QK_K/IQ1S_BLOCK_SIZE];
|
| 11541 |
float weight[IQ1S_BLOCK_SIZE];
|
| 11542 |
int8_t L[IQ1S_BLOCK_SIZE];
|
|
|
|
| 11545 |
float pairs[2*IQ1S_BLOCK_SIZE];
|
| 11546 |
int * idx = (int *)(pairs + 1);
|
| 11547 |
uint16_t index[IQ1S_BLOCK_SIZE/8];
|
| 11548 |
+
int8_t shifts[QK_K/IQ1S_BLOCK_SIZE];
|
| 11549 |
|
| 11550 |
for (int ibl = 0; ibl < nbl; ++ibl) {
|
| 11551 |
|
|
|
|
| 11591 |
}
|
| 11592 |
}
|
| 11593 |
float best_score = 0, scale = max;
|
| 11594 |
+
int besti1 = -1, besti2 = -1, best_shift = 0;
|
| 11595 |
for (int i1 = 0; i1 <= IQ1S_BLOCK_SIZE; ++i1) {
|
| 11596 |
for (int i2 = i1; i2 <= IQ1S_BLOCK_SIZE; ++i2) {
|
| 11597 |
+
float sumqx = (sumx[i1] - sumx[0])*x_p[0] + (sumx[i2] - sumx[i1])*x_p[1] + (sumx[IQ1S_BLOCK_SIZE] - sumx[i2])*x_p[2];
|
| 11598 |
+
float sumq2 = (sumw[i1] - sumw[0])*x_p[0]*x_p[0] + (sumw[i2] - sumw[i1])*x_p[1]*x_p[1] + (sumw[IQ1S_BLOCK_SIZE] - sumw[i2])*x_p[2]*x_p[2];
|
| 11599 |
+
if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
|
| 11600 |
+
scale = sumqx/sumq2; best_score = scale*sumqx;
|
| 11601 |
+
besti1 = i1; besti2 = i2; best_shift = 1;
|
| 11602 |
+
}
|
| 11603 |
+
sumqx = (sumx[i1] - sumx[0])*x_m[0] + (sumx[i2] - sumx[i1])*x_m[1] + (sumx[IQ1S_BLOCK_SIZE] - sumx[i2])*x_m[2];
|
| 11604 |
+
sumq2 = (sumw[i1] - sumw[0])*x_m[0]*x_m[0] + (sumw[i2] - sumw[i1])*x_m[1]*x_m[1] + (sumw[IQ1S_BLOCK_SIZE] - sumw[i2])*x_m[2]*x_m[2];
|
| 11605 |
if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
|
| 11606 |
scale = sumqx/sumq2; best_score = scale*sumqx;
|
| 11607 |
+
besti1 = i1; besti2 = i2; best_shift = -1;
|
| 11608 |
}
|
| 11609 |
}
|
| 11610 |
}
|
| 11611 |
+
GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0);
|
| 11612 |
for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0;
|
| 11613 |
for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
|
| 11614 |
for (int j = besti2; j < IQ1S_BLOCK_SIZE; ++j) L[idx[2*j]] = 2;
|
| 11615 |
if (scale < 0) {
|
| 11616 |
for (int j = 0; j < IQ1S_BLOCK_SIZE; ++j) L[j] = 2 - L[j];
|
| 11617 |
+
scale = -scale; best_shift = -best_shift;
|
| 11618 |
}
|
| 11619 |
bool all_on_grid = true;
|
| 11620 |
+
const float * xx = best_shift == 1 ? x_p : x_m;
|
| 11621 |
for (int k = 0; k < IQ1S_BLOCK_SIZE/8; ++k) {
|
| 11622 |
uint16_t u = 0;
|
| 11623 |
for (int j = 0; j < 8; ++j) u |= (L[8*k+j] << 2*j);
|
|
|
|
| 11625 |
if (grid_index < 0) {
|
| 11626 |
all_on_grid = false;
|
| 11627 |
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
|
| 11628 |
+
grid_index = iq1_find_best_neighbour2(neighbours, kgrid_q2xs, xb + 8*k, weight + 8*k, scale, xx, L + 8*k, NGRID_IQ1S);
|
| 11629 |
GGML_ASSERT(grid_index >= 0);
|
| 11630 |
}
|
| 11631 |
index[k] = grid_index;
|
|
|
|
| 11636 |
const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
|
| 11637 |
for (int j = 0; j < 8; ++j) {
|
| 11638 |
float w = weight[8*k + j];
|
| 11639 |
+
float q = xx[(pg[j] - 1)/2];
|
| 11640 |
sumqx += w*q*xb[8*k+j];
|
| 11641 |
sumq2 += w*q*q;
|
| 11642 |
}
|
|
|
|
| 11651 |
y[ibl].qh[ib] = h;
|
| 11652 |
GGML_ASSERT(scale >= 0);
|
| 11653 |
scales[ib] = scale;
|
| 11654 |
+
shifts[ib] = best_shift;
|
| 11655 |
max_scale = MAX(max_scale, scale);
|
| 11656 |
}
|
| 11657 |
|
|
|
|
| 11660 |
continue;
|
| 11661 |
}
|
| 11662 |
|
| 11663 |
+
float d = max_scale/15;
|
| 11664 |
y[ibl].d = GGML_FP32_TO_FP16(d*1.125f); // 1.085f is another fudge factor. Don't ask me why it is needed.
|
| 11665 |
float id = 1/d;
|
| 11666 |
for (int ib = 0; ib < QK_K/IQ1S_BLOCK_SIZE; ++ib) {
|
| 11667 |
int l = nearest_int(0.5f*(id*scales[ib]-1));
|
| 11668 |
+
l = MAX(0, MIN(7, l));
|
| 11669 |
+
if (shifts[ib] == -1) l |= 8;
|
| 11670 |
y[ibl].qh[ib] |= (l << 12);
|
| 11671 |
}
|
| 11672 |
}
|