am17an commited on
Commit
5cca3ec
·
1 Parent(s): 62cf694

CUDA: add conv_2d_dw (llama/14265)

Browse files

* CUDA: add conv_2d_dw

* better naming

* simplify using template

* Review: fix operation ordering in ggml-cuda, use __forceinline__, use more const

ggml/src/ggml-cuda/conv2d-dw.cu ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "conv2d-dw.cuh"
2
+
3
+ struct conv_params {
4
+ int in_w, in_h;
5
+ int out_w, out_h;
6
+ int kernel_w, kernel_h;
7
+ int stride_x, stride_y;
8
+ int padding_x, padding_y;
9
+ int dilation_x, dilation_y;
10
+ int channels, batches;
11
+ };
12
+
13
+ struct kernel_bounds {
14
+ int y_min, y_max;
15
+ int x_min, x_max;
16
+ };
17
+
18
+ __device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) {
19
+ kernel_bounds bounds;
20
+ bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
21
+ bounds.y_max =
22
+ min(params.kernel_h,
23
+ (params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
24
+ bounds.x_min = max(0, (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
25
+ bounds.x_max =
26
+ min(params.kernel_w,
27
+ (params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
28
+ return bounds;
29
+ }
30
+
31
+ __device__ __forceinline__ int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) {
32
+ return out_coord * stride + kern_coord * dilation - padding;
33
+ }
34
+
35
+ struct whcn_layout {
36
+ __device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
37
+ return n * (params.channels * params.in_w * params.in_h) + c * params.in_w * params.in_h + y * params.in_w + x;
38
+ }
39
+
40
+ __device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
41
+ return c * params.kernel_h * params.kernel_w + ky * params.kernel_w + kx;
42
+ }
43
+
44
+ __device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
45
+ return n * (params.channels * params.out_w * params.out_h) + c * params.out_w * params.out_h +
46
+ y * params.out_w + x;
47
+ }
48
+
49
+ __device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
50
+ int & out_x) {
51
+ out_x = global_idx % params.out_w;
52
+ out_y = (global_idx / params.out_w) % params.out_h;
53
+ c = (global_idx / (params.out_w * params.out_h)) % params.channels;
54
+ n = global_idx / (params.out_w * params.out_h * params.channels);
55
+ }
56
+ };
57
+
58
+ struct cwhn_layout {
59
+ __device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
60
+ return n * (params.channels * params.in_w * params.in_h) + (y * params.in_w + x) * params.channels + c;
61
+ }
62
+
63
+ __device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
64
+ return (ky * params.kernel_w + kx) * params.channels + c;
65
+ }
66
+
67
+ __device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
68
+ return n * (params.channels * params.out_w * params.out_h) + y * (params.out_w * params.channels) +
69
+ x * params.channels + c;
70
+ }
71
+
72
+ __device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
73
+ int & out_x) {
74
+ c = global_idx % params.channels;
75
+ out_x = (global_idx / params.channels) % params.out_w;
76
+ out_y = (global_idx / (params.channels * params.out_w)) % params.out_h;
77
+ n = global_idx / (params.channels * params.out_w * params.out_h);
78
+ }
79
+ };
80
+
81
+ template <typename T, typename Layout>
82
+ __global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restrict__ kernel, T * __restrict__ output,
83
+ const int in_w, const int in_h, const int out_w, const int out_h,
84
+ const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
85
+ const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
86
+ const int channels, const int batches) {
87
+ const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
88
+ const int total_elements = batches * channels * out_h * out_w;
89
+
90
+ if (global_idx >= total_elements) {
91
+ return;
92
+ }
93
+
94
+ conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x,
95
+ stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
96
+
97
+ int batch_idx, channel_idx, out_y_idx, out_x_idx;
98
+ Layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
99
+
100
+ T accumulator = 0;
101
+ kernel_bounds bounds = calculate_kernel_bounds(out_x_idx, out_y_idx, params);
102
+
103
+ for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {
104
+ int in_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y);
105
+
106
+ for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {
107
+ int in_x_idx = calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x);
108
+
109
+ const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
110
+ const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];
111
+
112
+ accumulator += input_val * kernel_val;
113
+ }
114
+ }
115
+
116
+ output[Layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator;
117
+ }
118
+
119
+ void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
120
+ const ggml_tensor * kernel = dst->src[0];
121
+ const ggml_tensor * input = dst->src[1];
122
+
123
+ GGML_ASSERT(kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
124
+ const float * w_d = (const float *) kernel->data;
125
+ const float * x_d = (const float *) input->data;
126
+ float * y_d = (float *) dst->data;
127
+
128
+ const int32_t * p = (const int32_t *) dst->op_params;
129
+ const int stride_x = p[0];
130
+ const int stride_y = p[1];
131
+ const int padding_x = p[2];
132
+ const int padding_y = p[3];
133
+ const int dilation_x = p[4];
134
+ const int dilation_y = p[5];
135
+
136
+ const int in_w = input->ne[0];
137
+ const int in_h = input->ne[1];
138
+ const int kernel_w = kernel->ne[0];
139
+ const int kernel_h = kernel->ne[1];
140
+ const int out_w = dst->ne[0];
141
+ const int out_h = dst->ne[1];
142
+ const int channels = dst->ne[2];
143
+ const int batches = dst->ne[3];
144
+
145
+ cudaStream_t st = ctx.stream();
146
+
147
+ const int total = batches * channels * out_h * out_w;
148
+ const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1) / CUDA_CONV2D_DW_BLOCK_SIZE;
149
+
150
+ if (ggml_is_contiguous(input)) {
151
+ conv2d_dw_kernel<float, whcn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
152
+ x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
153
+ dilation_x, dilation_y, channels, batches);
154
+ } else if (ggml_is_contiguous_channels(input)) {
155
+ conv2d_dw_kernel<float, cwhn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
156
+ x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
157
+ dilation_x, dilation_y, channels, batches);
158
+ } else {
159
+ GGML_ABORT("Unsupported memory layout for conv_2d_dw");
160
+ }
161
+ }
ggml/src/ggml-cuda/conv2d-dw.cuh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #pragma once
2
+ #include "common.cuh"
3
+
4
+ #define CUDA_CONV2D_DW_BLOCK_SIZE 256
5
+ void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
ggml/src/ggml-cuda/ggml-cuda.cu CHANGED
@@ -11,6 +11,7 @@
11
  #include "ggml-cuda/clamp.cuh"
12
  #include "ggml-cuda/concat.cuh"
13
  #include "ggml-cuda/conv-transpose-1d.cuh"
 
14
  #include "ggml-cuda/convert.cuh"
15
  #include "ggml-cuda/count-equal.cuh"
16
  #include "ggml-cuda/cpy.cuh"
@@ -2310,6 +2311,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2310
  case GGML_OP_IM2COL:
2311
  ggml_cuda_op_im2col(ctx, dst);
2312
  break;
 
 
 
2313
  case GGML_OP_CONV_TRANSPOSE_1D:
2314
  ggml_cuda_op_conv_transpose_1d(ctx,dst);
2315
  break;
@@ -3209,6 +3213,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3209
  return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
3210
  }
3211
  case GGML_OP_IM2COL:
 
3212
  case GGML_OP_POOL_2D:
3213
  case GGML_OP_SUM:
3214
  case GGML_OP_SUM_ROWS:
 
11
  #include "ggml-cuda/clamp.cuh"
12
  #include "ggml-cuda/concat.cuh"
13
  #include "ggml-cuda/conv-transpose-1d.cuh"
14
+ #include "ggml-cuda/conv2d-dw.cuh"
15
  #include "ggml-cuda/convert.cuh"
16
  #include "ggml-cuda/count-equal.cuh"
17
  #include "ggml-cuda/cpy.cuh"
 
2311
  case GGML_OP_IM2COL:
2312
  ggml_cuda_op_im2col(ctx, dst);
2313
  break;
2314
+ case GGML_OP_CONV_2D_DW:
2315
+ ggml_cuda_op_conv2d_dw(ctx, dst);
2316
+ break;
2317
  case GGML_OP_CONV_TRANSPOSE_1D:
2318
  ggml_cuda_op_conv_transpose_1d(ctx,dst);
2319
  break;
 
3213
  return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
3214
  }
3215
  case GGML_OP_IM2COL:
3216
+ case GGML_OP_CONV_2D_DW:
3217
  case GGML_OP_POOL_2D:
3218
  case GGML_OP_SUM:
3219
  case GGML_OP_SUM_ROWS: