MLIR  17.0.0git
DestinationStyleOpInterface.cpp
Go to the documentation of this file.
1 //===- DestinationStyleOpInterface.cpp -- Destination style ops -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
11 using namespace mlir;
12 
13 namespace mlir {
14 #include "mlir/Interfaces/DestinationStyleOpInterface.cpp.inc"
15 } // namespace mlir
16 
17 OpOperandVector::operator SmallVector<Value>() {
18  SmallVector<Value> result;
19  result.reserve(this->size());
20  llvm::transform(*this, std::back_inserter(result),
21  [](OpOperand *opOperand) { return opOperand->get(); });
22  return result;
23 }
24 
25 namespace {
26 size_t getNumTensorResults(Operation *op) {
27  size_t numTensorResults = 0;
28  for (auto t : op->getResultTypes()) {
29  if (isa<TensorType>(t)) {
30  ++numTensorResults;
31  }
32  }
33  return numTensorResults;
34 }
35 } // namespace
36 
38  DestinationStyleOpInterface dstStyleOp =
39  cast<DestinationStyleOpInterface>(op);
40 
41  SmallVector<OpOperand *> outputTensorOperands;
42  for (OpOperand *operand : dstStyleOp.getDpsInitOperands()) {
43  Type type = operand->get().getType();
44  if (isa<RankedTensorType>(type)) {
45  outputTensorOperands.push_back(operand);
46  } else if (!isa<MemRefType>(type)) {
47  return op->emitOpError("expected that operand #")
48  << operand->getOperandNumber()
49  << " is a ranked tensor or a ranked memref";
50  }
51  }
52 
53  // Verify the number of tensor results matches the number of output tensors.
54  if (getNumTensorResults(op) != outputTensorOperands.size())
55  return op->emitOpError("expected the number of tensor results (")
56  << getNumTensorResults(op)
57  << ") to be equal to the number of output tensors ("
58  << outputTensorOperands.size() << ")";
59 
60  for (OpOperand *opOperand : outputTensorOperands) {
61  OpResult result = dstStyleOp.getTiedOpResult(opOperand);
62  if (result.getType() != opOperand->get().getType())
63  return op->emitOpError("expected type of operand #")
64  << opOperand->getOperandNumber() << " ("
65  << opOperand->get().getType() << ")"
66  << " to match type of corresponding result (" << result.getType()
67  << ")";
68  }
69  return success();
70 }
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:152
This class represents an operand of an operation.
Definition: Value.h:261
This is a value defined by a result of an operation.
Definition: Value.h:448
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_type_range getResultTypes()
Definition: Operation.h:423
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:632
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Type getType() const
Return the type of this value.
Definition: Value.h:122
LogicalResult verifyDestinationStyleOpInterface(Operation *op)
Verify that op conforms to the invariants of DestinationStyleOpInterface.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26