19 #include "llvm/ADT/STLExtras.h"
22 #define GEN_PASS_DEF_CONVERTSHAPETOSTANDARDPASS
23 #include "mlir/Conversion/Passes.h.inc"
37 matchAndRewrite(
AnyOp op, OpAdaptor adaptor,
43 AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
47 rewriter.
replaceOp(op, {adaptor.getInputs().front()});
52 template <
typename SrcOpTy,
typename DstOpTy>
58 matchAndRewrite(SrcOpTy op,
typename SrcOpTy::Adaptor adaptor,
61 if (isa<SizeType>(op.getType()))
76 matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
85 Value broadcastedDim = one;
86 for (
auto tup : llvm::zip(extentTensors, rankDiffs)) {
87 Value shape = std::get<0>(tup);
88 Value rankDiff = std::get<1>(tup);
89 Value outOfBounds = arith::CmpIOp::create(lb, arith::CmpIPredicate::ult,
90 outputDimension, rankDiff);
96 scf::YieldOp::create(b, loc, broadcastedDim);
106 Value lesserRankOperandDimension = arith::SubIOp::create(
107 b, loc, indexTy, outputDimension, rankDiff);
108 Value lesserRankOperandExtent = tensor::ExtractOp::create(
109 b, loc, shape,
ValueRange{lesserRankOperandDimension});
112 arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq,
113 lesserRankOperandExtent, one);
114 Value dim = arith::SelectOp::create(
115 b, loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
116 scf::YieldOp::create(b, loc, dim);
120 return broadcastedDim;
124 LogicalResult BroadcastOpConverter::matchAndRewrite(
125 BroadcastOp op, OpAdaptor adaptor,
129 if (isa<ShapeType>(op.getType()))
132 auto loc = op.getLoc();
142 llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](
Value v) {
143 return tensor::DimOp::create(lb, v, zero);
147 Value maxRank = ranks.front();
148 for (
Value v : llvm::drop_begin(ranks, 1)) {
149 maxRank = arith::MaxUIOp::create(lb, v, maxRank);
153 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](
Value v) {
154 return arith::SubIOp::create(lb, indexTy, maxRank, v);
157 Value replacement = tensor::GenerateOp::create(
160 Value broadcastedDim =
161 getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
164 tensor::YieldOp::create(b, loc, broadcastedDim);
166 if (replacement.
getType() != op.getType())
167 replacement = tensor::CastOp::create(lb, op.getType(), replacement);
178 matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
183 LogicalResult ConstShapeOpConverter::matchAndRewrite(
184 ConstShapeOp op, OpAdaptor adaptor,
189 if (isa<ShapeType>(op.getType()))
192 auto loc = op.getLoc();
194 for (
auto extent : op.getShape()) {
196 rewriter, loc, extent.getLimitedValue()));
201 tensor::FromElementsOp::create(rewriter, loc, resultTy, extentOperands);
212 matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
217 LogicalResult ConstSizeOpConversion::matchAndRewrite(
218 ConstSizeOp op, OpAdaptor adaptor,
221 op, op.getValue().getSExtValue());
226 struct IsBroadcastableOpConverter
231 matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
236 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
237 IsBroadcastableOp op, OpAdaptor adaptor,
241 if (!llvm::all_of(op.getShapes(),
242 [](
Value v) { return !isa<ShapeType>(v.getType()); }))
245 auto loc = op.getLoc();
255 llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](
Value v) {
256 return tensor::DimOp::create(lb, v, zero);
260 Value maxRank = ranks.front();
261 for (
Value v : llvm::drop_begin(ranks, 1)) {
262 maxRank = arith::MaxUIOp::create(lb, v, maxRank);
266 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](
Value v) {
267 return arith::SubIOp::create(lb, indexTy, maxRank, v);
271 Value trueVal = arith::ConstantOp::create(rewriter, loc, i1Ty,
274 auto reduceResult = ForOp::create(
275 lb, loc, zero, maxRank, one,
ValueRange{trueVal},
280 Value broadcastedDim = getBroadcastedDim(
283 Value broadcastable = iterArgs[0];
284 for (
auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
285 Value shape, rankDiff;
286 std::tie(shape, rankDiff) = tup;
287 Value outOfBounds = arith::CmpIOp::create(
288 b, loc, arith::CmpIPredicate::ult, iv, rankDiff);
294 scf::YieldOp::create(b, loc, broadcastable);
299 Value operandDimension =
300 arith::SubIOp::create(b, loc, indexTy, iv, rankDiff);
301 Value dimensionExtent = tensor::ExtractOp::create(
304 Value equalOne = arith::CmpIOp::create(
305 b, loc, arith::CmpIPredicate::eq, dimensionExtent, one);
306 Value equalBroadcasted =
307 arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq,
308 dimensionExtent, broadcastedDim);
309 Value result = arith::AndIOp::create(
310 b, loc, broadcastable,
311 arith::OrIOp::create(b, loc, equalOne,
313 scf::YieldOp::create(b, loc, result);
318 scf::YieldOp::create(b, loc, broadcastable);
321 rewriter.
replaceOp(op, reduceResult.getResults().front());
330 matchAndRewrite(DimOp op, OpAdaptor adaptor,
336 DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
341 auto shapeOf = shape::ShapeOfOp::create(rewriter, op.getLoc(), op.getValue());
352 matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
357 LogicalResult GetExtentOpConverter::matchAndRewrite(
358 GetExtentOp op, OpAdaptor adaptor,
361 if (isa<SizeType>(op.getType()))
366 if (
auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) {
367 if (isa<ShapedType>(shapeOfOp.getArg().getType())) {
386 matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
392 RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
395 if (isa<SizeType>(op.getType()))
409 matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
415 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
418 if (isa<ShapeType>(op.getShape().getType()))
421 auto loc = op.getLoc();
427 tensor::DimOp::create(rewriter, loc, indexTy, adaptor.getShape(), zero);
429 auto loop = scf::ForOp::create(
430 rewriter, loc, zero, rank, one, op.getInitVals(),
433 tensor::ExtractOp::create(b, loc, adaptor.getShape(), iv);
435 SmallVector<Value, 2> mappedValues{iv, extent};
436 mappedValues.append(args.begin(), args.end());
439 Block *reduceBody = op.getBody();
440 mapping.map(reduceBody->getArguments(), mappedValues);
441 for (
auto &nested : reduceBody->without_terminator())
442 b.
clone(nested, mapping);
445 for (
auto result : reduceBody->getTerminator()->getOperands())
446 mappedResults.push_back(mapping.lookup(result));
447 scf::YieldOp::create(b, loc, mappedResults);
450 rewriter.
replaceOp(op, loop.getResults());
489 matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
495 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
497 if (!llvm::all_of(op.getShapes(),
498 [](
Value v) { return !isa<ShapeType>(v.getType()); }))
502 if (op.getShapes().size() <= 1) {
508 auto loc = op.getLoc();
511 Value firstShape = adaptor.getShapes().front();
513 tensor::DimOp::create(rewriter, loc, indexTy, firstShape, zero);
514 Value result =
nullptr;
516 for (
Value shape : adaptor.getShapes().drop_front(1)) {
517 Value rank = tensor::DimOp::create(rewriter, loc, indexTy, shape, zero);
518 Value eqRank = arith::CmpIOp::create(
519 rewriter, loc, arith::CmpIPredicate::eq, firstRank, rank);
520 auto same = IfOp::create(
521 rewriter, loc, eqRank,
525 arith::ConstantOp::create(b, loc, i1Ty, b.
getBoolAttr(
true));
526 auto loop = scf::ForOp::create(
527 b, loc, zero, firstRank, one,
ValueRange{init},
529 Value conj = args[0];
531 tensor::ExtractOp::create(b, loc, firstShape, iv);
532 Value rhsExtent = tensor::ExtractOp::create(b, loc, shape, iv);
533 Value eqExtent = arith::CmpIOp::create(
534 b, loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
535 Value conjNext = arith::AndIOp::create(b, loc, conj, eqExtent);
536 scf::YieldOp::create(b, loc,
ValueRange({conjNext}));
538 scf::YieldOp::create(b, loc, loop.getResults());
542 arith::ConstantOp::create(b, loc, i1Ty, b.
getBoolAttr(
false));
543 scf::YieldOp::create(b, loc, result);
545 result = !result ? same.getResult(0)
546 : arith::AndIOp::create(rewriter, loc, result,
559 matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
564 LogicalResult ShapeOfOpConversion::matchAndRewrite(
565 ShapeOfOp op, OpAdaptor adaptor,
569 if (isa<ShapeType>(op.getType()))
573 auto loc = op.getLoc();
574 Value tensor = adaptor.getArg();
576 if (isa<RankedTensorType>(tensorTy)) {
580 RankedTensorType rankedTensorTy = cast<RankedTensorType>(tensorTy);
581 int64_t rank = rankedTensorTy.getRank();
582 for (int64_t i = 0; i < rank; i++) {
583 if (rankedTensorTy.isDynamicDim(i)) {
584 Value extent = tensor::DimOp::create(rewriter, loc, tensor, i);
585 extentValues.push_back(extent);
588 rewriter, loc, rankedTensorTy.getDimSize(i));
589 extentValues.push_back(extent);
594 Value staticExtentTensor = tensor::FromElementsOp::create(
604 Value rank = tensor::RankOp::create(rewriter, loc, tensor);
608 Value dim = args.front();
609 Value extent = tensor::DimOp::create(b, loc, tensor, dim);
610 tensor::YieldOp::create(b, loc, extent);
622 matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
627 LogicalResult SplitAtOpConversion::matchAndRewrite(
628 SplitAtOp op, OpAdaptor adaptor,
632 if (llvm::any_of(
ValueRange{op.getOperand(), op.getHead(), op.getTail()},
638 Value rank = tensor::DimOp::create(b, adaptor.getOperand(), zero);
641 Value originalIndex = adaptor.getIndex();
642 Value add = arith::AddIOp::create(b, originalIndex, rank);
643 Value indexIsNegative =
644 arith::CmpIOp::create(b, arith::CmpIPredicate::slt, originalIndex, zero);
645 Value index = arith::SelectOp::create(b, indexIsNegative,
add, originalIndex);
649 tensor::ExtractSliceOp::create(b, adaptor.getOperand(), zero, index, one);
650 Value tailSize = arith::SubIOp::create(b, rank, index);
651 Value tail = tensor::ExtractSliceOp::create(b, adaptor.getOperand(), index,
658 class ToExtentTensorOpConversion
664 matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
666 if (!isa<RankedTensorType>(adaptor.getInput().getType()))
678 #include "ShapeToStandard.cpp.inc"
683 class ConvertShapeToStandardPass
684 :
public impl::ConvertShapeToStandardPassBase<ConvertShapeToStandardPass> {
686 void runOnOperation()
override;
690 void ConvertShapeToStandardPass::runOnOperation() {
694 target.addLegalDialect<arith::ArithDialect, SCFDialect,
695 tensor::TensorDialect>();
696 target.addLegalOp<CstrRequireOp, func::FuncOp, ModuleOp>();
703 auto module = getOperation();
714 BinaryOpConversion<AddOp, arith::AddIOp>,
715 BinaryOpConversion<MulOp, arith::MulIOp>,
716 BroadcastOpConverter,
717 ConstShapeOpConverter,
718 ConstSizeOpConversion,
720 IsBroadcastableOpConverter,
721 GetExtentOpConverter,
727 ToExtentTensorOpConversion>(
patterns.getContext());
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
This is a utility class for mapping one set of IR entities to another.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateShapeToStandardConversionPatterns(RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
@ AnyOp
No restrictions wrt. which ops are processed.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.