Conversation
pixelspark
left a comment
There was a problem hiding this comment.
Great contribution! Just a few notes, should be fine to merge after.
|
|
||
| var accumulator = {% if op_type == "ReduceProd" %} {{ scalar_type }}(1) {% else %} Scalar() {% endif %}; | ||
| var accumulator = {% if op_type == "ReduceProd" %} {{ scalar_type }}(1) {% else %} Scalar() {% endif %}; | ||
| var max_element: Scalar = log(Scalar()); |
There was a problem hiding this comment.
Hm, perhaps add a note explaining why you initialize to log(Scalar()) (I assume it is a trick to initialize to -Inf so the first value is always higher).
| {% elif op_type == "ArgMax" %} | ||
| if(input_val > max_element) { | ||
| max_element = input_val; | ||
| accumulator = f32(count); |
There was a problem hiding this comment.
The output of ArgMax is actually specified as int64 (see here). This should work for now but we might have to fix it later.
There was a problem hiding this comment.
Good point, then it would probably make sense to move argmax out of reduce to not have to return mixed data types from reduce?
| &[3, 2], | ||
| ); | ||
|
|
||
| // ONNX test case: do_not_keepdims with ArgMax |
There was a problem hiding this comment.
Can you also enable the corresponding test case in the ONNX backend test (Python scripts)?
This PR adds the argmax operator to wonnx.