from edgetpu.basic.basic_engine import BasicEngine
class ClassificationEngine(BasicEngine):
def __init__(self, model_path, device_path=None):
super().__init__(model_path, device_path)
def load_data(self, dataset_path):
self.dataset = np.load(dataset_path)
def load_labels(self, labels_path):
self.labels = np.load(labels_path)
def get_all_inputs(self):
return self.get_all_input_tensors()
def get_all_outputs(self):
return self.get_all_output_tensors()
def inference_step(self):
for sample in self.dataset:
input_tensor_shape = self.required_input_array_size()
sample = np.reshape(sample, input_tensor_shape).astype(np.uint8)
yield self.run_inference(sample)
def get_inference_time(self):
return self.get_inference_time()
def get_loss(self, label, output, loss_fn):
return loss_fn(label, output)
def required_input_array_size(self):
return self.required_input_array_size()
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
'-m', '--model', required=True, help='File path of .tflite file.')
'-l', '--labels', help='File path of labels file.')
'-i', '--input', required=True, help='File path of input data file')
args = parser.parse_args()
engine = classificationEngine(args.model)
engine.load_data(args.input)
inference_step = engine.inference_step()
output = next(inference_step)
if __name__ == '__main__':