Skip to content

Conversation

@hua-zi
Copy link
Owner

@hua-zi hua-zi commented Feb 23, 2023

PR types

Bug fixes

PR changes

APIs

Describe

问题描述:在静态图模式下,输入为FP16类型时,argmin会报TypeError。

import paddle
import numpy as np

paddle.enable_static()

x_np = np.random.random((10, 16)).astype('float16')
x = paddle.static.data(shape=[10, 16], name='x', dtype='float16')
out = paddle.argmin(x)

exe = paddle.static.Executor()
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_np},
            fetch_list=[out])

报错:

Traceback (most recent call last):
  File ".\test.py", line 11, in <module>
    out = paddle.argmin(x)
  File "c:\app\anaconda3\envs\pytorch\lib\site-packages\paddle\tensor\search.py", line 271, in argmin
    check_variable_and_dtype(
  File "c:\app\anaconda3\envs\pytorch\lib\site-packages\paddle\fluid\data_feeder.py", line 86, in check_variable_and_dtype  
    check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message)
  File "c:\app\anaconda3\envs\pytorch\lib\site-packages\paddle\fluid\data_feeder.py", line 147, in check_dtype
    raise TypeError(
TypeError: The data type of 'x' in paddle.argmin must be ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], but received float16.

修复方案:在argmin API静态图模式下的类型检查中增加fp16支持

# PR types
Bug fixes

# PR changes
APIs

# Describe
# 问题描述:在静态图模式下,输入为FP16类型时,argmin会报TypeError。
```
import paddle
import numpy as np

paddle.enable_static()

x_np = np.random.random((10, 16)).astype('float16')
x = paddle.static.data(shape=[10, 16], name='x', dtype='float16')
out = paddle.argmin(x)

exe = paddle.static.Executor()
exe.run(paddle.static.default_startup_program())
out = exe.run(feed={'x': x_np},
            fetch_list=[out])
```
报错:
```
Traceback (most recent call last):
  File ".\test.py", line 11, in <module>
    out = paddle.argmin(x)
  File "c:\app\anaconda3\envs\pytorch\lib\site-packages\paddle\tensor\search.py", line 271, in argmin
    check_variable_and_dtype(
  File "c:\app\anaconda3\envs\pytorch\lib\site-packages\paddle\fluid\data_feeder.py", line 86, in check_variable_and_dtype  
    check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message)
  File "c:\app\anaconda3\envs\pytorch\lib\site-packages\paddle\fluid\data_feeder.py", line 147, in check_dtype
    raise TypeError(
TypeError: The data type of 'x' in paddle.argmin must be ['float32', 'float64', 'int16', 'int32', 'int64', 'uint8'], but received float16.
```
修复方案:在argmin API静态图模式下的类型检查中增加fp16支持
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant