MLIR  22.0.0git
ShapeToStandard.cpp
Go to the documentation of this file.
1 //===- ShapeToStandard.cpp - conversion from Shape to Standard dialect ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/IRMapping.h"
17 #include "mlir/Pass/Pass.h"
19 #include "llvm/ADT/STLExtras.h"
20 
21 namespace mlir {
22 #define GEN_PASS_DEF_CONVERTSHAPETOSTANDARDPASS
23 #include "mlir/Conversion/Passes.h.inc"
24 } // namespace mlir
25 
26 using namespace mlir;
27 using namespace mlir::shape;
28 using namespace mlir::scf;
29 
30 /// Conversion patterns.
31 namespace {
32 class AnyOpConversion : public OpConversionPattern<AnyOp> {
33 public:
35 
36  LogicalResult
37  matchAndRewrite(AnyOp op, OpAdaptor adaptor,
38  ConversionPatternRewriter &rewriter) const override;
39 };
40 } // namespace
41 
42 LogicalResult
43 AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor,
44  ConversionPatternRewriter &rewriter) const {
45  // Replace `any` with its first operand.
46  // Any operand would be a valid substitution.
47  rewriter.replaceOp(op, {adaptor.getInputs().front()});
48  return success();
49 }
50 
51 namespace {
52 template <typename SrcOpTy, typename DstOpTy>
53 class BinaryOpConversion : public OpConversionPattern<SrcOpTy> {
54 public:
56 
57  LogicalResult
58  matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor,
59  ConversionPatternRewriter &rewriter) const override {
60  // For now, only error-free types are supported by this lowering.
61  if (isa<SizeType>(op.getType()))
62  return failure();
63 
64  rewriter.replaceOpWithNewOp<DstOpTy>(op, adaptor.getLhs(),
65  adaptor.getRhs());
66  return success();
67  }
68 };
69 } // namespace
70 
71 namespace {
72 struct BroadcastOpConverter : public OpConversionPattern<BroadcastOp> {
74 
75  LogicalResult
76  matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
77  ConversionPatternRewriter &rewriter) const override;
78 };
79 
80 // Get the resulting extent in a given dimension. This is computed with any
81 // number of extent tensors and shifted offsets into them.
82 Value getBroadcastedDim(ImplicitLocOpBuilder lb, ValueRange extentTensors,
83  ValueRange rankDiffs, Value outputDimension) {
85  Value broadcastedDim = one;
86  for (auto tup : llvm::zip(extentTensors, rankDiffs)) {
87  Value shape = std::get<0>(tup);
88  Value rankDiff = std::get<1>(tup);
89  Value outOfBounds = arith::CmpIOp::create(lb, arith::CmpIPredicate::ult,
90  outputDimension, rankDiff);
91  Type indexTy = lb.getIndexType();
92  broadcastedDim =
93  IfOp::create(
94  lb, outOfBounds,
95  [&](OpBuilder &b, Location loc) {
96  scf::YieldOp::create(b, loc, broadcastedDim);
97  },
98  [&](OpBuilder &b, Location loc) {
99  // The broadcasting logic is:
100  // - if one extent (here we arbitrarily choose the
101  // extent from the greater-rank operand) is equal to 1,
102  // then take the extent from the other operand
103  // - otherwise, take the extent as-is.
104  // Note that this logic remains correct in the presence
105  // of dimensions of zero extent.
106  Value lesserRankOperandDimension = arith::SubIOp::create(
107  b, loc, indexTy, outputDimension, rankDiff);
108  Value lesserRankOperandExtent = tensor::ExtractOp::create(
109  b, loc, shape, ValueRange{lesserRankOperandDimension});
110 
111  Value dimIsOne =
112  arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq,
113  lesserRankOperandExtent, one);
114  Value dim = arith::SelectOp::create(
115  b, loc, dimIsOne, broadcastedDim, lesserRankOperandExtent);
116  scf::YieldOp::create(b, loc, dim);
117  })
118  .getResult(0);
119  }
120  return broadcastedDim;
121 }
122 } // namespace
123 
124 LogicalResult BroadcastOpConverter::matchAndRewrite(
125  BroadcastOp op, OpAdaptor adaptor,
126  ConversionPatternRewriter &rewriter) const {
127  // For now, this lowering is only defined on `tensor<?xindex>` operands, not
128  // on shapes.
129  if (isa<ShapeType>(op.getType()))
130  return failure();
131 
132  auto loc = op.getLoc();
133  ImplicitLocOpBuilder lb(loc, rewriter);
134 
136  Type indexTy = lb.getIndexType();
137 
138  // Save all the ranks for bounds checking. Because this is a tensor
139  // representing the shape extents, the rank is the extent of the only
140  // dimension in the tensor.
141  SmallVector<Value> ranks, rankDiffs;
142  llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
143  return tensor::DimOp::create(lb, v, zero);
144  }));
145 
146  // Find the maximum rank
147  Value maxRank = ranks.front();
148  for (Value v : llvm::drop_begin(ranks, 1)) {
149  maxRank = arith::MaxUIOp::create(lb, v, maxRank);
150  }
151 
152  // Calculate the difference of ranks and the maximum rank for later offsets.
153  llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
154  return arith::SubIOp::create(lb, indexTy, maxRank, v);
155  }));
156 
157  Value replacement = tensor::GenerateOp::create(
158  lb, getExtentTensorType(lb.getContext()), ValueRange{maxRank},
159  [&](OpBuilder &b, Location loc, ValueRange args) {
160  Value broadcastedDim =
161  getBroadcastedDim(ImplicitLocOpBuilder(loc, b), adaptor.getShapes(),
162  rankDiffs, args[0]);
163 
164  tensor::YieldOp::create(b, loc, broadcastedDim);
165  });
166  if (replacement.getType() != op.getType())
167  replacement = tensor::CastOp::create(lb, op.getType(), replacement);
168  rewriter.replaceOp(op, replacement);
169  return success();
170 }
171 
172 namespace {
173 class ConstShapeOpConverter : public OpConversionPattern<ConstShapeOp> {
174 public:
176 
177  LogicalResult
178  matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor,
179  ConversionPatternRewriter &rewriter) const override;
180 };
181 } // namespace
182 
183 LogicalResult ConstShapeOpConverter::matchAndRewrite(
184  ConstShapeOp op, OpAdaptor adaptor,
185  ConversionPatternRewriter &rewriter) const {
186 
187  // For now, this lowering supports only extent tensors, not `shape.shape`
188  // types.
189  if (isa<ShapeType>(op.getType()))
190  return failure();
191 
192  auto loc = op.getLoc();
193  SmallVector<Value, 4> extentOperands;
194  for (auto extent : op.getShape()) {
195  extentOperands.push_back(arith::ConstantIndexOp::create(
196  rewriter, loc, extent.getLimitedValue()));
197  }
198  Type resultTy =
199  RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType());
200  Value tensor =
201  tensor::FromElementsOp::create(rewriter, loc, resultTy, extentOperands);
202  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, tensor);
203  return success();
204 }
205 
206 namespace {
207 class ConstSizeOpConversion : public OpConversionPattern<ConstSizeOp> {
208 public:
210 
211  LogicalResult
212  matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor,
213  ConversionPatternRewriter &rewriter) const override;
214 };
215 } // namespace
216 
217 LogicalResult ConstSizeOpConversion::matchAndRewrite(
218  ConstSizeOp op, OpAdaptor adaptor,
219  ConversionPatternRewriter &rewriter) const {
220  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(
221  op, op.getValue().getSExtValue());
222  return success();
223 }
224 
225 namespace {
226 struct IsBroadcastableOpConverter
227  : public OpConversionPattern<IsBroadcastableOp> {
229 
230  LogicalResult
231  matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor,
232  ConversionPatternRewriter &rewriter) const override;
233 };
234 } // namespace
235 
236 LogicalResult IsBroadcastableOpConverter::matchAndRewrite(
237  IsBroadcastableOp op, OpAdaptor adaptor,
238  ConversionPatternRewriter &rewriter) const {
239  // For now, this lowering is only defined on `tensor<?xindex>` operands, not
240  // on shapes.
241  if (!llvm::all_of(op.getShapes(),
242  [](Value v) { return !isa<ShapeType>(v.getType()); }))
243  return failure();
244 
245  auto loc = op.getLoc();
246  ImplicitLocOpBuilder lb(loc, rewriter);
249  Type indexTy = lb.getIndexType();
250 
251  // Save all the ranks for bounds checking. Because this is a tensor
252  // representing the shape extents, the rank is the extent of the only
253  // dimension in the tensor.
254  SmallVector<Value> ranks, rankDiffs;
255  llvm::append_range(ranks, llvm::map_range(adaptor.getShapes(), [&](Value v) {
256  return tensor::DimOp::create(lb, v, zero);
257  }));
258 
259  // Find the maximum rank
260  Value maxRank = ranks.front();
261  for (Value v : llvm::drop_begin(ranks, 1)) {
262  maxRank = arith::MaxUIOp::create(lb, v, maxRank);
263  }
264 
265  // Calculate the difference of ranks and the maximum rank for later offsets.
266  llvm::append_range(rankDiffs, llvm::map_range(ranks, [&](Value v) {
267  return arith::SubIOp::create(lb, indexTy, maxRank, v);
268  }));
269 
270  Type i1Ty = rewriter.getI1Type();
271  Value trueVal = arith::ConstantOp::create(rewriter, loc, i1Ty,
272  rewriter.getBoolAttr(true));
273 
274  auto reduceResult = ForOp::create(
275  lb, loc, zero, maxRank, one, ValueRange{trueVal},
276  [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
277  // Find a non-1 dim, if it exists. Note that the first part of this
278  // could reuse the Broadcast lowering entirely, but we redo the work
279  // here to make optimizations easier between the two loops.
280  Value broadcastedDim = getBroadcastedDim(
281  ImplicitLocOpBuilder(loc, b), adaptor.getShapes(), rankDiffs, iv);
282 
283  Value broadcastable = iterArgs[0];
284  for (auto tup : llvm::zip(adaptor.getShapes(), rankDiffs)) {
285  Value shape, rankDiff;
286  std::tie(shape, rankDiff) = tup;
287  Value outOfBounds = arith::CmpIOp::create(
288  b, loc, arith::CmpIPredicate::ult, iv, rankDiff);
289  broadcastable =
290  IfOp::create(
291  b, loc, outOfBounds,
292  [&](OpBuilder &b, Location loc) {
293  // Non existent dimensions are always broadcastable
294  scf::YieldOp::create(b, loc, broadcastable);
295  },
296  [&](OpBuilder &b, Location loc) {
297  // Every value needs to be either 1, or the same non-1
298  // value to be broadcastable in this dim.
299  Value operandDimension =
300  arith::SubIOp::create(b, loc, indexTy, iv, rankDiff);
301  Value dimensionExtent = tensor::ExtractOp::create(
302  b, loc, shape, ValueRange{operandDimension});
303 
304  Value equalOne = arith::CmpIOp::create(
305  b, loc, arith::CmpIPredicate::eq, dimensionExtent, one);
306  Value equalBroadcasted =
307  arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq,
308  dimensionExtent, broadcastedDim);
309  Value result = arith::AndIOp::create(
310  b, loc, broadcastable,
311  arith::OrIOp::create(b, loc, equalOne,
312  equalBroadcasted));
313  scf::YieldOp::create(b, loc, result);
314  })
315  .getResult(0);
316  }
317 
318  scf::YieldOp::create(b, loc, broadcastable);
319  });
320 
321  rewriter.replaceOp(op, reduceResult.getResults().front());
322  return success();
323 }
324 
325 namespace {
326 class DimOpConverter : public OpConversionPattern<DimOp> {
328 
329  LogicalResult
330  matchAndRewrite(DimOp op, OpAdaptor adaptor,
331  ConversionPatternRewriter &rewriter) const override;
332 };
333 } // namespace
334 
335 LogicalResult
336 DimOpConverter::matchAndRewrite(DimOp op, OpAdaptor adaptor,
337  ConversionPatternRewriter &rewriter) const {
338  // Lower to dim(X, i) to get_extent(shape_of(X), i) and rely on further
339  // lowerings. This can be further optimized if needed to avoid intermediate
340  // steps.
341  auto shapeOf = shape::ShapeOfOp::create(rewriter, op.getLoc(), op.getValue());
342  rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
343  op.getIndex());
344  return success();
345 }
346 
347 namespace {
348 class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
350 
351  LogicalResult
352  matchAndRewrite(GetExtentOp op, OpAdaptor adaptor,
353  ConversionPatternRewriter &rewriter) const override;
354 };
355 } // namespace
356 
357 LogicalResult GetExtentOpConverter::matchAndRewrite(
358  GetExtentOp op, OpAdaptor adaptor,
359  ConversionPatternRewriter &rewriter) const {
360  // For now, only error-free types are supported by this lowering.
361  if (isa<SizeType>(op.getType()))
362  return failure();
363 
364  // Derive shape extent directly from shape origin if possible. This
365  // circumvents the necessity to materialize the shape in memory.
366  if (auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>()) {
367  if (isa<ShapedType>(shapeOfOp.getArg().getType())) {
368  rewriter.replaceOpWithNewOp<tensor::DimOp>(op, shapeOfOp.getArg(),
369  adaptor.getDim());
370  return success();
371  }
372  }
373 
374  rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, rewriter.getIndexType(),
375  adaptor.getShape(),
376  ValueRange{adaptor.getDim()});
377  return success();
378 }
379 
380 namespace {
381 class RankOpConverter : public OpConversionPattern<shape::RankOp> {
382 public:
384 
385  LogicalResult
386  matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
387  ConversionPatternRewriter &rewriter) const override;
388 };
389 } // namespace
390 
391 LogicalResult
392 RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor,
393  ConversionPatternRewriter &rewriter) const {
394  // For now, this lowering supports only error-free types.
395  if (isa<SizeType>(op.getType()))
396  return failure();
397 
398  rewriter.replaceOpWithNewOp<tensor::DimOp>(op, adaptor.getShape(), 0);
399  return success();
400 }
401 
402 namespace {
403 /// Converts `shape.reduce` to `scf.for`.
404 struct ReduceOpConverter : public OpConversionPattern<shape::ReduceOp> {
405 public:
407 
408  LogicalResult
409  matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
410  ConversionPatternRewriter &rewriter) const final;
411 };
412 } // namespace
413 
414 LogicalResult
415 ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor,
416  ConversionPatternRewriter &rewriter) const {
417  // For now, this lowering is only defined on `tensor<?xindex>` operands.
418  if (isa<ShapeType>(op.getShape().getType()))
419  return failure();
420 
421  auto loc = op.getLoc();
422 
423  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
424  Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
425  Type indexTy = rewriter.getIndexType();
426  Value rank =
427  tensor::DimOp::create(rewriter, loc, indexTy, adaptor.getShape(), zero);
428 
429  auto loop = scf::ForOp::create(
430  rewriter, loc, zero, rank, one, op.getInitVals(),
431  [&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
432  Value extent =
433  tensor::ExtractOp::create(b, loc, adaptor.getShape(), iv);
434 
435  SmallVector<Value, 2> mappedValues{iv, extent};
436  mappedValues.append(args.begin(), args.end());
437 
438  IRMapping mapping;
439  Block *reduceBody = op.getBody();
440  mapping.map(reduceBody->getArguments(), mappedValues);
441  for (auto &nested : reduceBody->without_terminator())
442  b.clone(nested, mapping);
443 
444  SmallVector<Value, 2> mappedResults;
445  for (auto result : reduceBody->getTerminator()->getOperands())
446  mappedResults.push_back(mapping.lookup(result));
447  scf::YieldOp::create(b, loc, mappedResults);
448  });
449 
450  rewriter.replaceOp(op, loop.getResults());
451  return success();
452 }
453 
454 namespace {
455 /// Converts `shape.shape_eq` to an `scf.for` loop. For now, the lowering is
456 /// only defined on `tensor<?xindex>` operands. The test for equality first
457 /// compares their size and, if equal, checks every extent for equality.
458 ///
459 /// Example:
460 ///
461 /// %result = shape.shape_eq %a, %b : tensor<?xindex>, tensor<?xindex>
462 ///
463 /// becomes
464 ///
465 /// %c0 = arith.constant 0 : index
466 /// %0 = dim %arg0, %c0 : tensor<?xindex>
467 /// %1 = dim %arg1, %c0 : tensor<?xindex>
468 /// %2 = arith.cmpi "eq", %0, %1 : index
469 /// %result = scf.if %2 -> (i1) {
470 /// %c1 = arith.constant 1 : index
471 /// %true = arith.constant true
472 /// %4 = scf.for %arg2 = %c0 to %0 step %c1 iter_args(%arg3 = %true) -> (i1) {
473 /// %5 = tensor.extract %arg0[%arg2] : tensor<?xindex>
474 /// %6 = tensor.extract %arg1[%arg2] : tensor<?xindex>
475 /// %7 = arith.cmpi "eq", %5, %6 : index
476 /// %8 = arith.andi %arg3, %7 : i1
477 /// scf.yield %8 : i1
478 /// }
479 /// scf.yield %4 : i1
480 /// } else {
481 /// %false = arith.constant false
482 /// scf.yield %false : i1
483 /// }
484 ///
485 struct ShapeEqOpConverter : public OpConversionPattern<ShapeEqOp> {
487 
488  LogicalResult
489  matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
490  ConversionPatternRewriter &rewriter) const override;
491 };
492 } // namespace
493 
494 LogicalResult
495 ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor,
496  ConversionPatternRewriter &rewriter) const {
497  if (!llvm::all_of(op.getShapes(),
498  [](Value v) { return !isa<ShapeType>(v.getType()); }))
499  return failure();
500 
501  Type i1Ty = rewriter.getI1Type();
502  if (op.getShapes().size() <= 1) {
503  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, i1Ty,
504  rewriter.getBoolAttr(true));
505  return success();
506  }
507 
508  auto loc = op.getLoc();
509  Type indexTy = rewriter.getIndexType();
510  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
511  Value firstShape = adaptor.getShapes().front();
512  Value firstRank =
513  tensor::DimOp::create(rewriter, loc, indexTy, firstShape, zero);
514  Value result = nullptr;
515  // Generate a linear sequence of compares, all with firstShape as lhs.
516  for (Value shape : adaptor.getShapes().drop_front(1)) {
517  Value rank = tensor::DimOp::create(rewriter, loc, indexTy, shape, zero);
518  Value eqRank = arith::CmpIOp::create(
519  rewriter, loc, arith::CmpIPredicate::eq, firstRank, rank);
520  auto same = IfOp::create(
521  rewriter, loc, eqRank,
522  [&](OpBuilder &b, Location loc) {
523  Value one = arith::ConstantIndexOp::create(b, loc, 1);
524  Value init =
525  arith::ConstantOp::create(b, loc, i1Ty, b.getBoolAttr(true));
526  auto loop = scf::ForOp::create(
527  b, loc, zero, firstRank, one, ValueRange{init},
528  [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) {
529  Value conj = args[0];
530  Value lhsExtent =
531  tensor::ExtractOp::create(b, loc, firstShape, iv);
532  Value rhsExtent = tensor::ExtractOp::create(b, loc, shape, iv);
533  Value eqExtent = arith::CmpIOp::create(
534  b, loc, arith::CmpIPredicate::eq, lhsExtent, rhsExtent);
535  Value conjNext = arith::AndIOp::create(b, loc, conj, eqExtent);
536  scf::YieldOp::create(b, loc, ValueRange({conjNext}));
537  });
538  scf::YieldOp::create(b, loc, loop.getResults());
539  },
540  [&](OpBuilder &b, Location loc) {
541  Value result =
542  arith::ConstantOp::create(b, loc, i1Ty, b.getBoolAttr(false));
543  scf::YieldOp::create(b, loc, result);
544  });
545  result = !result ? same.getResult(0)
546  : arith::AndIOp::create(rewriter, loc, result,
547  same.getResult(0));
548  }
549  rewriter.replaceOp(op, result);
550  return success();
551 }
552 
553 namespace {
554 class ShapeOfOpConversion : public OpConversionPattern<ShapeOfOp> {
555 public:
557 
558  LogicalResult
559  matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor,
560  ConversionPatternRewriter &rewriter) const override;
561 };
562 } // namespace
563 
564 LogicalResult ShapeOfOpConversion::matchAndRewrite(
565  ShapeOfOp op, OpAdaptor adaptor,
566  ConversionPatternRewriter &rewriter) const {
567 
568  // For now, only error-free types are supported by this lowering.
569  if (isa<ShapeType>(op.getType()))
570  return failure();
571 
572  // For ranked tensor arguments, lower to `tensor.from_elements`.
573  auto loc = op.getLoc();
574  Value tensor = adaptor.getArg();
575  Type tensorTy = tensor.getType();
576  if (isa<RankedTensorType>(tensorTy)) {
577 
578  // Build values for individual extents.
579  SmallVector<Value, 8> extentValues;
580  RankedTensorType rankedTensorTy = cast<RankedTensorType>(tensorTy);
581  int64_t rank = rankedTensorTy.getRank();
582  for (int64_t i = 0; i < rank; i++) {
583  if (rankedTensorTy.isDynamicDim(i)) {
584  Value extent = tensor::DimOp::create(rewriter, loc, tensor, i);
585  extentValues.push_back(extent);
586  } else {
588  rewriter, loc, rankedTensorTy.getDimSize(i));
589  extentValues.push_back(extent);
590  }
591  }
592 
593  // Materialize extent tensor.
594  Value staticExtentTensor = tensor::FromElementsOp::create(
595  rewriter, loc, RankedTensorType::get({rank}, rewriter.getIndexType()),
596  extentValues);
597  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
598  staticExtentTensor);
599  return success();
600  }
601 
602  // Lower to `tensor.generate` otherwise.
603  auto *ctx = rewriter.getContext();
604  Value rank = tensor::RankOp::create(rewriter, loc, tensor);
605  rewriter.replaceOpWithNewOp<tensor::GenerateOp>(
606  op, getExtentTensorType(ctx), ValueRange{rank},
607  [&](OpBuilder &b, Location loc, ValueRange args) {
608  Value dim = args.front();
609  Value extent = tensor::DimOp::create(b, loc, tensor, dim);
610  tensor::YieldOp::create(b, loc, extent);
611  });
612 
613  return success();
614 }
615 
616 namespace {
617 class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
618 public:
620 
621  LogicalResult
622  matchAndRewrite(SplitAtOp op, OpAdaptor adaptor,
623  ConversionPatternRewriter &rewriter) const override;
624 };
625 } // namespace
626 
627 LogicalResult SplitAtOpConversion::matchAndRewrite(
628  SplitAtOp op, OpAdaptor adaptor,
629  ConversionPatternRewriter &rewriter) const {
630  // Error conditions are not implemented, only lower if all operands and
631  // results are extent tensors.
632  if (llvm::any_of(ValueRange{op.getOperand(), op.getHead(), op.getTail()},
633  [](Value v) { return isa<ShapeType>(v.getType()); }))
634  return failure();
635 
636  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
638  Value rank = tensor::DimOp::create(b, adaptor.getOperand(), zero);
639 
640  // index < 0 ? index + rank : index
641  Value originalIndex = adaptor.getIndex();
642  Value add = arith::AddIOp::create(b, originalIndex, rank);
643  Value indexIsNegative =
644  arith::CmpIOp::create(b, arith::CmpIPredicate::slt, originalIndex, zero);
645  Value index = arith::SelectOp::create(b, indexIsNegative, add, originalIndex);
646 
648  Value head =
649  tensor::ExtractSliceOp::create(b, adaptor.getOperand(), zero, index, one);
650  Value tailSize = arith::SubIOp::create(b, rank, index);
651  Value tail = tensor::ExtractSliceOp::create(b, adaptor.getOperand(), index,
652  tailSize, one);
653  rewriter.replaceOp(op, {head, tail});
654  return success();
655 }
656 
657 namespace {
658 class ToExtentTensorOpConversion
659  : public OpConversionPattern<ToExtentTensorOp> {
660 public:
662 
663  LogicalResult
664  matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor,
665  ConversionPatternRewriter &rewriter) const override {
666  if (!isa<RankedTensorType>(adaptor.getInput().getType()))
667  return rewriter.notifyMatchFailure(op, "input needs to be a tensor");
668 
669  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
670  adaptor.getInput());
671  return success();
672  }
673 };
674 } // namespace
675 
676 namespace {
677 /// Import the Shape Ops to Std Patterns.
678 #include "ShapeToStandard.cpp.inc"
679 } // namespace
680 
681 namespace {
682 /// Conversion pass.
683 class ConvertShapeToStandardPass
684  : public impl::ConvertShapeToStandardPassBase<ConvertShapeToStandardPass> {
685 
686  void runOnOperation() override;
687 };
688 } // namespace
689 
690 void ConvertShapeToStandardPass::runOnOperation() {
691  // Setup target legality.
692  MLIRContext &ctx = getContext();
693  ConversionTarget target(ctx);
694  target.addLegalDialect<arith::ArithDialect, SCFDialect,
695  tensor::TensorDialect>();
696  target.addLegalOp<CstrRequireOp, func::FuncOp, ModuleOp>();
697 
698  // Setup conversion patterns.
701 
702  // Apply conversion.
703  auto module = getOperation();
704  if (failed(applyPartialConversion(module, target, std::move(patterns))))
705  signalPassFailure();
706 }
707 
710  // clang-format off
711  populateWithGenerated(patterns);
712  patterns.add<
713  AnyOpConversion,
714  BinaryOpConversion<AddOp, arith::AddIOp>,
715  BinaryOpConversion<MulOp, arith::MulIOp>,
716  BroadcastOpConverter,
717  ConstShapeOpConverter,
718  ConstSizeOpConversion,
719  DimOpConverter,
720  IsBroadcastableOpConverter,
721  GetExtentOpConverter,
722  RankOpConverter,
723  ReduceOpConverter,
724  ShapeEqOpConverter,
725  ShapeOfOpConversion,
726  SplitAtOpConversion,
727  ToExtentTensorOpConversion>(patterns.getContext());
728  // clang-format on
729 }
static MLIRContext * getContext(OpFoldResult val)
Block represents an ordered list of Operations.
Definition: Block.h:33
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:99
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:623
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:552
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:463
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Definition: Shape.cpp:40
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void populateShapeToStandardConversionPatterns(RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
@ AnyOp
No restrictions wrt. which ops are processed.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.