MLIR  20.0.0git
DestinationStyleOpInterface.cpp
Go to the documentation of this file.
1 //===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 using namespace mlir;
12 
13 namespace mlir {
14 #include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
15 } // namespace mlir
16 
17 namespace {
18 size_t getNumTensorResults(Operation *op) {
19  size_t numTensorResults = 0;
20  for (auto t : op->getResultTypes()) {
21  if (isa<TensorType>(t)) {
22  ++numTensorResults;
23  }
24  }
25  return numTensorResults;
26 }
27 } // namespace
28 
30  DestinationStyleOpInterface dstStyleOp =
31  cast<DestinationStyleOpInterface>(op);
32 
33  SmallVector<OpOperand *> outputTensorOperands;
34  for (OpOperand &operand : dstStyleOp.getDpsInitsMutable()) {
35  Type type = operand.get().getType();
36  if (isa<TensorType>(type)) {
37  outputTensorOperands.push_back(&operand);
38  } else if (!isa<BaseMemRefType>(type)) {
39  return op->emitOpError("expected that operand #")
40  << operand.getOperandNumber() << " is a tensor or a memref";
41  }
42  }
43 
44  // Verify the number of tensor results matches the number of output tensors.
45  if (getNumTensorResults(op) != outputTensorOperands.size())
46  return op->emitOpError("expected the number of tensor results (")
47  << getNumTensorResults(op)
48  << ") to be equal to the number of output tensors ("
49  << outputTensorOperands.size() << ")";
50 
51  for (OpOperand *opOperand : outputTensorOperands) {
52  OpResult result = dstStyleOp.getTiedOpResult(opOperand);
53  if (result.getType() != opOperand->get().getType())
54  return op->emitOpError("expected type of operand #")
55  << opOperand->getOperandNumber() << " ("
56  << opOperand->get().getType() << ")"
57  << " to match type of corresponding result (" << result.getType()
58  << ")";
59  }
60 
61  return success();
62 }
This class represents an operand of an operation.
Definition: Value.h:267
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_type_range getResultTypes()
Definition: Operation.h:423
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Type getType() const
Return the type of this value.
Definition: Value.h:129
LogicalResult verifyDestinationStyleOpInterface(Operation *op)
Verify that op conforms to the invariants of DestinationStyleOpInterface.
Include the generated interface declarations.