20 #include "llvm/ADT/STLExtras.h"
23 #define GEN_PASS_DEF_CONVERTSHAPETOSTANDARD
24 #include "mlir/Conversion/Passes.h.inc"
38 matchAndRewrite(
AnyOp op, OpAdaptor adaptor,
44 AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
48 rewriter.
replaceOp(op, {adaptor.getInputs().front()});
53 template <
typename SrcOpTy,
typename DstOpTy>
59 matchAndRewrite(SrcOpTy op,
typename SrcOpTy::Adaptor adaptor,
62 if (isa<SizeType>(op.getType()))
77 matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
86 Value broadcastedDim = one;
87 for (
auto tup : llvm::zip(extentTensors, rankDiffs)) {
88 Value shape = std::get<0>(tup);
89 Value rankDiff = std::get<1>(tup);
90 Value outOfBounds = lb.
create<arith::CmpIOp>(arith::CmpIPredicate::ult,
91 outputDimension, rankDiff);
97 b.create<scf::YieldOp>(loc, broadcastedDim);
107 Value lesserRankOperandDimension = b.create<arith::SubIOp>(
108 loc, indexTy, outputDimension, rankDiff);
109 Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
110 loc, shape,
ValueRange{lesserRankOperandDimension});
113 b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
114 lesserRankOperandExtent, one);
115 Value dim = b.create<arith::SelectOp>(
116 loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
117 b.create<scf::YieldOp>(loc, dim);
121 return broadcastedDim;
125 LogicalResult BroadcastOpConverter::matchAndRewrite(
126 BroadcastOp op, OpAdaptor adaptor,
130 if (isa<ShapeType>(op.getType()))
133 auto loc = op.getLoc();
143 llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](
Value v) {
144 return lb.create<tensor::DimOp>(v, zero);
148 Value maxRank = ranks.front();
149 for (
Value v : llvm::drop_begin(ranks, 1)) {
150 maxRank = lb.
create<arith::MaxUIOp>(v, maxRank);
154 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](
Value v) {
155 return lb.
create<arith::SubIOp>(indexTy, maxRank, v);
161 Value broadcastedDim =
165 b.
create<tensor::YieldOp>(loc, broadcastedDim);
167 if (replacement.
getType() != op.getType())
168 replacement = lb.
create<tensor::CastOp>(op.getType(), replacement);
179 matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
184 LogicalResult ConstShapeOpConverter::matchAndRewrite(
185 ConstShapeOp op, OpAdaptor adaptor,
190 if (isa<ShapeType>(op.getType()))
193 auto loc = op.getLoc();
195 for (
auto extent : op.getShape()) {
196 extentOperands.push_back(
197 rewriter.
create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
202 rewriter.
create<tensor::FromElementsOp>(loc, resultTy, extentOperands);
213 matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
218 LogicalResult ConstSizeOpConversion::matchAndRewrite(
219 ConstSizeOp op, OpAdaptor adaptor,
222 op, op.getValue().getSExtValue());
227 struct IsBroadcastableOpConverter
232 matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
237 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
238 IsBroadcastableOp op, OpAdaptor adaptor,
242 if (!llvm::all_of(op.getShapes(),
243 [](
Value v) { return !isa<ShapeType>(v.getType()); }))
246 auto loc = op.getLoc();
256 llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](
Value v) {
257 return lb.create<tensor::DimOp>(v, zero);
261 Value maxRank = ranks.front();
262 for (
Value v : llvm::drop_begin(ranks, 1)) {
263 maxRank = lb.
create<arith::MaxUIOp>(v, maxRank);
267 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](
Value v) {
268 return lb.
create<arith::SubIOp>(indexTy, maxRank, v);
275 auto reduceResult = lb.
create<ForOp>(
281 Value broadcastedDim = getBroadcastedDim(
284 Value broadcastable = iterArgs[0];
285 for (
auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
286 Value shape, rankDiff;
287 std::tie(shape, rankDiff) = tup;
289 loc, arith::CmpIPredicate::ult, iv, rankDiff);
295 b.create<scf::YieldOp>(loc, broadcastable);
300 Value operandDimension =
301 b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff);
302 Value dimensionExtent = b.create<tensor::ExtractOp>(
305 Value equalOne = b.create<arith::CmpIOp>(
306 loc, arith::CmpIPredicate::eq, dimensionExtent, one);
307 Value equalBroadcasted = b.create<arith::CmpIOp>(
308 loc, arith::CmpIPredicate::eq, dimensionExtent,
310 Value result = b.create<arith::AndIOp>(
312 b.create<arith::OrIOp>(loc, equalOne,
314 b.create<scf::YieldOp>(loc, result);
319 b.
create<scf::YieldOp>(loc, broadcastable);
322 rewriter.
replaceOp(op, reduceResult.getResults().front());
331 matchAndRewrite(DimOp op, OpAdaptor adaptor,
337 DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
342 auto shapeOf = rewriter.
create<shape::ShapeOfOp>(op.getLoc(), op.getValue());
353 matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
358 LogicalResult GetExtentOpConverter::matchAndRewrite(
359 GetExtentOp op, OpAdaptor adaptor,
362 if (isa<SizeType>(op.getType()))
367 if (
auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) {
368 if (isa<ShapedType>(shapeOfOp.getArg().getType())) {
387 matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
393 RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
396 if (isa<SizeType>(op.getType()))
410 matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
416 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
419 if (isa<ShapeType>(op.getShape().getType()))
422 auto loc = op.getLoc();
424 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
425 Value one = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
428 rewriter.
create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero);
430 auto loop = rewriter.
create<scf::ForOp>(
431 loc, zero, rank, one, op.getInitVals(),
433 Value extent = b.
create<tensor::ExtractOp>(loc, adaptor.getShape(), iv);
436 mappedValues.append(args.begin(), args.end());
439 Block *reduceBody = op.getBody();
442 b.
clone(nested, mapping);
446 mappedResults.push_back(mapping.
lookup(result));
447 b.
create<scf::YieldOp>(loc, mappedResults);
450 rewriter.
replaceOp(op, loop.getResults());
489 matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
495 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
497 if (!llvm::all_of(op.getShapes(),
498 [](
Value v) { return !isa<ShapeType>(v.getType()); }))
502 if (op.getShapes().size() <= 1) {
508 auto loc = op.getLoc();
510 Value zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
511 Value firstShape = adaptor.getShapes().front();
513 rewriter.
create<tensor::DimOp>(loc, indexTy, firstShape, zero);
514 Value result =
nullptr;
516 for (
Value shape : adaptor.getShapes().drop_front(1)) {
517 Value rank = rewriter.
create<tensor::DimOp>(loc, indexTy, shape, zero);
518 Value eqRank = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
520 auto same = rewriter.
create<IfOp>(
523 Value one = b.create<arith::ConstantIndexOp>(loc, 1);
525 b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(
true));
526 auto loop = b.create<scf::ForOp>(
529 Value conj = args[0];
531 b.
create<tensor::ExtractOp>(loc, firstShape, iv);
532 Value rhsExtent = b.
create<tensor::ExtractOp>(loc, shape, iv);
534 loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
535 Value conjNext = b.
create<arith::AndIOp>(loc, conj, eqExtent);
538 b.
create<scf::YieldOp>(loc, loop.getResults());
543 b.
create<scf::YieldOp>(loc, result);
545 result = !result ? same.getResult(0)
546 : rewriter.
create<arith::AndIOp>(loc, result,
559 matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
564 LogicalResult ShapeOfOpConversion::matchAndRewrite(
565 ShapeOfOp op, OpAdaptor adaptor,
569 if (isa<ShapeType>(op.getType()))
573 auto loc = op.getLoc();
574 Value tensor = adaptor.getArg();
576 if (isa<RankedTensorType>(tensorTy)) {
580 RankedTensorType rankedTensorTy = cast<RankedTensorType>(tensorTy);
581 int64_t rank = rankedTensorTy.getRank();
582 for (int64_t i = 0; i < rank; i++) {
583 if (rankedTensorTy.isDynamicDim(i)) {
584 Value extent = rewriter.
create<tensor::DimOp>(loc, tensor, i);
585 extentValues.push_back(extent);
587 Value extent = rewriter.
create<arith::ConstantIndexOp>(
588 loc, rankedTensorTy.getDimSize(i));
589 extentValues.push_back(extent);
594 Value staticExtentTensor = rewriter.
create<tensor::FromElementsOp>(
604 Value rank = rewriter.
create<tensor::RankOp>(loc, tensor);
608 Value dim = args.front();
609 Value extent = b.
create<tensor::DimOp>(loc, tensor, dim);
610 b.
create<tensor::YieldOp>(loc, extent);
622 matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
627 LogicalResult SplitAtOpConversion::matchAndRewrite(
628 SplitAtOp op, OpAdaptor adaptor,
632 if (llvm::any_of(
ValueRange{op.getOperand(), op.getHead(), op.getTail()},
638 Value rank = b.
create<tensor::DimOp>(adaptor.getOperand(), zero);
641 Value originalIndex = adaptor.getIndex();
642 Value add = b.
create<arith::AddIOp>(originalIndex, rank);
643 Value indexIsNegative =
644 b.
create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero);
645 Value index = b.
create<arith::SelectOp>(indexIsNegative, add, originalIndex);
649 b.
create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one);
650 Value tailSize = b.
create<arith::SubIOp>(rank, index);
651 Value tail = b.
create<tensor::ExtractSliceOp>(adaptor.getOperand(), index,
658 class ToExtentTensorOpConversion
664 matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
666 if (!isa<RankedTensorType>(adaptor.getInput().getType()))
678 #include "ShapeToStandard.cpp.inc"
683 class ConvertShapeToStandardPass
684 :
public impl::ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
686 void runOnOperation()
override;
690 void ConvertShapeToStandardPass::runOnOperation() {
694 target.addLegalDialect<arith::ArithDialect, SCFDialect,
695 tensor::TensorDialect>();
696 target.addLegalOp<CstrRequireOp, func::FuncOp, ModuleOp>();
703 auto module = getOperation();
711 populateWithGenerated(patterns);
714 BinaryOpConversion<AddOp, arith::AddIOp>,
715 BinaryOpConversion<MulOp, arith::MulIOp>,
716 BroadcastOpConverter,
717 ConstShapeOpConverter,
718 ConstSizeOpConversion,
720 IsBroadcastableOpConverter,
721 GetExtentOpConverter,
727 ToExtentTensorOpConversion>(patterns.
getContext());
731 std::unique_ptr<OperationPass<ModuleOp>>
733 return std::make_unique<ConvertShapeToStandardPass>();
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
BoolAttr getBoolAttr(bool value)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
This is a utility class for mapping one set of IR entities to another.
auto lookup(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
operand_range getOperands()
Returns an iterator on the underlying Value's.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Include the generated interface declarations.
void populateShapeToStandardConversionPatterns(RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< OperationPass< ModuleOp > > createConvertShapeToStandardPass()
@ AnyOp
No restrictions wrt. which ops are processed.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.