Files
tyapi-server/internal/infrastructure/external/ocr/baidu_ocr_service.go
2025-09-12 01:15:09 +08:00

515 lines
15 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package ocr
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"go.uber.org/zap"
"tyapi-server/internal/application/certification/dto/responses"
)
// BaiduOCRService 百度OCR服务
type BaiduOCRService struct {
apiKey string
secretKey string
endpoint string
timeout time.Duration
logger *zap.Logger
}
// NewBaiduOCRService 创建百度OCR服务
func NewBaiduOCRService(apiKey, secretKey string, logger *zap.Logger) *BaiduOCRService {
return &BaiduOCRService{
apiKey: apiKey,
secretKey: secretKey,
endpoint: "https://aip.baidubce.com",
timeout: 30 * time.Second,
logger: logger,
}
}
// RecognizeBusinessLicense 识别营业执照
func (s *BaiduOCRService) RecognizeBusinessLicense(ctx context.Context, imageBytes []byte) (*responses.BusinessLicenseResult, error) {
s.logger.Info("开始识别营业执照", zap.Int("image_size", len(imageBytes)))
// 获取访问令牌
accessToken, err := s.getAccessToken(ctx)
if err != nil {
return nil, fmt.Errorf("获取访问令牌失败: %w", err)
}
// 将图片转换为base64并进行URL编码
imageBase64 := base64.StdEncoding.EncodeToString(imageBytes)
imageBase64UrlEncoded := url.QueryEscape(imageBase64)
// 构建请求URL只包含access_token
apiURL := fmt.Sprintf("%s/rest/2.0/ocr/v1/business_license?access_token=%s", s.endpoint, accessToken)
// 构建POST请求体
payload := strings.NewReader(fmt.Sprintf("image=%s", imageBase64UrlEncoded))
resp, err := s.sendRequest(ctx, "POST", apiURL, payload)
if err != nil {
return nil, fmt.Errorf("营业执照识别请求失败: %w", err)
}
// 解析响应
var result map[string]interface{}
if err := json.Unmarshal(resp, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
// 检查错误
if errCode, ok := result["error_code"].(float64); ok && errCode != 0 {
errorMsg := result["error_msg"].(string)
return nil, fmt.Errorf("OCR识别失败: %s", errorMsg)
}
// 解析识别结果
licenseResult := s.parseBusinessLicenseResult(result)
s.logger.Info("营业执照识别成功",
zap.String("company_name", licenseResult.CompanyName),
zap.String("legal_representative", licenseResult.LegalPersonName),
zap.String("registered_capital", licenseResult.RegisteredCapital),
)
return licenseResult, nil
}
// RecognizeIDCard 识别身份证
func (s *BaiduOCRService) RecognizeIDCard(ctx context.Context, imageBytes []byte, side string) (*responses.IDCardResult, error) {
s.logger.Info("开始识别身份证", zap.String("side", side), zap.Int("image_size", len(imageBytes)))
// 获取访问令牌
accessToken, err := s.getAccessToken(ctx)
if err != nil {
return nil, fmt.Errorf("获取访问令牌失败: %w", err)
}
// 将图片转换为base64并进行URL编码
imageBase64 := base64.StdEncoding.EncodeToString(imageBytes)
imageBase64UrlEncoded := url.QueryEscape(imageBase64)
// 构建请求URL只包含access_token
apiURL := fmt.Sprintf("%s/rest/2.0/ocr/v1/idcard?access_token=%s", s.endpoint, accessToken)
// 构建POST请求体
payload := strings.NewReader(fmt.Sprintf("image=%s&side=%s", imageBase64UrlEncoded, side))
resp, err := s.sendRequest(ctx, "POST", apiURL, payload)
if err != nil {
return nil, fmt.Errorf("身份证识别请求失败: %w", err)
}
// 解析响应
var result map[string]interface{}
if err := json.Unmarshal(resp, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
// 检查错误
if errCode, ok := result["error_code"].(float64); ok && errCode != 0 {
errorMsg := result["error_msg"].(string)
return nil, fmt.Errorf("OCR识别失败: %s", errorMsg)
}
// 解析识别结果
idCardResult := s.parseIDCardResult(result, side)
s.logger.Info("身份证识别成功",
zap.String("name", idCardResult.Name),
zap.String("id_number", idCardResult.IDCardNumber),
zap.String("side", side),
)
return idCardResult, nil
}
// RecognizeGeneralText 通用文字识别
func (s *BaiduOCRService) RecognizeGeneralText(ctx context.Context, imageBytes []byte) (*responses.GeneralTextResult, error) {
s.logger.Info("开始通用文字识别", zap.Int("image_size", len(imageBytes)))
// 获取访问令牌
accessToken, err := s.getAccessToken(ctx)
if err != nil {
return nil, fmt.Errorf("获取访问令牌失败: %w", err)
}
// 将图片转换为base64并进行URL编码
imageBase64 := base64.StdEncoding.EncodeToString(imageBytes)
imageBase64UrlEncoded := url.QueryEscape(imageBase64)
// 构建请求URL只包含access_token
apiURL := fmt.Sprintf("%s/rest/2.0/ocr/v1/general_basic?access_token=%s", s.endpoint, accessToken)
// 构建POST请求体
payload := strings.NewReader(fmt.Sprintf("image=%s", imageBase64UrlEncoded))
resp, err := s.sendRequest(ctx, "POST", apiURL, payload)
if err != nil {
return nil, fmt.Errorf("通用文字识别请求失败: %w", err)
}
// 解析响应
var result map[string]interface{}
if err := json.Unmarshal(resp, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
// 检查错误
if errCode, ok := result["error_code"].(float64); ok && errCode != 0 {
errorMsg := result["error_msg"].(string)
return nil, fmt.Errorf("OCR识别失败: %s", errorMsg)
}
// 解析识别结果
textResult := s.parseGeneralTextResult(result)
s.logger.Info("通用文字识别成功",
zap.Int("word_count", len(textResult.Words)),
zap.Float64("confidence", textResult.Confidence),
)
return textResult, nil
}
// RecognizeFromURL 从URL识别图片
func (s *BaiduOCRService) RecognizeFromURL(ctx context.Context, imageURL string, ocrType string) (interface{}, error) {
s.logger.Info("从URL识别图片", zap.String("url", imageURL), zap.String("type", ocrType))
// 下载图片
imageBytes, err := s.downloadImage(ctx, imageURL)
if err != nil {
s.logger.Error("下载图片失败", zap.Error(err))
return nil, fmt.Errorf("下载图片失败: %w", err)
}
// 根据类型调用相应的识别方法
switch ocrType {
case "business_license":
return s.RecognizeBusinessLicense(ctx, imageBytes)
case "idcard_front":
return s.RecognizeIDCard(ctx, imageBytes, "front")
case "idcard_back":
return s.RecognizeIDCard(ctx, imageBytes, "back")
case "general_text":
return s.RecognizeGeneralText(ctx, imageBytes)
default:
return nil, fmt.Errorf("不支持的OCR类型: %s", ocrType)
}
}
// getAccessToken 获取百度API访问令牌
func (s *BaiduOCRService) getAccessToken(ctx context.Context) (string, error) {
// 构建获取访问令牌的URL
tokenURL := fmt.Sprintf("%s/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
s.endpoint, s.apiKey, s.secretKey)
// 发送请求
resp, err := s.sendRequest(ctx, "POST", tokenURL, nil)
if err != nil {
return "", fmt.Errorf("获取访问令牌请求失败: %w", err)
}
// 解析响应
var result map[string]interface{}
if err := json.Unmarshal(resp, &result); err != nil {
return "", fmt.Errorf("解析访问令牌响应失败: %w", err)
}
// 检查错误
if errCode, ok := result["error"].(string); ok && errCode != "" {
errorDesc := result["error_description"].(string)
return "", fmt.Errorf("获取访问令牌失败: %s - %s", errCode, errorDesc)
}
// 提取访问令牌
accessToken, ok := result["access_token"].(string)
if !ok {
return "", fmt.Errorf("响应中未找到访问令牌")
}
return accessToken, nil
}
// sendRequest 发送HTTP请求
func (s *BaiduOCRService) sendRequest(ctx context.Context, method, url string, body io.Reader) ([]byte, error) {
// 创建HTTP客户端
client := &http.Client{
Timeout: s.timeout,
}
// 创建请求
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
// 设置请求头
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("User-Agent", "tyapi-server/1.0")
// 发送请求
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("发送请求失败: %w", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("请求失败,状态码: %d", resp.StatusCode)
}
// 读取响应内容
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应内容失败: %w", err)
}
return responseBody, nil
}
// parseBusinessLicenseResult 解析营业执照识别结果
func (s *BaiduOCRService) parseBusinessLicenseResult(result map[string]interface{}) *responses.BusinessLicenseResult {
wordsResult := result["words_result"].(map[string]interface{})
// 提取企业信息
companyName := ""
if companyNameObj, ok := wordsResult["单位名称"].(map[string]interface{}); ok {
companyName = companyNameObj["words"].(string)
}
unifiedSocialCode := ""
if socialCreditCodeObj, ok := wordsResult["社会信用代码"].(map[string]interface{}); ok {
unifiedSocialCode = socialCreditCodeObj["words"].(string)
}
legalPersonName := ""
if legalPersonObj, ok := wordsResult["法人"].(map[string]interface{}); ok {
legalPersonName = legalPersonObj["words"].(string)
}
// 提取注册资本等其他信息
registeredCapital := ""
if registeredCapitalObj, ok := wordsResult["注册资本"].(map[string]interface{}); ok {
registeredCapital = registeredCapitalObj["words"].(string)
}
// 提取企业地址
address := ""
if addressObj, ok := wordsResult["地址"].(map[string]interface{}); ok {
address = addressObj["words"].(string)
}
// 计算置信度这里简化处理实际应该从OCR结果中获取
confidence := 0.9 // 默认置信度
return &responses.BusinessLicenseResult{
CompanyName: companyName,
UnifiedSocialCode: unifiedSocialCode,
LegalPersonName: legalPersonName,
LegalPersonID: "", // 营业执照上没有法人身份证号
RegisteredCapital: registeredCapital,
Address: address,
Confidence: confidence,
ProcessedAt: time.Now(),
}
}
// parseIDCardResult 解析身份证识别结果
func (s *BaiduOCRService) parseIDCardResult(result map[string]interface{}, side string) *responses.IDCardResult {
wordsResult := result["words_result"].(map[string]interface{})
idCardResult := &responses.IDCardResult{
Side: side,
Confidence: s.extractConfidence(result),
}
if side == "front" {
if name, ok := wordsResult["姓名"]; ok {
if word, ok := name.(map[string]interface{}); ok {
idCardResult.Name = word["words"].(string)
}
}
if gender, ok := wordsResult["性别"]; ok {
if word, ok := gender.(map[string]interface{}); ok {
idCardResult.Gender = word["words"].(string)
}
}
if nation, ok := wordsResult["民族"]; ok {
if word, ok := nation.(map[string]interface{}); ok {
idCardResult.Nation = word["words"].(string)
}
}
if birthday, ok := wordsResult["出生"]; ok {
if word, ok := birthday.(map[string]interface{}); ok {
idCardResult.Birthday = word["words"].(string)
}
}
if address, ok := wordsResult["住址"]; ok {
if word, ok := address.(map[string]interface{}); ok {
idCardResult.Address = word["words"].(string)
}
}
if idNumber, ok := wordsResult["公民身份号码"]; ok {
if word, ok := idNumber.(map[string]interface{}); ok {
idCardResult.IDCardNumber = word["words"].(string)
}
}
} else {
if issuingAgency, ok := wordsResult["签发机关"]; ok {
if word, ok := issuingAgency.(map[string]interface{}); ok {
idCardResult.IssuingAgency = word["words"].(string)
}
}
if validPeriod, ok := wordsResult["有效期限"]; ok {
if word, ok := validPeriod.(map[string]interface{}); ok {
idCardResult.ValidPeriod = word["words"].(string)
}
}
}
return idCardResult
}
// parseGeneralTextResult 解析通用文字识别结果
func (s *BaiduOCRService) parseGeneralTextResult(result map[string]interface{}) *responses.GeneralTextResult {
wordsResult := result["words_result"].([]interface{})
textResult := &responses.GeneralTextResult{
Confidence: s.extractConfidence(result),
Words: make([]responses.TextLine, 0, len(wordsResult)),
}
for _, word := range wordsResult {
if wordMap, ok := word.(map[string]interface{}); ok {
line := responses.TextLine{
Text: wordMap["words"].(string),
Confidence: 1.0, // 百度返回的通用文字识别没有单独置信度
}
textResult.Words = append(textResult.Words, line)
}
}
return textResult
}
// extractConfidence 提取置信度
func (s *BaiduOCRService) extractConfidence(result map[string]interface{}) float64 {
if confidence, ok := result["confidence"].(float64); ok {
return confidence
}
return 0.0
}
// extractWords 提取识别的文字
func (s *BaiduOCRService) extractWords(result map[string]interface{}) []string {
words := make([]string, 0)
if wordsResult, ok := result["words_result"]; ok {
switch v := wordsResult.(type) {
case map[string]interface{}:
// 营业执照等结构化文档
for _, word := range v {
if wordMap, ok := word.(map[string]interface{}); ok {
if wordsStr, ok := wordMap["words"].(string); ok {
words = append(words, wordsStr)
}
}
}
case []interface{}:
// 通用文字识别
for _, word := range v {
if wordMap, ok := word.(map[string]interface{}); ok {
if wordsStr, ok := wordMap["words"].(string); ok {
words = append(words, wordsStr)
}
}
}
}
}
return words
}
// downloadImage 下载图片
func (s *BaiduOCRService) downloadImage(ctx context.Context, imageURL string) ([]byte, error) {
// 创建HTTP客户端
client := &http.Client{
Timeout: 30 * time.Second,
}
// 创建请求
req, err := http.NewRequestWithContext(ctx, "GET", imageURL, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
// 发送请求
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("下载图片失败: %w", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("下载图片失败,状态码: %d", resp.StatusCode)
}
// 读取响应内容
imageBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取图片内容失败: %w", err)
}
return imageBytes, nil
}
// ValidateBusinessLicense 验证营业执照识别结果
func (s *BaiduOCRService) ValidateBusinessLicense(result *responses.BusinessLicenseResult) error {
if result.Confidence < 0.8 {
return fmt.Errorf("识别置信度过低: %.2f", result.Confidence)
}
if result.CompanyName == "" {
return fmt.Errorf("未能识别公司名称")
}
if result.LegalPersonName == "" {
return fmt.Errorf("未能识别法定代表人")
}
if result.UnifiedSocialCode == "" {
return fmt.Errorf("未能识别统一社会信用代码")
}
return nil
}
// ValidateIDCard 验证身份证识别结果
func (s *BaiduOCRService) ValidateIDCard(result *responses.IDCardResult) error {
if result.Confidence < 0.8 {
return fmt.Errorf("识别置信度过低: %.2f", result.Confidence)
}
if result.Side == "front" {
if result.Name == "" {
return fmt.Errorf("未能识别姓名")
}
if result.IDCardNumber == "" {
return fmt.Errorf("未能识别身份证号码")
}
} else {
if result.IssuingAgency == "" {
return fmt.Errorf("未能识别签发机关")
}
if result.ValidPeriod == "" {
return fmt.Errorf("未能识别有效期限")
}
}
return nil
}