19#include "llvm/ADT/STLExtras.h"
22#define GEN_PASS_DEF_CONVERTSHAPETOSTANDARDPASS
23#include "mlir/Conversion/Passes.h.inc"
32class AnyOpConversion :
public OpConversionPattern<AnyOp> {
34 using OpConversionPattern<
AnyOp>::OpConversionPattern;
37 matchAndRewrite(
AnyOp op, OpAdaptor adaptor,
38 ConversionPatternRewriter &rewriter)
const override;
43AnyOpConversion::matchAndRewrite(
AnyOp op, OpAdaptor adaptor,
44 ConversionPatternRewriter &rewriter)
const {
47 rewriter.replaceOp(op, {adaptor.getInputs().front()});
52template <
typename SrcOpTy,
typename DstOpTy>
53class BinaryOpConversion :
public OpConversionPattern<SrcOpTy> {
55 using OpConversionPattern<SrcOpTy>::OpConversionPattern;
58 matchAndRewrite(SrcOpTy op,
typename SrcOpTy::Adaptor adaptor,
59 ConversionPatternRewriter &rewriter)
const override {
61 if (isa<SizeType>(op.getType()))
64 rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(),
72struct BroadcastOpConverter :
public OpConversionPattern<BroadcastOp> {
73 using OpConversionPattern<BroadcastOp>::OpConversionPattern;
76 matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
77 ConversionPatternRewriter &rewriter)
const override;
82Value getBroadcastedDim(ImplicitLocOpBuilder lb,
ValueRange extentTensors,
85 Value broadcastedDim = one;
86 for (
auto tup : llvm::zip(extentTensors, rankDiffs)) {
87 Value shape = std::get<0>(tup);
88 Value rankDiff = std::get<1>(tup);
89 Value outOfBounds = arith::CmpIOp::create(lb, arith::CmpIPredicate::ult,
90 outputDimension, rankDiff);
95 [&](OpBuilder &
b, Location loc) {
96 scf::YieldOp::create(
b, loc, broadcastedDim);
98 [&](OpBuilder &
b, Location loc) {
106 Value lesserRankOperandDimension = arith::SubIOp::create(
107 b, loc, indexTy, outputDimension, rankDiff);
108 Value lesserRankOperandExtent = tensor::ExtractOp::create(
109 b, loc, shape,
ValueRange{lesserRankOperandDimension});
112 arith::CmpIOp::create(
b, loc, arith::CmpIPredicate::eq,
113 lesserRankOperandExtent, one);
114 Value dim = arith::SelectOp::create(
115 b, loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
116 scf::YieldOp::create(
b, loc, dim);
120 return broadcastedDim;
124LogicalResult BroadcastOpConverter::matchAndRewrite(
125 BroadcastOp op, OpAdaptor adaptor,
126 ConversionPatternRewriter &rewriter)
const {
129 if (isa<ShapeType>(op.getType()))
132 auto loc = op.getLoc();
133 ImplicitLocOpBuilder lb(loc, rewriter);
141 SmallVector<Value> ranks, rankDiffs;
142 llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
143 return tensor::DimOp::create(lb, v, zero);
147 Value maxRank = ranks.front();
148 for (Value v : llvm::drop_begin(ranks, 1)) {
149 maxRank = arith::MaxUIOp::create(lb, v, maxRank);
153 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
154 return arith::SubIOp::create(lb, indexTy, maxRank, v);
160 Value broadcastedDim =
161 getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
164 tensor::YieldOp::create(b, loc, broadcastedDim);
173class ConstShapeOpConverter :
public OpConversionPattern<ConstShapeOp> {
175 using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
178 matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
179 ConversionPatternRewriter &rewriter)
const override;
183LogicalResult ConstShapeOpConverter::matchAndRewrite(
184 ConstShapeOp op, OpAdaptor adaptor,
185 ConversionPatternRewriter &rewriter)
const {
189 if (isa<ShapeType>(op.getType()))
192 auto loc = op.getLoc();
193 SmallVector<Value, 4> extentOperands;
194 for (
auto extent : op.getShape()) {
196 rewriter, loc, extent.getLimitedValue()));
199 RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
201 tensor::FromElementsOp::create(rewriter, loc, resultTy, extentOperands);
202 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
207class ConstSizeOpConversion :
public OpConversionPattern<ConstSizeOp> {
209 using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
212 matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
213 ConversionPatternRewriter &rewriter)
const override;
217LogicalResult ConstSizeOpConversion::matchAndRewrite(
218 ConstSizeOp op, OpAdaptor adaptor,
219 ConversionPatternRewriter &rewriter)
const {
220 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
221 op, op.getValue().getSExtValue());
226struct IsBroadcastableOpConverter
227 :
public OpConversionPattern<IsBroadcastableOp> {
228 using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
231 matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
232 ConversionPatternRewriter &rewriter)
const override;
236LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
237 IsBroadcastableOp op, OpAdaptor adaptor,
238 ConversionPatternRewriter &rewriter)
const {
241 if (!llvm::all_of(op.getShapes(),
242 [](Value v) { return !isa<ShapeType>(v.getType()); }))
245 auto loc = op.getLoc();
246 ImplicitLocOpBuilder lb(loc, rewriter);
254 SmallVector<Value> ranks, rankDiffs;
255 llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
256 return tensor::DimOp::create(lb, v, zero);
260 Value maxRank = ranks.front();
261 for (Value v : llvm::drop_begin(ranks, 1)) {
262 maxRank = arith::MaxUIOp::create(lb, v, maxRank);
266 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
267 return arith::SubIOp::create(lb, indexTy, maxRank, v);
270 Type i1Ty = rewriter.getI1Type();
271 Value trueVal = arith::ConstantOp::create(rewriter, loc, i1Ty,
272 rewriter.getBoolAttr(
true));
274 auto reduceResult = ForOp::create(
275 lb, loc, zero, maxRank, one,
ValueRange{trueVal},
276 [&](OpBuilder &
b, Location loc, Value iv,
ValueRange iterArgs) {
280 Value broadcastedDim = getBroadcastedDim(
281 ImplicitLocOpBuilder(loc,
b), adaptor.getShapes(), rankDiffs, iv);
283 Value broadcastable = iterArgs[0];
284 for (
auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
285 Value shape, rankDiff;
286 std::tie(shape, rankDiff) = tup;
287 Value outOfBounds = arith::CmpIOp::create(
288 b, loc, arith::CmpIPredicate::ult, iv, rankDiff);
292 [&](OpBuilder &
b, Location loc) {
294 scf::YieldOp::create(
b, loc, broadcastable);
296 [&](OpBuilder &
b, Location loc) {
299 Value operandDimension =
300 arith::SubIOp::create(
b, loc, indexTy, iv, rankDiff);
301 Value dimensionExtent = tensor::ExtractOp::create(
304 Value equalOne = arith::CmpIOp::create(
305 b, loc, arith::CmpIPredicate::eq, dimensionExtent, one);
306 Value equalBroadcasted =
307 arith::CmpIOp::create(
b, loc, arith::CmpIPredicate::eq,
308 dimensionExtent, broadcastedDim);
309 Value
result = arith::AndIOp::create(
310 b, loc, broadcastable,
311 arith::OrIOp::create(
b, loc, equalOne,
313 scf::YieldOp::create(
b, loc,
result);
318 scf::YieldOp::create(
b, loc, broadcastable);
321 rewriter.replaceOp(op, reduceResult.getResults().front());
326class DimOpConverter :
public OpConversionPattern<DimOp> {
327 using OpConversionPattern<DimOp>::OpConversionPattern;
330 matchAndRewrite(DimOp op, OpAdaptor adaptor,
331 ConversionPatternRewriter &rewriter)
const override;
336DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
337 ConversionPatternRewriter &rewriter)
const {
341 auto shapeOf = shape::ShapeOfOp::create(rewriter, op.getLoc(), op.getValue());
342 rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
348class GetExtentOpConverter :
public OpConversionPattern<GetExtentOp> {
349 using OpConversionPattern<GetExtentOp>::OpConversionPattern;
352 matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
353 ConversionPatternRewriter &rewriter)
const override;
357LogicalResult GetExtentOpConverter::matchAndRewrite(
358 GetExtentOp op, OpAdaptor adaptor,
359 ConversionPatternRewriter &rewriter)
const {
361 if (isa<SizeType>(op.getType()))
366 if (
auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) {
367 if (isa<ShapedType>(shapeOfOp.getArg().getType())) {
368 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(),
374 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
381class RankOpConverter :
public OpConversionPattern<shape::RankOp> {
383 using OpConversionPattern<shape::RankOp>::OpConversionPattern;
386 matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
387 ConversionPatternRewriter &rewriter)
const override;
392RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
393 ConversionPatternRewriter &rewriter)
const {
395 if (isa<SizeType>(op.getType()))
398 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0);
404struct ReduceOpConverter :
public OpConversionPattern<shape::ReduceOp> {
406 using OpConversionPattern::OpConversionPattern;
409 matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
410 ConversionPatternRewriter &rewriter)
const final;
415ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
416 ConversionPatternRewriter &rewriter)
const {
418 if (isa<ShapeType>(op.getShape().getType()))
421 auto loc = op.getLoc();
425 Type indexTy = rewriter.getIndexType();
427 tensor::DimOp::create(rewriter, loc, indexTy, adaptor.getShape(), zero);
429 auto loop = scf::ForOp::create(
430 rewriter, loc, zero, rank, one, op.getInitVals(),
431 [&](OpBuilder &
b, Location loc, Value iv,
ValueRange args) {
433 tensor::ExtractOp::create(b, loc, adaptor.getShape(), iv);
435 SmallVector<Value, 2> mappedValues{iv, extent};
436 mappedValues.append(args.begin(), args.end());
439 Block *reduceBody = op.getBody();
440 mapping.map(reduceBody->getArguments(), mappedValues);
441 for (
auto &nested : reduceBody->without_terminator())
442 b.clone(nested, mapping);
444 SmallVector<Value, 2> mappedResults;
445 for (
auto result : reduceBody->getTerminator()->getOperands())
446 mappedResults.push_back(mapping.lookup(
result));
447 scf::YieldOp::create(
b, loc, mappedResults);
450 rewriter.replaceOp(op, loop.getResults());
485struct ShapeEqOpConverter :
public OpConversionPattern<ShapeEqOp> {
486 using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
489 matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
490 ConversionPatternRewriter &rewriter)
const override;
495ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
496 ConversionPatternRewriter &rewriter)
const {
497 if (!llvm::all_of(op.getShapes(),
498 [](Value v) { return !isa<ShapeType>(v.getType()); }))
501 Type i1Ty = rewriter.getI1Type();
502 if (op.getShapes().size() <= 1) {
503 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty,
504 rewriter.getBoolAttr(
true));
508 auto loc = op.getLoc();
509 Type indexTy = rewriter.getIndexType();
511 Value firstShape = adaptor.getShapes().front();
513 tensor::DimOp::create(rewriter, loc, indexTy, firstShape, zero);
516 for (Value shape : adaptor.getShapes().drop_front(1)) {
517 Value rank = tensor::DimOp::create(rewriter, loc, indexTy, shape, zero);
518 Value eqRank = arith::CmpIOp::create(
519 rewriter, loc, arith::CmpIPredicate::eq, firstRank, rank);
520 auto same = IfOp::create(
521 rewriter, loc, eqRank,
522 [&](OpBuilder &
b, Location loc) {
525 arith::ConstantOp::create(
b, loc, i1Ty,
b.getBoolAttr(
true));
526 auto loop = scf::ForOp::create(
528 [&](OpBuilder &
b, Location nestedLoc, Value iv,
ValueRange args) {
529 Value conj = args[0];
531 tensor::ExtractOp::create(
b, loc, firstShape, iv);
532 Value rhsExtent = tensor::ExtractOp::create(
b, loc, shape, iv);
533 Value eqExtent = arith::CmpIOp::create(
534 b, loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
535 Value conjNext = arith::AndIOp::create(
b, loc, conj, eqExtent);
536 scf::YieldOp::create(
b, loc,
ValueRange({conjNext}));
538 scf::YieldOp::create(
b, loc, loop.getResults());
540 [&](OpBuilder &
b, Location loc) {
542 arith::ConstantOp::create(
b, loc, i1Ty,
b.getBoolAttr(
false));
543 scf::YieldOp::create(
b, loc,
result);
546 : arith::AndIOp::create(rewriter, loc,
result,
549 rewriter.replaceOp(op,
result);
554class ShapeOfOpConversion :
public OpConversionPattern<ShapeOfOp> {
556 using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
559 matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
560 ConversionPatternRewriter &rewriter)
const override;
564LogicalResult ShapeOfOpConversion::matchAndRewrite(
565 ShapeOfOp op, OpAdaptor adaptor,
566 ConversionPatternRewriter &rewriter)
const {
569 if (isa<ShapeType>(op.getType()))
573 auto loc = op.getLoc();
574 Value tensor = adaptor.getArg();
575 Type tensorTy = tensor.
getType();
576 if (isa<RankedTensorType>(tensorTy)) {
579 SmallVector<Value, 8> extentValues;
580 RankedTensorType rankedTensorTy = cast<RankedTensorType>(tensorTy);
581 int64_t rank = rankedTensorTy.getRank();
582 for (int64_t i = 0; i < rank; i++) {
583 if (rankedTensorTy.isDynamicDim(i)) {
584 Value extent = tensor::DimOp::create(rewriter, loc, tensor, i);
585 extentValues.push_back(extent);
588 rewriter, loc, rankedTensorTy.getDimSize(i));
589 extentValues.push_back(extent);
594 Value staticExtentTensor = tensor::FromElementsOp::create(
595 rewriter, loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
597 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
603 auto *ctx = rewriter.getContext();
604 Value rank = tensor::RankOp::create(rewriter, loc, tensor);
605 rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
608 Value dim = args.front();
609 Value extent = tensor::DimOp::create(
b, loc, tensor, dim);
610 tensor::YieldOp::create(
b, loc, extent);
617class SplitAtOpConversion :
public OpConversionPattern<SplitAtOp> {
619 using OpConversionPattern<SplitAtOp>::OpConversionPattern;
622 matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
623 ConversionPatternRewriter &rewriter)
const override;
627LogicalResult SplitAtOpConversion::matchAndRewrite(
628 SplitAtOp op, OpAdaptor adaptor,
629 ConversionPatternRewriter &rewriter)
const {
632 if (llvm::any_of(
ValueRange{op.getOperand(), op.getHead(), op.getTail()},
633 [](Value v) {
return isa<ShapeType>(v.
getType()); }))
636 ImplicitLocOpBuilder
b(op.getLoc(), rewriter);
638 Value rank = tensor::DimOp::create(
b, adaptor.getOperand(), zero);
641 Value originalIndex = adaptor.getIndex();
642 Value
add = arith::AddIOp::create(
b, originalIndex, rank);
643 Value indexIsNegative =
644 arith::CmpIOp::create(
b, arith::CmpIPredicate::slt, originalIndex, zero);
645 Value index = arith::SelectOp::create(
b, indexIsNegative,
add, originalIndex);
649 tensor::ExtractSliceOp::create(
b, adaptor.getOperand(), zero, index, one);
650 Value tailSize = arith::SubIOp::create(
b, rank, index);
651 Value tail = tensor::ExtractSliceOp::create(
b, adaptor.getOperand(), index,
653 rewriter.replaceOp(op, {head, tail});
658class ToExtentTensorOpConversion
659 :
public OpConversionPattern<ToExtentTensorOp> {
661 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
664 matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
665 ConversionPatternRewriter &rewriter)
const override {
666 if (!isa<RankedTensorType>(adaptor.getInput().getType()))
667 return rewriter.notifyMatchFailure(op,
"input needs to be a tensor");
669 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
678#include "ShapeToStandard.cpp.inc"
683class ConvertShapeToStandardPass
684 :
public impl::ConvertShapeToStandardPassBase<ConvertShapeToStandardPass> {
686 void runOnOperation()
override;
690void ConvertShapeToStandardPass::runOnOperation() {
693 ConversionTarget
target(ctx);
694 target.addLegalDialect<arith::ArithDialect, SCFDialect,
695 tensor::TensorDialect>();
696 target.addLegalOp<CstrRequireOp, func::FuncOp, ModuleOp>();
703 auto module = getOperation();
714 BinaryOpConversion<AddOp, arith::AddIOp>,
715 BinaryOpConversion<MulOp, arith::MulIOp>,
716 BroadcastOpConverter,
717 ConstShapeOpConverter,
718 ConstSizeOpConversion,
720 IsBroadcastableOpConverter,
721 GetExtentOpConverter,
727 ToExtentTensorOpConversion>(
patterns.getContext());
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
MLIRContext * getContext() const
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateShapeToStandardConversionPatterns(RewritePatternSet &patterns)
@ AnyOp
No restrictions wrt. which ops are processed.