libavfilter/dnn: add more data type support for dnn model input
currently, only float is supported as model input, actually, there are other data types, this patch adds uint8. Signed-off-by: Guo, Yejun <yejun.guo@intel.com> Signed-off-by: Pedro Arthur <bygrandao@gmail.com>
This commit is contained in:
		
							parent
							
								
									25c1cd909f
								
							
						
					
					
						commit
						c636dc9819
					
				| @ -24,8 +24,9 @@ | ||||
|  */ | ||||
| 
 | ||||
| #include "dnn_backend_native.h" | ||||
| #include "libavutil/avassert.h" | ||||
| 
 | ||||
| static DNNReturnType set_input_output_native(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output) | ||||
| static DNNReturnType set_input_output_native(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output) | ||||
| { | ||||
|     ConvolutionalNetwork *network = (ConvolutionalNetwork *)model; | ||||
|     InputParams *input_params; | ||||
| @ -45,6 +46,7 @@ static DNNReturnType set_input_output_native(void *model, DNNData *input, const | ||||
|         if (input->data){ | ||||
|             av_freep(&input->data); | ||||
|         } | ||||
|         av_assert0(input->dt == DNN_FLOAT); | ||||
|         network->layers[0].output = input->data = av_malloc(cur_height * cur_width * cur_channels * sizeof(float)); | ||||
|         if (!network->layers[0].output){ | ||||
|             return DNN_ERROR; | ||||
|  | ||||
| @ -79,10 +79,31 @@ static TF_Buffer *read_graph(const char *model_filename) | ||||
|     return graph_buf; | ||||
| } | ||||
| 
 | ||||
| static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output) | ||||
| static TF_Tensor *allocate_input_tensor(const DNNInputData *input) | ||||
| { | ||||
|     TF_DataType dt; | ||||
|     size_t size; | ||||
|     int64_t input_dims[] = {1, input->height, input->width, input->channels}; | ||||
|     switch (input->dt) { | ||||
|     case DNN_FLOAT: | ||||
|         dt = TF_FLOAT; | ||||
|         size = sizeof(float); | ||||
|         break; | ||||
|     case DNN_UINT8: | ||||
|         dt = TF_UINT8; | ||||
|         size = sizeof(char); | ||||
|         break; | ||||
|     default: | ||||
|         av_assert0(!"should not reach here"); | ||||
|     } | ||||
| 
 | ||||
|     return TF_AllocateTensor(dt, input_dims, 4, | ||||
|                              input_dims[1] * input_dims[2] * input_dims[3] * size); | ||||
| } | ||||
| 
 | ||||
| static DNNReturnType set_input_output_tf(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output) | ||||
| { | ||||
|     TFModel *tf_model = (TFModel *)model; | ||||
|     int64_t input_dims[] = {1, input->height, input->width, input->channels}; | ||||
|     TF_SessionOptions *sess_opts; | ||||
|     const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init"); | ||||
| 
 | ||||
| @ -95,8 +116,7 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char | ||||
|     if (tf_model->input_tensor){ | ||||
|         TF_DeleteTensor(tf_model->input_tensor); | ||||
|     } | ||||
|     tf_model->input_tensor = TF_AllocateTensor(TF_FLOAT, input_dims, 4, | ||||
|                                                input_dims[1] * input_dims[2] * input_dims[3] * sizeof(float)); | ||||
|     tf_model->input_tensor = allocate_input_tensor(input); | ||||
|     if (!tf_model->input_tensor){ | ||||
|         return DNN_ERROR; | ||||
|     } | ||||
|  | ||||
| @ -32,6 +32,14 @@ typedef enum {DNN_SUCCESS, DNN_ERROR} DNNReturnType; | ||||
| 
 | ||||
| typedef enum {DNN_NATIVE, DNN_TF} DNNBackendType; | ||||
| 
 | ||||
| typedef enum {DNN_FLOAT, DNN_UINT8} DNNDataType; | ||||
| 
 | ||||
| typedef struct DNNInputData{ | ||||
|     void *data; | ||||
|     DNNDataType dt; | ||||
|     int width, height, channels; | ||||
| } DNNInputData; | ||||
| 
 | ||||
| typedef struct DNNData{ | ||||
|     float *data; | ||||
|     int width, height, channels; | ||||
| @ -42,7 +50,7 @@ typedef struct DNNModel{ | ||||
|     void *model; | ||||
|     // Sets model input and output.
 | ||||
|     // Should be called at least once before model execution.
 | ||||
|     DNNReturnType (*set_input_output)(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output); | ||||
|     DNNReturnType (*set_input_output)(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output); | ||||
| } DNNModel; | ||||
| 
 | ||||
| // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.
 | ||||
|  | ||||
| @ -40,7 +40,8 @@ typedef struct SRContext { | ||||
|     DNNBackendType backend_type; | ||||
|     DNNModule *dnn_module; | ||||
|     DNNModel *model; | ||||
|     DNNData input, output; | ||||
|     DNNInputData input; | ||||
|     DNNData output; | ||||
|     int scale_factor; | ||||
|     struct SwsContext *sws_contexts[3]; | ||||
|     int sws_slice_h, sws_input_linesize, sws_output_linesize; | ||||
| @ -86,6 +87,7 @@ static av_cold int init(AVFilterContext *context) | ||||
|         return AVERROR(EIO); | ||||
|     } | ||||
| 
 | ||||
|     sr_context->input.dt = DNN_FLOAT; | ||||
|     sr_context->sws_contexts[0] = NULL; | ||||
|     sr_context->sws_contexts[1] = NULL; | ||||
|     sr_context->sws_contexts[2] = NULL; | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user