MLIR  20.0.0git
ShapeToStandard.cpp
Go to the documentation of this file.
1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/IRMapping.h"
18 #include "mlir/Pass/Pass.h"
20 #include "llvm/ADT/STLExtras.h"
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_CONVERTSHAPETOSTANDARD
24 #include "mlir/Conversion/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 using namespace mlir::shape;
29 using namespace mlir::scf;
30 
31 /// Conversion patterns.
32 namespace {
33 class AnyOpConversion : public OpConversionPattern<AnyOp> {
34 public:
36 
37  LogicalResult
38  matchAndRewrite(AnyOp op, OpAdaptor adaptor,
39  ConversionPatternRewriter &rewriter) const override;
40 };
41 } // namespace
42 
43 LogicalResult
44 AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
45  ConversionPatternRewriter &rewriter) const {
46  // Replace `any` with its first operand.
47  // Any operand would be a valid substitution.
48  rewriter.replaceOp(op, {adaptor.getInputs().front()});
49  return success();
50 }
51 
52 namespace {
53 template <typename SrcOpTy, typename DstOpTy>
54 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
55 public:
57 
58  LogicalResult
59  matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor,
60  ConversionPatternRewriter &rewriter) const override {
61  // For now, only error-free types are supported by this lowering.
62  if (isa<SizeType>(op.getType()))
63  return failure();
64 
65  rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(),
66  adaptor.getRhs());
67  return success();
68  }
69 };
70 } // namespace
71 
72 namespace {
73 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
75 
76  LogicalResult
77  matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
78  ConversionPatternRewriter &rewriter) const override;
79 };
80 
81 // Get the resulting extent in a given dimension. This is computed with any
82 // number of extent tensors and shifted offsets into them.
83 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
84  ValueRange rankDiffs, Value outputDimension) {
85  Value one = lb.create<arith::ConstantIndexOp>(1);
86  Value broadcastedDim = one;
87  for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
88  Value shape = std::get<0>(tup);
89  Value rankDiff = std::get<1>(tup);
90  Value outOfBounds = lb.create<arith::CmpIOp>(arith::CmpIPredicate::ult,
91  outputDimension, rankDiff);
92  Type indexTy = lb.getIndexType();
93  broadcastedDim =
94  lb.create<IfOp>(
95  outOfBounds,
96  [&](OpBuilder &b, Location loc) {
97  b.create<scf::YieldOp>(loc, broadcastedDim);
98  },
99  [&](OpBuilder &b, Location loc) {
100  // The broadcasting logic is:
101  // - if one extent (here we arbitrarily choose the
102  // extent from the greater-rank operand) is equal to 1,
103  // then take the extent from the other operand
104  // - otherwise, take the extent as-is.
105  // Note that this logic remains correct in the presence
106  // of dimensions of zero extent.
107  Value lesserRankOperandDimension = b.create<arith::SubIOp>(
108  loc, indexTy, outputDimension, rankDiff);
109  Value lesserRankOperandExtent = b.create<tensor::ExtractOp>(
110  loc, shape, ValueRange{lesserRankOperandDimension});
111 
112  Value dimIsOne =
113  b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
114  lesserRankOperandExtent, one);
115  Value dim = b.create<arith::SelectOp>(
116  loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
117  b.create<scf::YieldOp>(loc, dim);
118  })
119  .getResult(0);
120  }
121  return broadcastedDim;
122 }
123 } // namespace
124 
125 LogicalResult BroadcastOpConverter::matchAndRewrite(
126  BroadcastOp op, OpAdaptor adaptor,
127  ConversionPatternRewriter &rewriter) const {
128  // For now, this lowering is only defined on `tensor<?xindex>` operands, not
129  // on shapes.
130  if (isa<ShapeType>(op.getType()))
131  return failure();
132 
133  auto loc = op.getLoc();
134  ImplicitLocOpBuilder lb(loc, rewriter);
135 
136  Value zero = lb.create<arith::ConstantIndexOp>(0);
137  Type indexTy = lb.getIndexType();
138 
139  // Save all the ranks for bounds checking. Because this is a tensor
140  // representing the shape extents, the rank is the extent of the only
141  // dimension in the tensor.
142  SmallVector<Value> ranks, rankDiffs;
143  llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
144  return lb.create<tensor::DimOp>(v, zero);
145  }));
146 
147  // Find the maximum rank
148  Value maxRank = ranks.front();
149  for (Value v : llvm::drop_begin(ranks, 1)) {
150  maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
151  }
152 
153  // Calculate the difference of ranks and the maximum rank for later offsets.
154  llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
155  return lb.create<arith::SubIOp>(indexTy, maxRank, v);
156  }));
157 
158  Value replacement = lb.create<tensor::GenerateOp>(
159  getExtentTensorType(lb.getContext()), ValueRange{maxRank},
160  [&](OpBuilder &b, Location loc, ValueRange args) {
161  Value broadcastedDim =
162  getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
163  rankDiffs, args[0]);
164 
165  b.create<tensor::YieldOp>(loc, broadcastedDim);
166  });
167  if (replacement.getType() != op.getType())
168  replacement = lb.create<tensor::CastOp>(op.getType(), replacement);
169  rewriter.replaceOp(op, replacement);
170  return success();
171 }
172 
173 namespace {
174 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
175 public:
177 
178  LogicalResult
179  matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
180  ConversionPatternRewriter &rewriter) const override;
181 };
182 } // namespace
183 
184 LogicalResult ConstShapeOpConverter::matchAndRewrite(
185  ConstShapeOp op, OpAdaptor adaptor,
186  ConversionPatternRewriter &rewriter) const {
187 
188  // For now, this lowering supports only extent tensors, not `shape.shape`
189  // types.
190  if (isa<ShapeType>(op.getType()))
191  return failure();
192 
193  auto loc = op.getLoc();
194  SmallVector<Value, 4> extentOperands;
195  for (auto extent : op.getShape()) {
196  extentOperands.push_back(
197  rewriter.create<arith::ConstantIndexOp>(loc, extent.getLimitedValue()));
198  }
199  Type resultTy =
200  RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
201  Value tensor =
202  rewriter.create<tensor::FromElementsOp>(loc, resultTy, extentOperands);
203  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
204  return success();
205 }
206 
207 namespace {
208 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
209 public:
211 
212  LogicalResult
213  matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
214  ConversionPatternRewriter &rewriter) const override;
215 };
216 } // namespace
217 
218 LogicalResult ConstSizeOpConversion::matchAndRewrite(
219  ConstSizeOp op, OpAdaptor adaptor,
220  ConversionPatternRewriter &rewriter) const {
221  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
222  op, op.getValue().getSExtValue());
223  return success();
224 }
225 
226 namespace {
227 struct IsBroadcastableOpConverter
228  : public OpConversionPattern<IsBroadcastableOp> {
230 
231  LogicalResult
232  matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
233  ConversionPatternRewriter &rewriter) const override;
234 };
235 } // namespace
236 
237 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
238  IsBroadcastableOp op, OpAdaptor adaptor,
239  ConversionPatternRewriter &rewriter) const {
240  // For now, this lowering is only defined on `tensor<?xindex>` operands, not
241  // on shapes.
242  if (!llvm::all_of(op.getShapes(),
243  [](Value v) { return !isa<ShapeType>(v.getType()); }))
244  return failure();
245 
246  auto loc = op.getLoc();
247  ImplicitLocOpBuilder lb(loc, rewriter);
248  Value zero = lb.create<arith::ConstantIndexOp>(0);
249  Value one = lb.create<arith::ConstantIndexOp>(1);
250  Type indexTy = lb.getIndexType();
251 
252  // Save all the ranks for bounds checking. Because this is a tensor
253  // representing the shape extents, the rank is the extent of the only
254  // dimension in the tensor.
255  SmallVector<Value> ranks, rankDiffs;
256  llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
257  return lb.create<tensor::DimOp>(v, zero);
258  }));
259 
260  // Find the maximum rank
261  Value maxRank = ranks.front();
262  for (Value v : llvm::drop_begin(ranks, 1)) {
263  maxRank = lb.create<arith::MaxUIOp>(v, maxRank);
264  }
265 
266  // Calculate the difference of ranks and the maximum rank for later offsets.
267  llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
268  return lb.create<arith::SubIOp>(indexTy, maxRank, v);
269  }));
270 
271  Type i1Ty = rewriter.getI1Type();
272  Value trueVal =
273  rewriter.create<arith::ConstantOp>(loc, i1Ty, rewriter.getBoolAttr(true));
274 
275  auto reduceResult = lb.create<ForOp>(
276  loc, zero, maxRank, one, ValueRange{trueVal},
277  [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
278  // Find a non-1 dim, if it exists. Note that the first part of this
279  // could reuse the Broadcast lowering entirely, but we redo the work
280  // here to make optimizations easier between the two loops.
281  Value broadcastedDim = getBroadcastedDim(
282  ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv);
283 
284  Value broadcastable = iterArgs[0];
285  for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
286  Value shape, rankDiff;
287  std::tie(shape, rankDiff) = tup;
288  Value outOfBounds = b.create<arith::CmpIOp>(
289  loc, arith::CmpIPredicate::ult, iv, rankDiff);
290  broadcastable =
291  b.create<IfOp>(
292  loc, outOfBounds,
293  [&](OpBuilder &b, Location loc) {
294  // Non existent dimensions are always broadcastable
295  b.create<scf::YieldOp>(loc, broadcastable);
296  },
297  [&](OpBuilder &b, Location loc) {
298  // Every value needs to be either 1, or the same non-1
299  // value to be broadcastable in this dim.
300  Value operandDimension =
301  b.create<arith::SubIOp>(loc, indexTy, iv, rankDiff);
302  Value dimensionExtent = b.create<tensor::ExtractOp>(
303  loc, shape, ValueRange{operandDimension});
304 
305  Value equalOne = b.create<arith::CmpIOp>(
306  loc, arith::CmpIPredicate::eq, dimensionExtent, one);
307  Value equalBroadcasted = b.create<arith::CmpIOp>(
308  loc, arith::CmpIPredicate::eq, dimensionExtent,
309  broadcastedDim);
310  Value result = b.create<arith::AndIOp>(
311  loc, broadcastable,
312  b.create<arith::OrIOp>(loc, equalOne,
313  equalBroadcasted));
314  b.create<scf::YieldOp>(loc, result);
315  })
316  .getResult(0);
317  }
318 
319  b.create<scf::YieldOp>(loc, broadcastable);
320  });
321 
322  rewriter.replaceOp(op, reduceResult.getResults().front());
323  return success();
324 }
325 
326 namespace {
327 class DimOpConverter : public OpConversionPattern<DimOp> {
329 
330  LogicalResult
331  matchAndRewrite(DimOp op, OpAdaptor adaptor,
332  ConversionPatternRewriter &rewriter) const override;
333 };
334 } // namespace
335 
336 LogicalResult
337 DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
338  ConversionPatternRewriter &rewriter) const {
339  // Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further
340  // lowerings. This can be further optimized if needed to avoid intermediate
341  // steps.
342  auto shapeOf = rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getValue());
343  rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
344  op.getIndex());
345  return success();
346 }
347 
348 namespace {
349 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
351 
352  LogicalResult
353  matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
354  ConversionPatternRewriter &rewriter) const override;
355 };
356 } // namespace
357 
358 LogicalResult GetExtentOpConverter::matchAndRewrite(
359  GetExtentOp op, OpAdaptor adaptor,
360  ConversionPatternRewriter &rewriter) const {
361  // For now, only error-free types are supported by this lowering.
362  if (isa<SizeType>(op.getType()))
363  return failure();
364 
365  // Derive shape extent directly from shape origin if possible. This
366  // circumvents the necessity to materialize the shape in memory.
367  if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) {
368  if (isa<ShapedType>(shapeOfOp.getArg().getType())) {
369  rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(),
370  adaptor.getDim());
371  return success();
372  }
373  }
374 
375  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
376  adaptor.getShape(),
377  ValueRange{adaptor.getDim()});
378  return success();
379 }
380 
381 namespace {
382 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
383 public:
385 
386  LogicalResult
387  matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
388  ConversionPatternRewriter &rewriter) const override;
389 };
390 } // namespace
391 
392 LogicalResult
393 RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
394  ConversionPatternRewriter &rewriter) const {
395  // For now, this lowering supports only error-free types.
396  if (isa<SizeType>(op.getType()))
397  return failure();
398 
399  rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0);
400  return success();
401 }
402 
403 namespace {
404 /// Converts `shape.reduce` to `scf.for`.
405 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
406 public:
408 
409  LogicalResult
410  matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
411  ConversionPatternRewriter &rewriter) const final;
412 };
413 } // namespace
414 
415 LogicalResult
416 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
417  ConversionPatternRewriter &rewriter) const {
418  // For now, this lowering is only defined on `tensor<?xindex>` operands.
419  if (isa<ShapeType>(op.getShape().getType()))
420  return failure();
421 
422  auto loc = op.getLoc();
423 
424  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
425  Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
426  Type indexTy = rewriter.getIndexType();
427  Value rank =
428  rewriter.create<tensor::DimOp>(loc, indexTy, adaptor.getShape(), zero);
429 
430  auto loop = rewriter.create<scf::ForOp>(
431  loc, zero, rank, one, op.getInitVals(),
432  [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
433  Value extent = b.create<tensor::ExtractOp>(loc, adaptor.getShape(), iv);
434 
435  SmallVector<Value, 2> mappedValues{iv, extent};
436  mappedValues.append(args.begin(), args.end());
437 
438  IRMapping mapping;
439  Block *reduceBody = op.getBody();
440  mapping.map(reduceBody->getArguments(), mappedValues);
441  for (auto &nested : reduceBody->without_terminator())
442  b.clone(nested, mapping);
443 
444  SmallVector<Value, 2> mappedResults;
445  for (auto result : reduceBody->getTerminator()->getOperands())
446  mappedResults.push_back(mapping.lookup(result));
447  b.create<scf::YieldOp>(loc, mappedResults);
448  });
449 
450  rewriter.replaceOp(op, loop.getResults());
451  return success();
452 }
453 
454 namespace {
455 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
456 /// only defined on `tensor<?xindex>` operands. The test for equality first
457 /// compares their size and, if equal, checks every extent for equality.
458 ///
459 /// Example:
460 ///
461 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
462 ///
463 /// becomes
464 ///
465 /// %c0 = arith.constant 0 : index
466 /// %0 = dim %arg0, %c0 : tensor<?xindex>
467 /// %1 = dim %arg1, %c0 : tensor<?xindex>
468 /// %2 = arith.cmpi "eq", %0, %1 : index
469 /// %result = scf.if %2 -> (i1) {
470 /// %c1 = arith.constant 1 : index
471 /// %true = arith.constant true
472 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
473 /// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
474 /// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
475 /// %7 = arith.cmpi "eq", %5, %6 : index
476 /// %8 = arith.andi %arg3, %7 : i1
477 /// scf.yield %8 : i1
478 /// }
479 /// scf.yield %4 : i1
480 /// } else {
481 /// %false = arith.constant false
482 /// scf.yield %false : i1
483 /// }
484 ///
485 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
487 
488  LogicalResult
489  matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
490  ConversionPatternRewriter &rewriter) const override;
491 };
492 } // namespace
493 
494 LogicalResult
495 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
496  ConversionPatternRewriter &rewriter) const {
497  if (!llvm::all_of(op.getShapes(),
498  [](Value v) { return !isa<ShapeType>(v.getType()); }))
499  return failure();
500 
501  Type i1Ty = rewriter.getI1Type();
502  if (op.getShapes().size() <= 1) {
503  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty,
504  rewriter.getBoolAttr(true));
505  return success();
506  }
507 
508  auto loc = op.getLoc();
509  Type indexTy = rewriter.getIndexType();
510  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
511  Value firstShape = adaptor.getShapes().front();
512  Value firstRank =
513  rewriter.create<tensor::DimOp>(loc, indexTy, firstShape, zero);
514  Value result = nullptr;
515  // Generate a linear sequence of compares, all with firstShape as lhs.
516  for (Value shape : adaptor.getShapes().drop_front(1)) {
517  Value rank = rewriter.create<tensor::DimOp>(loc, indexTy, shape, zero);
518  Value eqRank = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
519  firstRank, rank);
520  auto same = rewriter.create<IfOp>(
521  loc, eqRank,
522  [&](OpBuilder &b, Location loc) {
523  Value one = b.create<arith::ConstantIndexOp>(loc, 1);
524  Value init =
525  b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(true));
526  auto loop = b.create<scf::ForOp>(
527  loc, zero, firstRank, one, ValueRange{init},
528  [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
529  Value conj = args[0];
530  Value lhsExtent =
531  b.create<tensor::ExtractOp>(loc, firstShape, iv);
532  Value rhsExtent = b.create<tensor::ExtractOp>(loc, shape, iv);
533  Value eqExtent = b.create<arith::CmpIOp>(
534  loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
535  Value conjNext = b.create<arith::AndIOp>(loc, conj, eqExtent);
536  b.create<scf::YieldOp>(loc, ValueRange({conjNext}));
537  });
538  b.create<scf::YieldOp>(loc, loop.getResults());
539  },
540  [&](OpBuilder &b, Location loc) {
541  Value result =
542  b.create<arith::ConstantOp>(loc, i1Ty, b.getBoolAttr(false));
543  b.create<scf::YieldOp>(loc, result);
544  });
545  result = !result ? same.getResult(0)
546  : rewriter.create<arith::AndIOp>(loc, result,
547  same.getResult(0));
548  }
549  rewriter.replaceOp(op, result);
550  return success();
551 }
552 
553 namespace {
554 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
555 public:
557 
558  LogicalResult
559  matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
560  ConversionPatternRewriter &rewriter) const override;
561 };
562 } // namespace
563 
564 LogicalResult ShapeOfOpConversion::matchAndRewrite(
565  ShapeOfOp op, OpAdaptor adaptor,
566  ConversionPatternRewriter &rewriter) const {
567 
568  // For now, only error-free types are supported by this lowering.
569  if (isa<ShapeType>(op.getType()))
570  return failure();
571 
572  // For ranked tensor arguments, lower to `tensor.from_elements`.
573  auto loc = op.getLoc();
574  Value tensor = adaptor.getArg();
575  Type tensorTy = tensor.getType();
576  if (isa<RankedTensorType>(tensorTy)) {
577 
578  // Build values for individual extents.
579  SmallVector<Value, 8> extentValues;
580  RankedTensorType rankedTensorTy = cast<RankedTensorType>(tensorTy);
581  int64_t rank = rankedTensorTy.getRank();
582  for (int64_t i = 0; i < rank; i++) {
583  if (rankedTensorTy.isDynamicDim(i)) {
584  Value extent = rewriter.create<tensor::DimOp>(loc, tensor, i);
585  extentValues.push_back(extent);
586  } else {
587  Value extent = rewriter.create<arith::ConstantIndexOp>(
588  loc, rankedTensorTy.getDimSize(i));
589  extentValues.push_back(extent);
590  }
591  }
592 
593  // Materialize extent tensor.
594  Value staticExtentTensor = rewriter.create<tensor::FromElementsOp>(
595  loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
596  extentValues);
597  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
598  staticExtentTensor);
599  return success();
600  }
601 
602  // Lower to `tensor.generate` otherwise.
603  auto *ctx = rewriter.getContext();
604  Value rank = rewriter.create<tensor::RankOp>(loc, tensor);
605  rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
606  op, getExtentTensorType(ctx), ValueRange{rank},
607  [&](OpBuilder &b, Location loc, ValueRange args) {
608  Value dim = args.front();
609  Value extent = b.create<tensor::DimOp>(loc, tensor, dim);
610  b.create<tensor::YieldOp>(loc, extent);
611  });
612 
613  return success();
614 }
615 
616 namespace {
617 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
618 public:
620 
621  LogicalResult
622  matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
623  ConversionPatternRewriter &rewriter) const override;
624 };
625 } // namespace
626 
627 LogicalResult SplitAtOpConversion::matchAndRewrite(
628  SplitAtOp op, OpAdaptor adaptor,
629  ConversionPatternRewriter &rewriter) const {
630  // Error conditions are not implemented, only lower if all operands and
631  // results are extent tensors.
632  if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()},
633  [](Value v) { return isa<ShapeType>(v.getType()); }))
634  return failure();
635 
636  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
637  Value zero = b.create<arith::ConstantIndexOp>(0);
638  Value rank = b.create<tensor::DimOp>(adaptor.getOperand(), zero);
639 
640  // index < 0 ? index + rank : index
641  Value originalIndex = adaptor.getIndex();
642  Value add = b.create<arith::AddIOp>(originalIndex, rank);
643  Value indexIsNegative =
644  b.create<arith::CmpIOp>(arith::CmpIPredicate::slt, originalIndex, zero);
645  Value index = b.create<arith::SelectOp>(indexIsNegative, add, originalIndex);
646 
647  Value one = b.create<arith::ConstantIndexOp>(1);
648  Value head =
649  b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), zero, index, one);
650  Value tailSize = b.create<arith::SubIOp>(rank, index);
651  Value tail = b.create<tensor::ExtractSliceOp>(adaptor.getOperand(), index,
652  tailSize, one);
653  rewriter.replaceOp(op, {head, tail});
654  return success();
655 }
656 
657 namespace {
658 class ToExtentTensorOpConversion
659  : public OpConversionPattern<ToExtentTensorOp> {
660 public:
662 
663  LogicalResult
664  matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
665  ConversionPatternRewriter &rewriter) const override {
666  if (!isa<RankedTensorType>(adaptor.getInput().getType()))
667  return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
668 
669  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
670  adaptor.getInput());
671  return success();
672  }
673 };
674 } // namespace
675 
676 namespace {
677 /// Import the Shape Ops to Std Patterns.
678 #include "ShapeToStandard.cpp.inc"
679 } // namespace
680 
681 namespace {
682 /// Conversion pass.
683 class ConvertShapeToStandardPass
684  : public impl::ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
685 
686  void runOnOperation() override;
687 };
688 } // namespace
689 
690 void ConvertShapeToStandardPass::runOnOperation() {
691  // Setup target legality.
692  MLIRContext &ctx = getContext();
693  ConversionTarget target(ctx);
694  target.addLegalDialect<arith::ArithDialect, SCFDialect,
695  tensor::TensorDialect>();
696  target.addLegalOp<CstrRequireOp, func::FuncOp, ModuleOp>();
697 
698  // Setup conversion patterns.
699  RewritePatternSet patterns(&ctx);
701 
702  // Apply conversion.
703  auto module = getOperation();
704  if (failed(applyPartialConversion(module, target, std::move(patterns))))
705  signalPassFailure();
706 }
707 
709  RewritePatternSet &patterns) {
710  // clang-format off
711  populateWithGenerated(patterns);
712  patterns.add<
713  AnyOpConversion,
714  BinaryOpConversion<AddOp, arith::AddIOp>,
715  BinaryOpConversion<MulOp, arith::MulIOp>,
716  BroadcastOpConverter,
717  ConstShapeOpConverter,
718  ConstSizeOpConversion,
719  DimOpConverter,
720  IsBroadcastableOpConverter,
721  GetExtentOpConverter,
722  RankOpConverter,
723  ReduceOpConverter,
724  ShapeEqOpConverter,
725  ShapeOfOpConversion,
726  SplitAtOpConversion,
727  ToExtentTensorOpConversion>(patterns.getContext());
728  // clang-format on
729 }
730 
731 std::unique_ptr<OperationPass<ModuleOp>>
733  return std::make_unique<ConvertShapeToStandardPass>();
734 }
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgListType getArguments()
Definition: Block.h:87
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
MLIRContext * getContext() const
Definition: Builders.h:55
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:215
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Definition: Shape.cpp:41
Include the generated interface declarations.
void populateShapeToStandardConversionPatterns(RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< OperationPass< ModuleOp > > createConvertShapeToStandardPass()
@ AnyOp
No restrictions wrt. which ops are processed.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.