MLIR 22.0.0git
ShapeToStandard.cpp
Go to the documentation of this file.
1//===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "mlir/IR/IRMapping.h"
17#include "mlir/Pass/Pass.h"
19#include "llvm/ADT/STLExtras.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_CONVERTSHAPETOSTANDARDPASS
23#include "mlir/Conversion/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::shape;
28using namespace mlir::scf;
29
30/// Conversion patterns.
31namespace {
32class AnyOpConversion : public OpConversionPattern<AnyOp> {
33public:
34 using OpConversionPattern<AnyOp>::OpConversionPattern;
35
36 LogicalResult
37 matchAndRewrite(AnyOp op, OpAdaptor adaptor,
38 ConversionPatternRewriter &rewriter) const override;
39};
40} // namespace
41
42LogicalResult
43AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
44 ConversionPatternRewriter &rewriter) const {
45 // Replace `any` with its first operand.
46 // Any operand would be a valid substitution.
47 rewriter.replaceOp(op, {adaptor.getInputs().front()});
48 return success();
49}
50
51namespace {
52template <typename SrcOpTy, typename DstOpTy>
53class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
54public:
55 using OpConversionPattern<SrcOpTy>::OpConversionPattern;
56
57 LogicalResult
58 matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor,
59 ConversionPatternRewriter &rewriter) const override {
60 // For now, only error-free types are supported by this lowering.
61 if (isa<SizeType>(op.getType()))
62 return failure();
63
64 rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(),
65 adaptor.getRhs());
66 return success();
67 }
68};
69} // namespace
70
71namespace {
72struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
73 using OpConversionPattern<BroadcastOp>::OpConversionPattern;
74
75 LogicalResult
76 matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
77 ConversionPatternRewriter &rewriter) const override;
78};
79
80// Get the resulting extent in a given dimension. This is computed with any
81// number of extent tensors and shifted offsets into them.
82Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
83 ValueRange rankDiffs, Value outputDimension) {
84 Value one = arith::ConstantIndexOp::create(lb, 1);
85 Value broadcastedDim = one;
86 for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
87 Value shape = std::get<0>(tup);
88 Value rankDiff = std::get<1>(tup);
89 Value outOfBounds = arith::CmpIOp::create(lb, arith::CmpIPredicate::ult,
90 outputDimension, rankDiff);
91 Type indexTy = lb.getIndexType();
92 broadcastedDim =
93 IfOp::create(
94 lb, outOfBounds,
95 [&](OpBuilder &b, Location loc) {
96 scf::YieldOp::create(b, loc, broadcastedDim);
97 },
98 [&](OpBuilder &b, Location loc) {
99 // The broadcasting logic is:
100 // - if one extent (here we arbitrarily choose the
101 // extent from the greater-rank operand) is equal to 1,
102 // then take the extent from the other operand
103 // - otherwise, take the extent as-is.
104 // Note that this logic remains correct in the presence
105 // of dimensions of zero extent.
106 Value lesserRankOperandDimension = arith::SubIOp::create(
107 b, loc, indexTy, outputDimension, rankDiff);
108 Value lesserRankOperandExtent = tensor::ExtractOp::create(
109 b, loc, shape, ValueRange{lesserRankOperandDimension});
110
111 Value dimIsOne =
112 arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq,
113 lesserRankOperandExtent, one);
114 Value dim = arith::SelectOp::create(
115 b, loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
116 scf::YieldOp::create(b, loc, dim);
117 })
118 .getResult(0);
119 }
120 return broadcastedDim;
121}
122} // namespace
123
124LogicalResult BroadcastOpConverter::matchAndRewrite(
125 BroadcastOp op, OpAdaptor adaptor,
126 ConversionPatternRewriter &rewriter) const {
127 // For now, this lowering is only defined on `tensor<?xindex>` operands, not
128 // on shapes.
129 if (isa<ShapeType>(op.getType()))
130 return failure();
131
132 auto loc = op.getLoc();
133 ImplicitLocOpBuilder lb(loc, rewriter);
134
135 Value zero = arith::ConstantIndexOp::create(lb, 0);
136 Type indexTy = lb.getIndexType();
137
138 // Save all the ranks for bounds checking. Because this is a tensor
139 // representing the shape extents, the rank is the extent of the only
140 // dimension in the tensor.
141 SmallVector<Value> ranks, rankDiffs;
142 llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
143 return tensor::DimOp::create(lb, v, zero);
144 }));
145
146 // Find the maximum rank
147 Value maxRank = ranks.front();
148 for (Value v : llvm::drop_begin(ranks, 1)) {
149 maxRank = arith::MaxUIOp::create(lb, v, maxRank);
150 }
151
152 // Calculate the difference of ranks and the maximum rank for later offsets.
153 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
154 return arith::SubIOp::create(lb, indexTy, maxRank, v);
155 }));
156
157 Value replacement = tensor::GenerateOp::create(
158 lb, getExtentTensorType(lb.getContext()), ValueRange{maxRank},
159 [&](OpBuilder &b, Location loc, ValueRange args) {
160 Value broadcastedDim =
161 getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
162 rankDiffs, args[0]);
163
164 tensor::YieldOp::create(b, loc, broadcastedDim);
165 });
166 if (replacement.getType() != op.getType())
167 replacement = tensor::CastOp::create(lb, op.getType(), replacement);
168 rewriter.replaceOp(op, replacement);
169 return success();
170}
171
172namespace {
173class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
174public:
175 using OpConversionPattern<ConstShapeOp>::OpConversionPattern;
176
177 LogicalResult
178 matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
179 ConversionPatternRewriter &rewriter) const override;
180};
181} // namespace
182
183LogicalResult ConstShapeOpConverter::matchAndRewrite(
184 ConstShapeOp op, OpAdaptor adaptor,
185 ConversionPatternRewriter &rewriter) const {
186
187 // For now, this lowering supports only extent tensors, not `shape.shape`
188 // types.
189 if (isa<ShapeType>(op.getType()))
190 return failure();
191
192 auto loc = op.getLoc();
193 SmallVector<Value, 4> extentOperands;
194 for (auto extent : op.getShape()) {
195 extentOperands.push_back(arith::ConstantIndexOp::create(
196 rewriter, loc, extent.getLimitedValue()));
197 }
198 Type resultTy =
199 RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
200 Value tensor =
201 tensor::FromElementsOp::create(rewriter, loc, resultTy, extentOperands);
202 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
203 return success();
204}
205
206namespace {
207class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
208public:
209 using OpConversionPattern<ConstSizeOp>::OpConversionPattern;
210
211 LogicalResult
212 matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
213 ConversionPatternRewriter &rewriter) const override;
214};
215} // namespace
216
217LogicalResult ConstSizeOpConversion::matchAndRewrite(
218 ConstSizeOp op, OpAdaptor adaptor,
219 ConversionPatternRewriter &rewriter) const {
220 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
221 op, op.getValue().getSExtValue());
222 return success();
223}
224
225namespace {
226struct IsBroadcastableOpConverter
227 : public OpConversionPattern<IsBroadcastableOp> {
228 using OpConversionPattern<IsBroadcastableOp>::OpConversionPattern;
229
230 LogicalResult
231 matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
232 ConversionPatternRewriter &rewriter) const override;
233};
234} // namespace
235
236LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
237 IsBroadcastableOp op, OpAdaptor adaptor,
238 ConversionPatternRewriter &rewriter) const {
239 // For now, this lowering is only defined on `tensor<?xindex>` operands, not
240 // on shapes.
241 if (!llvm::all_of(op.getShapes(),
242 [](Value v) { return !isa<ShapeType>(v.getType()); }))
243 return failure();
244
245 auto loc = op.getLoc();
246 ImplicitLocOpBuilder lb(loc, rewriter);
247 Value zero = arith::ConstantIndexOp::create(lb, 0);
248 Value one = arith::ConstantIndexOp::create(lb, 1);
249 Type indexTy = lb.getIndexType();
250
251 // Save all the ranks for bounds checking. Because this is a tensor
252 // representing the shape extents, the rank is the extent of the only
253 // dimension in the tensor.
254 SmallVector<Value> ranks, rankDiffs;
255 llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
256 return tensor::DimOp::create(lb, v, zero);
257 }));
258
259 // Find the maximum rank
260 Value maxRank = ranks.front();
261 for (Value v : llvm::drop_begin(ranks, 1)) {
262 maxRank = arith::MaxUIOp::create(lb, v, maxRank);
263 }
264
265 // Calculate the difference of ranks and the maximum rank for later offsets.
266 llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
267 return arith::SubIOp::create(lb, indexTy, maxRank, v);
268 }));
269
270 Type i1Ty = rewriter.getI1Type();
271 Value trueVal = arith::ConstantOp::create(rewriter, loc, i1Ty,
272 rewriter.getBoolAttr(true));
273
274 auto reduceResult = ForOp::create(
275 lb, loc, zero, maxRank, one, ValueRange{trueVal},
276 [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
277 // Find a non-1 dim, if it exists. Note that the first part of this
278 // could reuse the Broadcast lowering entirely, but we redo the work
279 // here to make optimizations easier between the two loops.
280 Value broadcastedDim = getBroadcastedDim(
281 ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv);
282
283 Value broadcastable = iterArgs[0];
284 for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
285 Value shape, rankDiff;
286 std::tie(shape, rankDiff) = tup;
287 Value outOfBounds = arith::CmpIOp::create(
288 b, loc, arith::CmpIPredicate::ult, iv, rankDiff);
289 broadcastable =
290 IfOp::create(
291 b, loc, outOfBounds,
292 [&](OpBuilder &b, Location loc) {
293 // Non existent dimensions are always broadcastable
294 scf::YieldOp::create(b, loc, broadcastable);
295 },
296 [&](OpBuilder &b, Location loc) {
297 // Every value needs to be either 1, or the same non-1
298 // value to be broadcastable in this dim.
299 Value operandDimension =
300 arith::SubIOp::create(b, loc, indexTy, iv, rankDiff);
301 Value dimensionExtent = tensor::ExtractOp::create(
302 b, loc, shape, ValueRange{operandDimension});
303
304 Value equalOne = arith::CmpIOp::create(
305 b, loc, arith::CmpIPredicate::eq, dimensionExtent, one);
306 Value equalBroadcasted =
307 arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq,
308 dimensionExtent, broadcastedDim);
309 Value result = arith::AndIOp::create(
310 b, loc, broadcastable,
311 arith::OrIOp::create(b, loc, equalOne,
312 equalBroadcasted));
313 scf::YieldOp::create(b, loc, result);
314 })
315 .getResult(0);
316 }
317
318 scf::YieldOp::create(b, loc, broadcastable);
319 });
320
321 rewriter.replaceOp(op, reduceResult.getResults().front());
322 return success();
323}
324
325namespace {
326class DimOpConverter : public OpConversionPattern<DimOp> {
327 using OpConversionPattern<DimOp>::OpConversionPattern;
328
329 LogicalResult
330 matchAndRewrite(DimOp op, OpAdaptor adaptor,
331 ConversionPatternRewriter &rewriter) const override;
332};
333} // namespace
334
335LogicalResult
336DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
337 ConversionPatternRewriter &rewriter) const {
338 // Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further
339 // lowerings. This can be further optimized if needed to avoid intermediate
340 // steps.
341 auto shapeOf = shape::ShapeOfOp::create(rewriter, op.getLoc(), op.getValue());
342 rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
343 op.getIndex());
344 return success();
345}
346
347namespace {
348class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
349 using OpConversionPattern<GetExtentOp>::OpConversionPattern;
350
351 LogicalResult
352 matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
353 ConversionPatternRewriter &rewriter) const override;
354};
355} // namespace
356
357LogicalResult GetExtentOpConverter::matchAndRewrite(
358 GetExtentOp op, OpAdaptor adaptor,
359 ConversionPatternRewriter &rewriter) const {
360 // For now, only error-free types are supported by this lowering.
361 if (isa<SizeType>(op.getType()))
362 return failure();
363
364 // Derive shape extent directly from shape origin if possible. This
365 // circumvents the necessity to materialize the shape in memory.
366 if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) {
367 if (isa<ShapedType>(shapeOfOp.getArg().getType())) {
368 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(),
369 adaptor.getDim());
370 return success();
371 }
372 }
373
374 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
375 adaptor.getShape(),
376 ValueRange{adaptor.getDim()});
377 return success();
378}
379
380namespace {
381class RankOpConverter : public OpConversionPattern<shape::RankOp> {
382public:
383 using OpConversionPattern<shape::RankOp>::OpConversionPattern;
384
385 LogicalResult
386 matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
387 ConversionPatternRewriter &rewriter) const override;
388};
389} // namespace
390
391LogicalResult
392RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
393 ConversionPatternRewriter &rewriter) const {
394 // For now, this lowering supports only error-free types.
395 if (isa<SizeType>(op.getType()))
396 return failure();
397
398 rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0);
399 return success();
400}
401
402namespace {
403/// Converts `shape.reduce` to `scf.for`.
404struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
405public:
406 using OpConversionPattern::OpConversionPattern;
407
408 LogicalResult
409 matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
410 ConversionPatternRewriter &rewriter) const final;
411};
412} // namespace
413
414LogicalResult
415ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
416 ConversionPatternRewriter &rewriter) const {
417 // For now, this lowering is only defined on `tensor<?xindex>` operands.
418 if (isa<ShapeType>(op.getShape().getType()))
419 return failure();
420
421 auto loc = op.getLoc();
422
423 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
424 Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
425 Type indexTy = rewriter.getIndexType();
426 Value rank =
427 tensor::DimOp::create(rewriter, loc, indexTy, adaptor.getShape(), zero);
428
429 auto loop = scf::ForOp::create(
430 rewriter, loc, zero, rank, one, op.getInitVals(),
431 [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
432 Value extent =
433 tensor::ExtractOp::create(b, loc, adaptor.getShape(), iv);
434
435 SmallVector<Value, 2> mappedValues{iv, extent};
436 mappedValues.append(args.begin(), args.end());
437
438 IRMapping mapping;
439 Block *reduceBody = op.getBody();
440 mapping.map(reduceBody->getArguments(), mappedValues);
441 for (auto &nested : reduceBody->without_terminator())
442 b.clone(nested, mapping);
443
444 SmallVector<Value, 2> mappedResults;
445 for (auto result : reduceBody->getTerminator()->getOperands())
446 mappedResults.push_back(mapping.lookup(result));
447 scf::YieldOp::create(b, loc, mappedResults);
448 });
449
450 rewriter.replaceOp(op, loop.getResults());
451 return success();
452}
453
454namespace {
455/// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
456/// only defined on `tensor<?xindex>` operands. The test for equality first
457/// compares their size and, if equal, checks every extent for equality.
458///
459/// Example:
460///
461/// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
462///
463/// becomes
464///
465/// %c0 = arith.constant 0 : index
466/// %0 = dim %arg0, %c0 : tensor<?xindex>
467/// %1 = dim %arg1, %c0 : tensor<?xindex>
468/// %2 = arith.cmpi "eq", %0, %1 : index
469/// %result = scf.if %2 -> (i1) {
470/// %c1 = arith.constant 1 : index
471/// %true = arith.constant true
472/// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
473/// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
474/// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
475/// %7 = arith.cmpi "eq", %5, %6 : index
476/// %8 = arith.andi %arg3, %7 : i1
477/// scf.yield %8 : i1
478/// }
479/// scf.yield %4 : i1
480/// } else {
481/// %false = arith.constant false
482/// scf.yield %false : i1
483/// }
484///
485struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
486 using OpConversionPattern<ShapeEqOp>::OpConversionPattern;
487
488 LogicalResult
489 matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
490 ConversionPatternRewriter &rewriter) const override;
491};
492} // namespace
493
494LogicalResult
495ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
496 ConversionPatternRewriter &rewriter) const {
497 if (!llvm::all_of(op.getShapes(),
498 [](Value v) { return !isa<ShapeType>(v.getType()); }))
499 return failure();
500
501 Type i1Ty = rewriter.getI1Type();
502 if (op.getShapes().size() <= 1) {
503 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty,
504 rewriter.getBoolAttr(true));
505 return success();
506 }
507
508 auto loc = op.getLoc();
509 Type indexTy = rewriter.getIndexType();
510 Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
511 Value firstShape = adaptor.getShapes().front();
512 Value firstRank =
513 tensor::DimOp::create(rewriter, loc, indexTy, firstShape, zero);
514 Value result = nullptr;
515 // Generate a linear sequence of compares, all with firstShape as lhs.
516 for (Value shape : adaptor.getShapes().drop_front(1)) {
517 Value rank = tensor::DimOp::create(rewriter, loc, indexTy, shape, zero);
518 Value eqRank = arith::CmpIOp::create(
519 rewriter, loc, arith::CmpIPredicate::eq, firstRank, rank);
520 auto same = IfOp::create(
521 rewriter, loc, eqRank,
522 [&](OpBuilder &b, Location loc) {
523 Value one = arith::ConstantIndexOp::create(b, loc, 1);
524 Value init =
525 arith::ConstantOp::create(b, loc, i1Ty, b.getBoolAttr(true));
526 auto loop = scf::ForOp::create(
527 b, loc, zero, firstRank, one, ValueRange{init},
528 [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
529 Value conj = args[0];
530 Value lhsExtent =
531 tensor::ExtractOp::create(b, loc, firstShape, iv);
532 Value rhsExtent = tensor::ExtractOp::create(b, loc, shape, iv);
533 Value eqExtent = arith::CmpIOp::create(
534 b, loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
535 Value conjNext = arith::AndIOp::create(b, loc, conj, eqExtent);
536 scf::YieldOp::create(b, loc, ValueRange({conjNext}));
537 });
538 scf::YieldOp::create(b, loc, loop.getResults());
539 },
540 [&](OpBuilder &b, Location loc) {
541 Value result =
542 arith::ConstantOp::create(b, loc, i1Ty, b.getBoolAttr(false));
543 scf::YieldOp::create(b, loc, result);
544 });
545 result = !result ? same.getResult(0)
546 : arith::AndIOp::create(rewriter, loc, result,
547 same.getResult(0));
548 }
549 rewriter.replaceOp(op, result);
550 return success();
551}
552
553namespace {
554class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
555public:
556 using OpConversionPattern<ShapeOfOp>::OpConversionPattern;
557
558 LogicalResult
559 matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
560 ConversionPatternRewriter &rewriter) const override;
561};
562} // namespace
563
564LogicalResult ShapeOfOpConversion::matchAndRewrite(
565 ShapeOfOp op, OpAdaptor adaptor,
566 ConversionPatternRewriter &rewriter) const {
567
568 // For now, only error-free types are supported by this lowering.
569 if (isa<ShapeType>(op.getType()))
570 return failure();
571
572 // For ranked tensor arguments, lower to `tensor.from_elements`.
573 auto loc = op.getLoc();
574 Value tensor = adaptor.getArg();
575 Type tensorTy = tensor.getType();
576 if (isa<RankedTensorType>(tensorTy)) {
577
578 // Build values for individual extents.
579 SmallVector<Value, 8> extentValues;
580 RankedTensorType rankedTensorTy = cast<RankedTensorType>(tensorTy);
581 int64_t rank = rankedTensorTy.getRank();
582 for (int64_t i = 0; i < rank; i++) {
583 if (rankedTensorTy.isDynamicDim(i)) {
584 Value extent = tensor::DimOp::create(rewriter, loc, tensor, i);
585 extentValues.push_back(extent);
586 } else {
587 Value extent = arith::ConstantIndexOp::create(
588 rewriter, loc, rankedTensorTy.getDimSize(i));
589 extentValues.push_back(extent);
590 }
591 }
592
593 // Materialize extent tensor.
594 Value staticExtentTensor = tensor::FromElementsOp::create(
595 rewriter, loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
596 extentValues);
597 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
598 staticExtentTensor);
599 return success();
600 }
601
602 // Lower to `tensor.generate` otherwise.
603 auto *ctx = rewriter.getContext();
604 Value rank = tensor::RankOp::create(rewriter, loc, tensor);
605 rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
606 op, getExtentTensorType(ctx), ValueRange{rank},
607 [&](OpBuilder &b, Location loc, ValueRange args) {
608 Value dim = args.front();
609 Value extent = tensor::DimOp::create(b, loc, tensor, dim);
610 tensor::YieldOp::create(b, loc, extent);
611 });
612
613 return success();
614}
615
616namespace {
617class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
618public:
619 using OpConversionPattern<SplitAtOp>::OpConversionPattern;
620
621 LogicalResult
622 matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
623 ConversionPatternRewriter &rewriter) const override;
624};
625} // namespace
626
627LogicalResult SplitAtOpConversion::matchAndRewrite(
628 SplitAtOp op, OpAdaptor adaptor,
629 ConversionPatternRewriter &rewriter) const {
630 // Error conditions are not implemented, only lower if all operands and
631 // results are extent tensors.
632 if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()},
633 [](Value v) { return isa<ShapeType>(v.getType()); }))
634 return failure();
635
636 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
637 Value zero = arith::ConstantIndexOp::create(b, 0);
638 Value rank = tensor::DimOp::create(b, adaptor.getOperand(), zero);
639
640 // index < 0 ? index + rank : index
641 Value originalIndex = adaptor.getIndex();
642 Value add = arith::AddIOp::create(b, originalIndex, rank);
643 Value indexIsNegative =
644 arith::CmpIOp::create(b, arith::CmpIPredicate::slt, originalIndex, zero);
645 Value index = arith::SelectOp::create(b, indexIsNegative, add, originalIndex);
646
647 Value one = arith::ConstantIndexOp::create(b, 1);
648 Value head =
649 tensor::ExtractSliceOp::create(b, adaptor.getOperand(), zero, index, one);
650 Value tailSize = arith::SubIOp::create(b, rank, index);
651 Value tail = tensor::ExtractSliceOp::create(b, adaptor.getOperand(), index,
652 tailSize, one);
653 rewriter.replaceOp(op, {head, tail});
654 return success();
655}
656
657namespace {
658class ToExtentTensorOpConversion
659 : public OpConversionPattern<ToExtentTensorOp> {
660public:
661 using OpConversionPattern<ToExtentTensorOp>::OpConversionPattern;
662
663 LogicalResult
664 matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
665 ConversionPatternRewriter &rewriter) const override {
666 if (!isa<RankedTensorType>(adaptor.getInput().getType()))
667 return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
668
669 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
670 adaptor.getInput());
671 return success();
672 }
673};
674} // namespace
675
676namespace {
677/// Import the Shape Ops to Std Patterns.
678#include "ShapeToStandard.cpp.inc"
679} // namespace
680
681namespace {
682/// Conversion pass.
683class ConvertShapeToStandardPass
684 : public impl::ConvertShapeToStandardPassBase<ConvertShapeToStandardPass> {
685
686 void runOnOperation() override;
687};
688} // namespace
689
690void ConvertShapeToStandardPass::runOnOperation() {
691 // Setup target legality.
692 MLIRContext &ctx = getContext();
693 ConversionTarget target(ctx);
694 target.addLegalDialect<arith::ArithDialect, SCFDialect,
695 tensor::TensorDialect>();
696 target.addLegalOp<CstrRequireOp, func::FuncOp, ModuleOp>();
697
698 // Setup conversion patterns.
699 RewritePatternSet patterns(&ctx);
701
702 // Apply conversion.
703 auto module = getOperation();
704 if (failed(applyPartialConversion(module, target, std::move(patterns))))
705 signalPassFailure();
706}
707
710 // clang-format off
711 populateWithGenerated(patterns);
712 patterns.add<
713 AnyOpConversion,
714 BinaryOpConversion<AddOp, arith::AddIOp>,
715 BinaryOpConversion<MulOp, arith::MulIOp>,
716 BroadcastOpConverter,
717 ConstShapeOpConverter,
718 ConstSizeOpConversion,
719 DimOpConverter,
720 IsBroadcastableOpConverter,
721 GetExtentOpConverter,
722 RankOpConverter,
723 ReduceOpConverter,
724 ShapeEqOpConverter,
725 ShapeOfOpConversion,
726 SplitAtOpConversion,
727 ToExtentTensorOpConversion>(patterns.getContext());
728 // clang-format on
729}
for(Operation *op :ops)
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
#define add(a, b)
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Definition Shape.cpp:40
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateShapeToStandardConversionPatterns(RewritePatternSet &patterns)
@ AnyOp
No restrictions wrt. which ops are processed.