TensorFlow函数 tf.argmax()
参数:
- input:输入数据
- dimension:按某维度查找。
dimension=0:按列查找;
dimension=1:按行查找;
返回:
- 最大值的下标
import tensorflow.compat.v1 as tf tf.disable_v2_behavior() a = tf.constant([1.,2.,5.,0.,4.]) b = tf.constant([[1,2,3],[3,6,1],[4,1,6],[6,2,4]]) # sess = tf.Session() # print(sess.run(tf.argmax(a,0))) with tf.Session() as sess: print(sess.run(tf.argmax(a,0))) with tf.Session() as sess: print(sess.run(tf.argmax(b,1))) with tf.Session() as sess: print(sess.run(tf.argmax(b,0)))
输出内容为:
2 [2 1 2 0] [3 1 2]
解释:
# axis=0时比较每一列的元素,将每一列最大元素所在的索引记录下来,最后输出每一列最大元素所在的索引数组。 # axis=1的时候,将每一行最大元素所在的索引记录下来,最后返回每一行最大元素所在的索引数组。