21 #include "llvm/Support/InterleavedRange.h"
23 #define DEBUG_TYPE "reify-result-shapes"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
28 #define GEN_PASS_DEF_REIFYRESULTSHAPESPASS
29 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
42 ReifyRankedShapedTypeOpInterface op) {
43 LLVM_DEBUG({
DBGS() <<
" reifying op: " << op <<
"\n"; });
47 reifiedResultShapes.empty()) {
48 return op->emitWarning() <<
"failed to get the reified shapes";
51 bool modified =
false;
54 for (
const auto &[oldTy, reifiedShape] :
55 llvm::zip(op->getResultTypes(), reifiedResultShapes)) {
57 if (!isa<RankedTensorType, MemRefType>(oldTy)) {
58 outTypes.push_back(oldTy);
62 ShapedType shapedTy = dyn_cast<ShapedType>(oldTy);
65 for (
auto &&[dim, ofr] : llvm::zip_equal(shape, reifiedShape)) {
68 if (!maybeCst.has_value()) {
69 dim = ShapedType::kDynamic;
77 if (shape == shapedTy.getShape()) {
78 outTypes.push_back(oldTy);
82 outTypes.push_back(shapedTy.cloneWith(shape, shapedTy.getElementType()));
87 LLVM_DEBUG({
DBGS() <<
"- op doesn't require update\n"; });
92 DBGS() <<
"- oldTypes: " << llvm::interleaved_array(op->getResultTypes())
94 DBGS() <<
"- outTypes: " << llvm::interleaved_array(outTypes) <<
" \n";
103 assert((isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation())) &&
106 for (
auto [reifiedTy, oldRes] : llvm::zip(outTypes, op->getResults())) {
108 Type oldTy = oldRes.getType();
110 if (oldTy == reifiedTy || !isa<MemRefType, RankedTensorType>(oldTy)) {
111 newResults.push_back(newRes);
117 if (isa<RankedTensorType>(reifiedTy)) {
118 newResults.push_back(
119 tensor::CastOp::create(rewriter, loc, oldTy, newRes));
121 assert(isa<MemRefType>(reifiedTy) &&
"expected a memref type");
122 newResults.push_back(
123 memref::CastOp::create(rewriter, loc, oldTy, newRes));
128 DBGS() <<
"- reified results " << llvm::interleaved_array(newResults)
140 struct ReifyResultShapesPass final
141 :
public memref::impl::ReifyResultShapesPassBase<ReifyResultShapesPass> {
142 void runOnOperation()
override;
146 void ReifyResultShapesPass::runOnOperation() {
148 getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) {
151 if (!isa<tensor::PadOp, tensor::ConcatOp>(op.getOperation()))
156 for (ReifyRankedShapedTypeOpInterface op : ops) {
157 rewriter.setInsertionPoint(op);
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult reifyOpResultShapes(RewriterBase &rewriter, ReifyRankedShapedTypeOpInterface op)
Reifies the results of op, potentially replacing op with a reified version.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).