14 #include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
18 size_t getNumTensorResults(
Operation *op) {
19 size_t numTensorResults = 0;
21 if (isa<TensorType>(t)) {
25 return numTensorResults;
30 DestinationStyleOpInterface dstStyleOp =
31 cast<DestinationStyleOpInterface>(op);
34 for (
OpOperand &operand : dstStyleOp.getDpsInitsMutable()) {
35 Type type = operand.get().getType();
36 if (isa<TensorType>(type)) {
37 outputTensorOperands.push_back(&operand);
38 }
else if (!isa<BaseMemRefType>(type)) {
40 << operand.getOperandNumber() <<
" is a tensor or a memref";
45 if (getNumTensorResults(op) != outputTensorOperands.size())
46 return op->
emitOpError(
"expected the number of tensor results (")
47 << getNumTensorResults(op)
48 <<
") to be equal to the number of output tensors ("
49 << outputTensorOperands.size() <<
")";
51 for (
OpOperand *opOperand : outputTensorOperands) {
52 OpResult result = dstStyleOp.getTiedOpResult(opOperand);
53 if (result.
getType() != opOperand->get().getType())
54 return op->
emitOpError(
"expected type of operand #")
55 << opOperand->getOperandNumber() <<
" ("
56 << opOperand->get().getType() <<
")"
57 <<
" to match type of corresponding result (" << result.
getType()
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
result_type_range getResultTypes()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Type getType() const
Return the type of this value.
LogicalResult verifyDestinationStyleOpInterface(Operation *op)
Verify that op conforms to the invariants of DestinationStyleOpInterface.
Include the generated interface declarations.