MLIR  16.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
28 #include "mlir/IR/AffineMap.h"
29 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
33 
34 #include "llvm/ADT/DenseMap.h"
35 #include "llvm/ADT/SetVector.h"
36 #include "llvm/ADT/SmallSet.h"
37 #include "llvm/ADT/StringSet.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "llvm/Support/MathExtras.h"
41 #include "llvm/Support/raw_ostream.h"
42 
43 using namespace mlir;
44 using namespace mlir::linalg;
45 
46 //===----------------------------------------------------------------------===//
47 // Support for named Linalg ops defined in ods-gen.
48 //===----------------------------------------------------------------------===//
49 
52 
53 /// Fills the region of a structured operation using the provided
54 /// `regionBuilder`. The method is used by both named structured ops created by
55 /// ods-gen and by manually defined C++ ops. It is called by both builders and
56 /// parsers and creates a block with arguments corresponding to the elemental
57 /// types of `inputTypes` and `outputTypes`. All output types are asserted to be
58 /// ShapedType.
59 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
60  TypeRange inputTypes, TypeRange outputTypes,
62  RegionBuilderFn regionBuilder) {
63  assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
64 
65  // TODO: atm all operands go through getElementTypeOrSelf,
66  // reconsider when we have evidence we need to.
67  SmallVector<Type, 8> argTypes;
69  for (auto containers : {inputTypes, outputTypes}) {
70  for (auto t : containers) {
71  argTypes.push_back(getElementTypeOrSelf(t));
72 
73  // TODO: Pass in a proper location here.
74  argLocs.push_back(opBuilder.getUnknownLoc());
75  }
76  }
77 
78  // RAII.
79  OpBuilder::InsertionGuard guard(opBuilder);
80  Block *body =
81  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
82 
83  opBuilder.setInsertionPointToStart(body);
84  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
85  regionBuilder(b, *body, attrs);
86 
87  // indexing_maps is an auto-generated method.
88 
89  // iterator_types is an auto-generated method.
90 }
91 
92 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
93 /// The result types are derived automatically if `resultTensorTypes` is none.
94 /// The body of the operation is filled using `regionBuilder`. All ods-gen
95 /// created structured operations use the method to implement their builders.
97  llvm::Optional<TypeRange> resultTensorTypes,
98  ValueRange inputs, ValueRange outputs,
99  ArrayRef<NamedAttribute> attributes,
100  RegionBuilderFn regionBuilder) {
101  // Derive the result types if needed.
102  SmallVector<Type> derivedResultTypes =
103  resultTensorTypes.value_or(TypeRange());
104  if (!resultTensorTypes)
105  copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
106  [](Type type) { return type.isa<RankedTensorType>(); });
107 
108  state.addOperands(inputs);
109  state.addOperands(outputs);
110  state.addTypes(derivedResultTypes);
111  state.addAttributes(attributes);
112  state.addAttribute(
113  "operand_segment_sizes",
114  b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
115  static_cast<int32_t>(outputs.size())}));
116 
117  // Create and fill the region of the structured operation.
118  Region &region = *state.addRegion();
119  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
120  state.attributes.getAttrs(), regionBuilder);
121 }
122 
123 /// Common parsing used for both named structured ops created by ods-gen and by
124 /// manually defined C++ ops. Does not handle regions.
125 static ParseResult
127  SmallVectorImpl<Type> &inputTypes,
128  SmallVectorImpl<Type> &outputTypes) {
129  SMLoc inputsOperandsLoc, outputsOperandsLoc;
131  outputsOperands;
132 
133  if (parser.parseOptionalAttrDict(result.attributes))
134  return failure();
135 
136  if (succeeded(parser.parseOptionalKeyword("ins"))) {
137  if (parser.parseLParen())
138  return failure();
139 
140  inputsOperandsLoc = parser.getCurrentLocation();
141  if (parser.parseOperandList(inputsOperands) ||
142  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
143  return failure();
144  }
145 
146  if (succeeded(parser.parseOptionalKeyword("outs"))) {
147  outputsOperandsLoc = parser.getCurrentLocation();
148  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
149  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
150  return failure();
151  }
152 
153  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
154  result.operands) ||
155  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
156  result.operands))
157  return failure();
158 
159  result.addAttribute("operand_segment_sizes",
161  {static_cast<int32_t>(inputsOperands.size()),
162  static_cast<int32_t>(outputsOperands.size())}));
163  return success();
164 }
165 
167  ValueRange outputs) {
168  if (!inputs.empty())
169  p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
170  if (!outputs.empty())
171  p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
172 }
173 
174 //===----------------------------------------------------------------------===//
175 // Specific parsing and printing for named structured ops created by ods-gen.
176 //===----------------------------------------------------------------------===//
177 
179  OpAsmParser &parser, Region &region, unsigned numRegionArgs,
180  TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
181  RegionBuilderFn regionBuilder) {
182  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
183  return parser.emitError(
184  parser.getCurrentLocation(),
185  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
186  "region expects {0} args, got {1}",
187  numRegionArgs, inputTypes.size() + outputTypes.size()));
188  }
189 
190  OpBuilder opBuilder(parser.getContext());
191  fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
192  regionBuilder);
193  return success();
194 }
195 
196 static ParseResult
198  SmallVectorImpl<Type> &resultTypes) {
199  if (parser.parseOptionalArrowTypeList(resultTypes))
200  return failure();
201  return success();
202 }
203 
205  OperationState &result,
206  unsigned numRegionArgs,
207  RegionBuilderFn regionBuilder) {
208  // TODO: Enable when ods-gen supports captures.
209  SmallVector<Type, 1> inputTypes, outputTypes;
210  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
211  return failure();
212 
213  // TODO: consider merging results parsing into region parsing.
214  // Need to wait for declarative assembly resolution to decide.
215  SmallVector<Type, 1> outputTensorsTypes;
216  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
217  return failure();
218  result.addTypes(outputTensorsTypes);
219 
220  std::unique_ptr<Region> region = std::make_unique<Region>();
221  if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
222  outputTypes, result.attributes.getAttrs(),
223  regionBuilder))
224  return failure();
225  result.addRegion(std::move(region));
226 
227  return success();
228 }
229 
231  TypeRange resultTypes) {
232  if (resultTypes.empty())
233  return;
234  p.printOptionalArrowTypeList(resultTypes);
235 }
236 
238  ValueRange inputs, ValueRange outputs) {
240  op->getAttrs(),
241  /*elidedAttrs=*/{"operand_segment_sizes",
242  // See generated code in mlir-linalg-yaml-gen.cpp
243  "linalg.memoized_indexing_maps"});
244 
245  // Printing is shared with generic ops, except for the region and
246  // attributes.
247  printCommonStructuredOpParts(p, inputs, outputs);
248 
249  // Results printing.
251 
252  // Region is elided.
253 }
254 
255 /// This is a common class used for patterns of the form
256 /// ```
257 /// someop(memrefcast(%src)) -> someop(%src)
258 /// ```
259 /// It folds the source of the memref.cast into the root operation directly.
261  bool folded = false;
262  for (OpOperand &operand : op->getOpOperands()) {
263  auto castOp = operand.get().getDefiningOp<memref::CastOp>();
264  if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
265  operand.set(castOp.getOperand());
266  folded = true;
267  }
268  }
269  return success(folded);
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // Region builder helper.
274 // TODO: Move this to a utility library.
275 // The public methods on this class are referenced directly from generated code.
276 // Helper build the unary, binary, and type conversion functions defined by the
277 // DSL. See mlir-linalg-ods-yaml-gen.cpp for the code that uses this class.
278 //
279 // Implementations of the math functions must be polymorphic over numeric types,
280 // internally performing necessary casts. If the function application makes no
281 // sense, then the only recourse is to assert and return nullptr. This can be
282 // extended later if it becomes possible to fail construction of the region. The
283 // invariant should be enforced at a higher level.
284 //
285 // TODO: These helpers are currently type polymorphic over the class of integer
286 // and floating point types, but they will not internally cast within bit
287 // widths of a class (mixed precision such as i8->i32) or across classes
288 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
289 // to be handled with care and work is being considered to extend the op
290 // language to make such cases explicit. In the mean-time, violating this will
291 // fail verification, which is deemed acceptable.
292 //===----------------------------------------------------------------------===//
293 
294 namespace {
295 
296 class RegionBuilderHelper {
297 public:
298  RegionBuilderHelper(MLIRContext *context, Block &block)
299  : context(context), block(block) {}
300 
301  // Build the unary functions defined by OpDSL.
302  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
303  if (!isFloatingPoint(arg))
304  llvm_unreachable("unsupported non numeric type");
305  OpBuilder builder = getBuilder();
306  switch (unaryFn) {
307  case UnaryFn::exp:
308  return builder.create<math::ExpOp>(arg.getLoc(), arg);
309  case UnaryFn::log:
310  return builder.create<math::LogOp>(arg.getLoc(), arg);
311  case UnaryFn::abs:
312  return builder.create<math::AbsFOp>(arg.getLoc(), arg);
313  case UnaryFn::ceil:
314  return builder.create<math::CeilOp>(arg.getLoc(), arg);
315  case UnaryFn::floor:
316  return builder.create<math::FloorOp>(arg.getLoc(), arg);
317  case UnaryFn::negf:
318  return builder.create<arith::NegFOp>(arg.getLoc(), arg);
319  }
320  llvm_unreachable("unsupported unary function");
321  }
322 
323  // Build the binary functions defined by OpDSL.
324  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
325  bool allComplex = isComplex(arg0) && isComplex(arg1);
326  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
327  bool allInteger = isInteger(arg0) && isInteger(arg1);
328  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
329  arg1.getType().getIntOrFloatBitWidth() == 1;
330  if (!allComplex && !allFloatingPoint && !allInteger)
331  llvm_unreachable("unsupported non numeric type");
332  OpBuilder builder = getBuilder();
333  switch (binaryFn) {
334  case BinaryFn::add:
335  if (allComplex)
336  return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
337  if (allFloatingPoint)
338  return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
339  if (allBool)
340  return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
341  return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
342  case BinaryFn::sub:
343  if (allComplex)
344  return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
345  if (allFloatingPoint)
346  return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
347  if (allBool)
348  llvm_unreachable("unsupported operation: sub with bools");
349  return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
350  case BinaryFn::mul:
351  if (allComplex)
352  return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
353  if (allFloatingPoint)
354  return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
355  if (allBool)
356  return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
357  return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
358  case BinaryFn::max_signed:
359  assert(!allComplex);
360  if (allFloatingPoint)
361  return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
362  return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
363  case BinaryFn::min_signed:
364  assert(!allComplex);
365  if (allFloatingPoint)
366  return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
367  return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
368  case BinaryFn::max_unsigned:
369  assert(!allComplex);
370  if (allFloatingPoint)
371  return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
372  return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
373  case BinaryFn::min_unsigned:
374  assert(!allComplex);
375  if (allFloatingPoint)
376  return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
377  return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
378  }
379  llvm_unreachable("unsupported binary function");
380  }
381 
382  // Build the type functions defined by OpDSL.
383  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
384  switch (typeFn) {
385  case TypeFn::cast_signed:
386  return cast(toType, operand, false);
387  case TypeFn::cast_unsigned:
388  return cast(toType, operand, true);
389  }
390  llvm_unreachable("unsupported type conversion function");
391  }
392 
393  void yieldOutputs(ValueRange values) {
394  OpBuilder builder = getBuilder();
395  Location loc = builder.getUnknownLoc();
396  builder.create<YieldOp>(loc, values);
397  }
398 
399  Value constant(const std::string &value) {
400  OpBuilder builder = getBuilder();
401  Location loc = builder.getUnknownLoc();
402  Attribute valueAttr = parseAttribute(value, builder.getContext());
403  Type type = NoneType::get(builder.getContext());
404  if (auto typedAttr = valueAttr.dyn_cast<TypedAttr>())
405  type = typedAttr.getType();
406  return builder.create<arith::ConstantOp>(loc, type, valueAttr);
407  }
408 
409  Value index(int64_t dim) {
410  OpBuilder builder = getBuilder();
411  return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
412  }
413 
414  Type getIntegerType(unsigned width) {
415  return IntegerType::get(context, width);
416  }
417 
418  Type getFloat32Type() { return Float32Type::get(context); }
419  Type getFloat64Type() { return Float64Type::get(context); }
420 
421 private:
422  // Generates operations to cast the given operand to a specified type.
423  // If the cast cannot be performed, a warning will be issued and the
424  // operand returned as-is (which will presumably yield a verification
425  // issue downstream).
426  Value cast(Type toType, Value operand, bool isUnsignedCast) {
427  OpBuilder builder = getBuilder();
428  auto loc = operand.getLoc();
429 
430  if (operand.getType() == toType)
431  return operand;
432  if (auto toIntType = toType.dyn_cast<IntegerType>()) {
433  // If operand is floating point, cast directly to the int type.
434  if (operand.getType().isa<FloatType>()) {
435  if (isUnsignedCast)
436  return builder.create<arith::FPToUIOp>(loc, toType, operand);
437  return builder.create<arith::FPToSIOp>(loc, toType, operand);
438  }
439  // Cast index operands directly to the int type.
440  if (operand.getType().isIndex())
441  return builder.create<arith::IndexCastOp>(loc, toType, operand);
442  if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
443  // Either extend or truncate.
444  if (toIntType.getWidth() > fromIntType.getWidth()) {
445  if (isUnsignedCast)
446  return builder.create<arith::ExtUIOp>(loc, toType, operand);
447  return builder.create<arith::ExtSIOp>(loc, toType, operand);
448  }
449  if (toIntType.getWidth() < fromIntType.getWidth())
450  return builder.create<arith::TruncIOp>(loc, toType, operand);
451  }
452  } else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
453  // If operand is integer, cast directly to the float type.
454  // Note that it is unclear how to cast from BF16<->FP16.
455  if (operand.getType().isa<IntegerType>()) {
456  if (isUnsignedCast)
457  return builder.create<arith::UIToFPOp>(loc, toFloatType, operand);
458  return builder.create<arith::SIToFPOp>(loc, toFloatType, operand);
459  }
460  if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
461  if (toFloatType.getWidth() > fromFloatType.getWidth())
462  return builder.create<arith::ExtFOp>(loc, toFloatType, operand);
463  if (toFloatType.getWidth() < fromFloatType.getWidth())
464  return builder.create<arith::TruncFOp>(loc, toFloatType, operand);
465  }
466  }
467 
468  emitWarning(operand.getLoc()) << "could not cast operand of type "
469  << operand.getType() << " to " << toType;
470  return operand;
471  }
472 
473  bool isComplex(Value value) { return value.getType().isa<ComplexType>(); }
474  bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
475  bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
476 
477  OpBuilder getBuilder() {
478  OpBuilder builder(context);
479  builder.setInsertionPointToEnd(&block);
480  return builder;
481  }
482 
483  MLIRContext *context;
484  Block &block;
485 };
486 
487 } // namespace
488 
489 //===----------------------------------------------------------------------===//
490 // FillOp
491 //===----------------------------------------------------------------------===//
492 
493 namespace {
494 
495 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
496 ///
497 /// For such op chains, we can create new linalg.fill ops with the result
498 /// type of the tensor.expand/collapse_shape op.
499 template <typename TensorReshapeOp>
500 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
502  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
503  PatternRewriter &rewriter) const override {
504  auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
505  if (!oldFill)
506  return failure();
507 
508  Location loc = oldFill.getLoc();
509  auto newInit = rewriter.create<TensorReshapeOp>(
510  loc, reshapeOp.getResultType(), oldFill.output(),
511  reshapeOp.getReassociation());
512  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
513  ValueRange{newInit});
514 
515  return success();
516  }
517 };
518 
519 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
520 /// filling value are the same.
521 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
523 
524  LogicalResult matchAndRewrite(tensor::PadOp padOp,
525  PatternRewriter &rewriter) const override {
526  auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
527  if (!fillOp)
528  return failure();
529 
530  // We can only fold if the padding value is the same as the original
531  // filling value.
532  Value padValue = padOp.getConstantPaddingValue();
533  if (!padValue || fillOp.value() != padValue)
534  return failure();
535 
536  ReifiedRankedShapedTypeDims reifiedShape;
537  ReifyRankedShapedTypeOpInterface interface =
538  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation());
539  if (failed(interface.reifyResultShapes(rewriter, reifiedShape)))
540  return rewriter.notifyMatchFailure(
541  padOp, "failed to reify tensor.pad op result shape");
542 
543  auto oldResultType = padOp.getResultType();
544  SmallVector<int64_t, 4> staticShape(oldResultType.getRank(),
545  ShapedType::kDynamicSize);
546  auto newInitOp = rewriter.create<InitTensorOp>(
547  padOp.getLoc(), reifiedShape.front(), staticShape,
548  oldResultType.getElementType());
549  auto newFillOp = rewriter.create<FillOp>(
550  fillOp.getLoc(), ValueRange{padValue}, ValueRange{newInitOp});
551  rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldResultType,
552  newFillOp.result());
553 
554  return success();
555  }
556 };
557 
558 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
559 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
560 /// filling value are the same.
561 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
563 
564  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
565  PatternRewriter &rewriter) const override {
566  auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
567  if (!srcPadOp)
568  return failure();
569 
570  if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
571  return failure();
572 
573  // Walk back the tensor.insert_slice chain and find the first destination
574  // value at the start of the chain.
575  Value firstDest = insertOp.getDest();
576  while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
577  if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
578  return failure();
579 
580  // Make sure the range of values accessed are disjoint. Without this, we
581  // cannot fold tensor.pad away.
582  bool disjoint = false;
583  for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
584  // If the dimension has dynamic offset/size, we cannot guarantee
585  // disjoint. So just skip it.
586  if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
587  insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
588  prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
589  continue;
590 
591  // Get the range start and end, inclusively for both.
592  int64_t prevStart = prevOp.getStaticOffset(i);
593  int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
594  prevOp.getStaticStride(i);
595  int64_t nextStart = insertOp.getStaticOffset(i);
596  int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
597  insertOp.getStaticStride(i);
598  if (prevEnd < nextStart || nextEnd < prevStart) {
599  disjoint = true;
600  break;
601  }
602  }
603 
604  if (!disjoint)
605  break;
606  firstDest = prevOp.getDest();
607  }
608 
609  // Check whether the first destination is a fill op. For overlapped cases,
610  // this also cannot be true.
611  auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
612  if (!dstFillOp)
613  return failure();
614 
615  // We can only fold if the padding value is the same as the original
616  // filling value.
617  Value padValue = srcPadOp.getConstantPaddingValue();
618  if (!padValue || dstFillOp.value() != padValue)
619  return failure();
620 
621  SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
622  SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
623 
624  Location loc = insertOp.getLoc();
625  MLIRContext *context = getContext();
626 
627  AffineExpr sym0, sym1;
628  bindSymbols(context, sym0, sym1);
629  auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
630 
631  // Calculate the new offsets for the insert. It should be the old offsets
632  // plus low padding sizes.
633  SmallVector<OpFoldResult, 4> newOffsets;
634  for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
635  newOffsets.push_back(makeComposedFoldedAffineApply(
636  rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
637  }
638 
640  for (int i = 0, e = srcPadOp.getSourceType().getRank(); i < e; ++i) {
641  newSizes.push_back(
642  rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
643  .getResult());
644  }
645 
646  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
647  insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
648  newSizes, insertOp.getMixedStrides());
649  return success();
650  }
651 };
652 
653 } // namespace
654 
655 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
656  MLIRContext *context) {
657  results
658  .add<FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
659  FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
660  FoldInsertPadIntoFill>(context);
661 }
662 
663 //===----------------------------------------------------------------------===//
664 // GenericOps
665 //===----------------------------------------------------------------------===//
666 void GenericOp::build(
667  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
668  ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
669  ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
670  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
671  ArrayRef<NamedAttribute> attributes) {
672  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
673  iteratorTypes, doc, libraryCall);
674  result.addAttributes(attributes);
675  if (!bodyBuild)
676  return;
677 
678  SmallVector<Type, 4> blockArgTypes;
679  SmallVector<Location, 4> blockArgLocs;
680  for (ValueRange container : {inputs, outputs}) {
681  for (Value v : container) {
682  blockArgTypes.push_back(getElementTypeOrSelf(v));
683  blockArgLocs.push_back(v.getLoc());
684  }
685  }
686 
687  OpBuilder::InsertionGuard guard(builder);
688  auto &region = *result.regions.front();
689  Block *bodyBlock =
690  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
691  bodyBuild(builder, result.location, bodyBlock->getArguments());
692 }
693 
694 void GenericOp::build(
695  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
696  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
697  ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
698  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
699  ArrayRef<NamedAttribute> attributes) {
700  build(builder, result, resultTensorTypes, inputs, outputs,
701  builder.getAffineMapArrayAttr(indexingMaps),
702  builder.getStrArrayAttr(iteratorTypes),
703  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
704  libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
705  bodyBuild, attributes);
706 }
707 
708 void GenericOp::build(
709  OpBuilder &builder, OperationState &result, ValueRange inputs,
710  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
711  ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
712  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
713  ArrayRef<NamedAttribute> attributes) {
714  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
715  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
716 }
717 
718 void GenericOp::build(
719  OpBuilder &builder, OperationState &result, ValueRange inputs,
720  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
721  ArrayRef<StringRef> iteratorTypes,
722  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
723  ArrayRef<NamedAttribute> attributes) {
724  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
725  /*doc=*/"",
726  /*libraryCall=*/"", bodyBuild, attributes);
727 }
728 
729 void GenericOp::build(
730  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
731  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
732  ArrayRef<StringRef> iteratorTypes,
733  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
734  ArrayRef<NamedAttribute> attributes) {
735  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
736  iteratorTypes,
737  /*doc=*/"",
738  /*libraryCall=*/"", bodyBuild, attributes);
739 }
740 
742  p << " ";
743 
744  // Print extra attributes.
745  auto genericAttrNames = linalgTraitAttrNames();
746 
747  llvm::StringSet<> genericAttrNamesSet;
748  genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
749  SmallVector<NamedAttribute, 8> genericAttrs;
750  for (auto attr : (*this)->getAttrs())
751  if (genericAttrNamesSet.count(attr.getName().strref()) > 0)
752  genericAttrs.push_back(attr);
753  if (!genericAttrs.empty()) {
754  auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
755  p << genericDictAttr;
756  }
757 
758  // Printing is shared with named ops, except for the region and attributes
759  printCommonStructuredOpParts(p, getInputs(), getOutputs());
760 
761  genericAttrNames.push_back("operand_segment_sizes");
762  genericAttrNamesSet.insert(genericAttrNames.back());
763 
764  bool hasExtraAttrs = false;
765  for (NamedAttribute n : (*this)->getAttrs()) {
766  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
767  break;
768  }
769  if (hasExtraAttrs) {
770  p << " attrs = ";
771  p.printOptionalAttrDict((*this)->getAttrs(),
772  /*elidedAttrs=*/genericAttrNames);
773  }
774 
775  // Print region.
776  if (!getRegion().empty()) {
777  p << ' ';
778  p.printRegion(getRegion());
779  }
780 
781  // Print results.
782  printNamedStructuredOpResults(p, getResultTensors().getTypes());
783 }
784 
785 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
786  DictionaryAttr dictAttr;
787  // Parse the core linalg traits that must check into a dictAttr.
788  // The name is unimportant as we will overwrite result.attributes.
789  // The core linalg traits must contain the information necessary to pass the
790  // verifier.
791  if (parser.parseAttribute(dictAttr, "_", result.attributes))
792  return failure();
793  result.attributes.assign(dictAttr.getValue().begin(),
794  dictAttr.getValue().end());
795 
796  // Parsing is shared with named ops, except for the region.
797  SmallVector<Type, 1> inputTypes, outputTypes;
798  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
799  return failure();
800 
801  // Optional attributes may be added.
802  if (succeeded(parser.parseOptionalKeyword("attrs")))
803  if (failed(parser.parseEqual()) ||
804  failed(parser.parseOptionalAttrDict(result.attributes)))
805  return failure();
806 
807  std::unique_ptr<Region> region = std::make_unique<Region>();
808  if (parser.parseRegion(*region, {}))
809  return failure();
810  result.addRegion(std::move(region));
811 
812  // Generic ops may specify that a subset of its outputs are tensors. Such
813  // outputs are specified in the result type.
814  // TODO: may need to move output parsing before region parsing.
815  // Need to wait for declarative assembly resolution to decide.
816  SmallVector<Type, 1> outputTensorsTypes;
817  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
818  return failure();
819  result.addTypes(outputTensorsTypes);
820 
821  return success();
822 }
823 
826  &effects,
827  ValueRange results, ValueRange inputBuffers, ValueRange outputs) {
828  for (Value value : inputBuffers) {
829  effects.emplace_back(MemoryEffects::Read::get(), value,
831  }
832  for (Value value : outputs) {
833  effects.emplace_back(MemoryEffects::Read::get(), value,
835  effects.emplace_back(MemoryEffects::Write::get(), value,
837  }
838 }
839 
840 void GenericOp::getEffects(
842  &effects) {
843  SmallVector<Value> inputBuffers = getInputBufferOperands();
844  SmallVector<Value> outputBuffers = getOutputBufferOperands();
845  getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
846  outputBuffers);
847 }
848 
850 
851 namespace {
852 
853 struct DeduplicateAndRemoveDeadOperandsAndResults
854  : public OpRewritePattern<GenericOp> {
856 
857  LogicalResult matchAndRewrite(GenericOp genericOp,
858  PatternRewriter &rewriter) const override {
859  // Create a map from argument position in the original op to the argument
860  // position in the new op. If the argument is dropped it wont have an entry.
861  SmallVector<OpOperand *> droppedOpOperands;
862 
863  // Information needed to build the new op.
864  SmallVector<Value> newInputOperands, newOutputOperands;
865  SmallVector<AffineMap> newIndexingMaps;
866 
867  // Gather information about duplicate input operands.
868  llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
869  deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
870  newIndexingMaps);
871 
872  // Gather information about the dropped outputs.
873  llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
874  deduplicateOutputOperands(genericOp, droppedOpOperands,
875  newOutputOperands, newIndexingMaps);
876 
877  // Check if there is any change to operands.
878  if (newInputOperands.size() + newOutputOperands.size() ==
879  static_cast<size_t>(genericOp.getNumInputsAndOutputs()))
880  return failure();
881 
882  // Create the new op with the body being empty.
883  Location loc = genericOp.getLoc();
884  SmallVector<Type> newResultTypes;
885  if (genericOp.hasTensorSemantics()) {
886  newResultTypes = llvm::to_vector(llvm::map_range(
887  newOutputOperands, [](Value v) { return v.getType(); }));
888  }
889  auto newOp = rewriter.create<GenericOp>(
890  loc, newResultTypes, newInputOperands, newOutputOperands,
891  rewriter.getAffineMapArrayAttr(newIndexingMaps),
892  genericOp.getIteratorTypes(), genericOp.getDocAttr(),
893  genericOp.getLibraryCallAttr(),
894  [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
895  return;
896  });
897  // Copy over unknown attributes. They might be load bearing for some flow.
898  ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
899  for (NamedAttribute kv : genericOp->getAttrs())
900  if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
901  newOp->setAttr(kv.getName(), kv.getValue());
902 
903  // Fix up the payload of the canonicalized operation.
904  populateOpPayload(genericOp, newOp, origInsToNewInsPos,
905  origOutsToNewOutsPos, rewriter);
906 
907  // Replace all live uses of the op.
908  SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
909  for (const auto &result : llvm::enumerate(genericOp.getResults())) {
910  auto it = origOutsToNewOutsPos.find(result.index());
911  if (it == origOutsToNewOutsPos.end())
912  continue;
913  replacementsVals[result.index()] = newOp.getResult(it->second);
914  }
915  rewriter.replaceOp(genericOp, replacementsVals);
916  return success();
917  }
918 
919 private:
920  // Deduplicate input operands, and return the
921  // - Mapping from operand position in the original op, to operand position in
922  // the canonicalized op.
923  // - The preserved input operands list (by reference).
924  llvm::SmallDenseMap<unsigned, unsigned>
925  deduplicateInputOperands(GenericOp genericOp,
926  SmallVector<OpOperand *> &droppedOpOperands,
927  SmallVector<Value> &newInputOperands,
928  SmallVector<AffineMap> &newIndexingMaps) const {
929  llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
930  llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
931  for (const auto &inputOpOperand :
932  llvm::enumerate(genericOp.getInputOperands())) {
933  // Check if operand is dead and if dropping the indexing map makes the
934  // loops to shape computation invalid.
935  if (!genericOp.payloadUsesValueFromOperand(inputOpOperand.value())) {
936  // Add the current operands to the list of potentially droppable
937  // operands. If it cannot be dropped, this needs to be popped back.
938  droppedOpOperands.push_back(inputOpOperand.value());
939  if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
940  continue;
941  droppedOpOperands.pop_back();
942  }
943 
944  // Check if this operand is a duplicate.
945  AffineMap indexingMap =
946  genericOp.getTiedIndexingMap(inputOpOperand.value());
947  auto it = dedupedInputs.find(
948  std::make_pair(inputOpOperand.value()->get(), indexingMap));
949  if (it != dedupedInputs.end()) {
950  origToNewPos[inputOpOperand.index()] = it->second;
951  droppedOpOperands.push_back(inputOpOperand.value());
952  continue;
953  }
954 
955  // This is a preserved argument.
956  origToNewPos[inputOpOperand.index()] = newInputOperands.size();
957  dedupedInputs[{inputOpOperand.value()->get(), indexingMap}] =
958  newInputOperands.size();
959  newInputOperands.push_back(inputOpOperand.value()->get());
960  newIndexingMaps.push_back(indexingMap);
961  }
962  return origToNewPos;
963  }
964 
965  // Deduplicate output operands, and return the
966  // - Mapping from operand position in the original op, to operand position in
967  // the canonicalized op.
968  // - The preserved output operands list (by reference).
969  llvm::SmallDenseMap<unsigned, unsigned>
970  deduplicateOutputOperands(GenericOp genericOp,
971  SmallVector<OpOperand *> &droppedOpOperands,
972  SmallVector<Value> &newOutputOperands,
973  SmallVector<AffineMap> &newIndexingMaps) const {
974  llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
975  llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
976  dedupedOutpts;
977  // If the op doesnt have tensor semantics, keep all the outputs as
978  // preserved.
979  if (!genericOp.hasTensorSemantics()) {
980  for (const auto &outputOpOperand :
981  llvm::enumerate(genericOp.getOutputOperands())) {
982  origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
983  newOutputOperands.push_back(outputOpOperand.value()->get());
984  newIndexingMaps.push_back(
985  genericOp.getTiedIndexingMap(outputOpOperand.value()));
986  }
987  } else {
988  // Output argument can be dropped if the result has
989  // - no users, and
990  // - it is not used in the payload, and
991  // - the corresponding indexing maps are not needed for loop bound
992  // computation.
993  auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
994  for (const auto &outputOpOperand :
995  llvm::enumerate(genericOp.getOutputOperands())) {
996  Value result = genericOp.getResult(outputOpOperand.index());
997  AffineMap indexingMap =
998  genericOp.getTiedIndexingMap(outputOpOperand.value());
999  auto key =
1000  std::make_tuple(outputOpOperand.value()->get(), indexingMap,
1001  yieldOp->getOperand(outputOpOperand.index()));
1002 
1003  // Do not drop an out if its value is used in the payload.
1004  if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) {
1005  if (result.use_empty()) {
1006  // Check if the opoperand can be dropped without affecting loop
1007  // bound computation. Add the operand to the list of dropped op
1008  // operand for checking. If it cannot be dropped, need to pop the
1009  // value back.
1010  droppedOpOperands.push_back(outputOpOperand.value());
1011  if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
1012  continue;
1013  }
1014  droppedOpOperands.pop_back();
1015  }
1016 
1017  // The out operand can also be dropped if it is computed redundantly
1018  // by another result, the conditions for that are
1019  // - The same operand is used as the out operand
1020  // - The same indexing map is used
1021  // - The same yield value is used.
1022  auto it = dedupedOutpts.find(key);
1023  if (it != dedupedOutpts.end()) {
1024  origToNewPos[outputOpOperand.index()] = it->second;
1025  droppedOpOperands.push_back(outputOpOperand.value());
1026  continue;
1027  }
1028  }
1029 
1030  origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
1031  dedupedOutpts[key] = newOutputOperands.size();
1032  newOutputOperands.push_back(outputOpOperand.value()->get());
1033  newIndexingMaps.push_back(
1034  genericOp.getTiedIndexingMap(outputOpOperand.value()));
1035  }
1036  }
1037 
1038  return origToNewPos;
1039  }
1040 
1041  // Populate the body of the canonicalized operation.
1042  void populateOpPayload(
1043  GenericOp genericOp, GenericOp newOp,
1044  const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
1045  const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
1046  PatternRewriter &rewriter) const {
1047  // Merge the body of the original op with the new op.
1048  Block *newOpBlock = &newOp.getRegion().front();
1049  assert(newOpBlock->empty() && "expected new op to have an empty payload");
1050  Block *origOpBlock = &genericOp.getRegion().front();
1051  SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
1052 
1053  // Replace all arguments in the original op, with arguments from the
1054  // canonicalized op.
1055  auto updateReplacements =
1056  [&](OpOperandVector &origOperands, OpOperandVector &newOperands,
1057  const llvm::SmallDenseMap<unsigned, unsigned> &map) {
1058  for (const auto &origOperand : llvm::enumerate(origOperands)) {
1059  auto it = map.find(origOperand.index());
1060  if (it == map.end())
1061  continue;
1062  OpOperand *newOperand = newOperands[it->second];
1063  replacements[origOperand.value()->getOperandNumber()] =
1064  newOpBlock->getArgument(newOperand->getOperandNumber());
1065  }
1066  };
1067 
1068  OpOperandVector origInputOperands = genericOp.getInputOperands();
1069  OpOperandVector newInputOperands = newOp.getInputOperands();
1070  updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
1071 
1072  OpOperandVector origOutputOperands = genericOp.getOutputOperands();
1073  OpOperandVector newOutputOperands = newOp.getOutputOperands();
1074  updateReplacements(origOutputOperands, newOutputOperands,
1075  origOutsToNewOutsPos);
1076 
1077  rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
1078 
1079  // Drop the unused yield args.
1080  if (newOp.getNumOutputs() != genericOp.getNumOutputs()) {
1081  OpBuilder::InsertionGuard g(rewriter);
1082  YieldOp origYieldOp = cast<YieldOp>(newOpBlock->getTerminator());
1083  rewriter.setInsertionPoint(origYieldOp);
1084 
1085  SmallVector<Value> newYieldVals(newOp.getNumOutputs(), nullptr);
1086  for (const auto &yieldOpOperands :
1087  llvm::enumerate(origYieldOp.getValues())) {
1088  auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
1089  if (it == origOutsToNewOutsPos.end())
1090  continue;
1091  newYieldVals[it->second] = yieldOpOperands.value();
1092  }
1093  rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
1094  }
1095  }
1096 };
1097 
1098 /// Remove generic operations (on tensors) that are just copying
1099 /// the values from inputs to the results. Requirements are
1100 /// 1) All iterator types are parallel
1101 /// 2) The body contains just a yield operation with the yielded values being
1102 /// the arguments corresponding to the operands.
1103 struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
1105 
1106  LogicalResult matchAndRewrite(GenericOp genericOp,
1107  PatternRewriter &rewriter) const override {
1108  // Check all indexing maps are identity.
1109  if (llvm::any_of(genericOp.getIndexingMapsArray(),
1110  [](AffineMap map) { return !map.isIdentity(); }))
1111  return failure();
1112 
1113  // Check that the body of the linalg operation is just a linalg.yield
1114  // operation.
1115  Block &body = genericOp.getRegion().front();
1116  if (!llvm::hasSingleElement(body))
1117  return failure();
1118  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1119  if (!yieldOp)
1120  return failure();
1121 
1122  // In the buffer case, we need to check exact buffer equality.
1123  if (genericOp.hasBufferSemantics()) {
1124  if (genericOp.getNumInputs() == 1 && genericOp.getNumOutputs() == 1 &&
1125  genericOp.getInputOperand(0)->get() ==
1126  genericOp.getOutputOperand(0)->get()) {
1127  rewriter.eraseOp(genericOp);
1128  return success();
1129  }
1130  return failure();
1131  }
1132 
1133  // Get the argument number of the returned values. That is the operand
1134  // number to use for replacing uses of this operation.
1135  SmallVector<Value> returnedArgs;
1136  for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1137  auto yieldArg = yieldVal.value().dyn_cast<BlockArgument>();
1138  if (!yieldArg || yieldArg.getOwner() != &body)
1139  return failure();
1140  unsigned argumentNumber = yieldArg.getArgNumber();
1141  Value returnedArg = genericOp->getOperand(argumentNumber);
1142  Type resultType = genericOp->getResult(yieldVal.index()).getType();
1143  // The input can have a different type than the result, e.g. a dynamic
1144  // input dimension can be turned into a static output dimension.
1145  Type returnType = returnedArg.getType();
1146  if (returnType != resultType) {
1147  // Distinguish between sparse conversion or dense tensor casting.
1148  // TODO: unify the two ops?
1149  if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1151  returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1152  genericOp.getLoc(), resultType, returnedArg);
1153  else {
1154  if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1155  resultType))
1156  return failure();
1157  returnedArg = rewriter.create<tensor::CastOp>(
1158  genericOp.getLoc(), resultType, returnedArg);
1159  }
1160  }
1161  returnedArgs.push_back(returnedArg);
1162  }
1163 
1164  if (returnedArgs.size() != genericOp->getNumResults())
1165  return failure();
1166  rewriter.replaceOp(genericOp, returnedArgs);
1167  return success();
1168  }
1169 };
1170 } // namespace
1171 
1172 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1173  MLIRContext *context) {
1174  results
1175  .add<DeduplicateAndRemoveDeadOperandsAndResults, EraseIdentityGenericOp>(
1176  context);
1177 }
1178 
1179 LogicalResult GenericOp::fold(ArrayRef<Attribute>,
1181  return foldMemRefCast(*this);
1182 }
1183 
1184 //===----------------------------------------------------------------------===//
1185 // InitTensorOp
1186 //===----------------------------------------------------------------------===//
1187 
1188 void InitTensorOp::build(OpBuilder &b, OperationState &result,
1189  ArrayRef<OpFoldResult> sizes, Type elementType,
1190  ArrayRef<NamedAttribute> attrs) {
1191  SmallVector<Value, 4> dynamicSizes;
1192  SmallVector<int64_t, 4> staticSizes;
1193  dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1194  ShapedType::kDynamicSize);
1195  auto resultType = RankedTensorType ::get(staticSizes, elementType);
1196  build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes));
1197  result.addAttributes(attrs);
1198 }
1199 
1201  RankedTensorType resultType = getType();
1202  SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
1203  getStaticSizes().cast<ArrayAttr>(),
1204  [](Attribute a) -> int64_t { return a.cast<IntegerAttr>().getInt(); }));
1205 
1207  *this, "sizes", resultType.getRank(), getStaticSizes(), getSizes(),
1208  ShapedType::isDynamic)))
1209  return failure();
1210 
1211  if (getStaticSizes().size() != static_cast<unsigned>(resultType.getRank()))
1212  return emitError("expected ") << resultType.getRank() << " sizes values";
1213 
1214  Type expectedType = InitTensorOp::inferResultType(
1215  staticSizes, resultType.getElementType(), resultType.getEncoding());
1216  if (resultType != expectedType) {
1217  return emitError("specified type ")
1218  << resultType << " does not match the inferred type "
1219  << expectedType;
1220  }
1221  return success();
1222 }
1223 
1224 Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
1225  Type elementType, Attribute encoding) {
1226  return RankedTensorType::get(staticSizes, elementType, encoding);
1227 }
1228 
1230  SmallVector<OpFoldResult> mixedSizes;
1231  mixedSizes.reserve(getType().getRank());
1232  unsigned dynamicValIndex = 0;
1233  for (Attribute attr : getStaticSizes()) {
1234  auto intAttr = attr.cast<IntegerAttr>();
1235  if (!ShapedType::isDynamic(intAttr.getInt())) {
1236  mixedSizes.push_back(intAttr);
1237  continue;
1238  }
1239  mixedSizes.push_back(getSizes()[dynamicValIndex++]);
1240  }
1241  return mixedSizes;
1242 }
1243 
1244 namespace {
1245 /// Change the type of the result of a `linalg.init_tensor` by making the result
1246 /// type statically sized along dimension that in the original operation where
1247 /// defined as dynamic, but the size was defined using a `constant` op. For
1248 /// example
1249 ///
1250 /// %c5 = arith.constant 5: index
1251 /// %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32>
1252 ///
1253 /// to
1254 ///
1255 /// %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32>
1256 struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
1258 
1259  LogicalResult matchAndRewrite(InitTensorOp op,
1260  PatternRewriter &rewriter) const override {
1261  SmallVector<Value, 4> dynamicSizes;
1262  SmallVector<int64_t, 4> staticSizes;
1263  for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) {
1264  // If the size is already static, nothing to do.
1265  if (!op.isDynamicSize(i)) {
1266  staticSizes.push_back(op.getStaticSize(i));
1267  continue;
1268  }
1269 
1270  // If the size is dynamic but defined using a `constant` op, get the
1271  // constant value to find the static size to use.
1272  unsigned operandNum = op.getIndexOfDynamicSize(i);
1273  Value sizeOperand = op.getOperand(operandNum);
1274  if (auto constantIndexOp =
1275  sizeOperand.getDefiningOp<arith::ConstantIndexOp>()) {
1276  staticSizes.push_back(constantIndexOp.value());
1277  continue;
1278  }
1279 
1280  // Fallback case. Keep the size dynamic.
1281  dynamicSizes.push_back(sizeOperand);
1282  staticSizes.push_back(ShapedType::kDynamicSize);
1283  }
1284  RankedTensorType newType =
1285  RankedTensorType::get(staticSizes, op.getType().getElementType());
1286  if (newType == op.getType())
1287  return failure();
1288  auto newOp =
1289  rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
1290  rewriter.getI64ArrayAttr(staticSizes));
1291  rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
1292  return success();
1293  }
1294 };
1295 } // namespace
1296 
1297 namespace {
1298 /// Since `init_tensor` operation creates a tensor needed only for its shape, a
1299 /// slice of this is also needed only for its shape. The result can be
1300 /// replaced by a new init_tensor operation of the same size as the extract
1301 /// slice op.
1302 struct FoldInitTensorWithExtractSliceOp
1303  : public OpRewritePattern<tensor::ExtractSliceOp> {
1305 
1306  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
1307  PatternRewriter &rewriter) const override {
1308  if (!sliceOp.getSource().getDefiningOp<linalg::InitTensorOp>())
1309  return failure();
1310  // ExtractSliceOp may be rank-reducing; its dynamic sizes must be preserved
1311  // as well as its result type.
1312  rewriter.replaceOpWithNewOp<linalg::InitTensorOp>(
1313  sliceOp, sliceOp.getSizes(),
1314  sliceOp.getResult().getType().cast<RankedTensorType>().getShape(),
1315  sliceOp.getSourceType().getElementType());
1316  return success();
1317  }
1318 };
1319 
1320 template <typename TensorReshapeOp>
1321 struct FoldInitTensorWithTensorReshapeOp
1322  : public OpRewritePattern<TensorReshapeOp> {
1324 
1325  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1326  PatternRewriter &rewriter) const override {
1327  if (!reshapeOp.getSrc().template getDefiningOp<InitTensorOp>())
1328  return failure();
1329  Location loc = reshapeOp.getLoc();
1330  ReifiedRankedShapedTypeDims resultShapes;
1331  ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
1332  cast<ReifyRankedShapedTypeOpInterface>(reshapeOp.getOperation());
1333  if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter,
1334  resultShapes)) ||
1335  !llvm::hasSingleElement(resultShapes))
1336  return failure();
1337  Value initTensor = rewriter.create<InitTensorOp>(
1338  loc, getAsOpFoldResult(resultShapes[0]),
1339  reshapeOp.getResultType().getElementType());
1340  if (initTensor.getType() != reshapeOp.getResultType()) {
1341  rewriter.replaceOpWithNewOp<tensor::CastOp>(
1342  reshapeOp, reshapeOp.getResultType(), initTensor);
1343  } else {
1344  rewriter.replaceOp(reshapeOp, initTensor);
1345  }
1346  return success();
1347  }
1348 };
1349 
1350 struct FoldInitTensorWithDimOp : public OpRewritePattern<tensor::DimOp> {
1352 
1353  LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1354  PatternRewriter &rewriter) const override {
1355  Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1356  auto initTensorOp = dimOp.getSource().getDefiningOp<linalg::InitTensorOp>();
1357  if (!initTensorOp || !maybeConstantIndex)
1358  return failure();
1359  if (!initTensorOp.isDynamicSize(*maybeConstantIndex))
1360  return failure();
1361  rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(*maybeConstantIndex));
1362  return success();
1363  }
1364 };
1365 
1366 /// Canonicalize
1367 ///
1368 /// ```mlir
1369 /// %0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
1370 /// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
1371 /// ```
1372 ///
1373 /// into
1374 ///
1375 /// ```mlir
1376 /// %0 = linalg.init_tensor [4, %d1] : tensor<4x?xf32>
1377 /// ```
1378 ///
1379 /// This assumes the input program is correct in terms of its shape. So it
1380 /// is safe to assume that `%d0` is in fact 4. If that was not the case, the
1381 /// input program is wrong to begin with, so its undefined behavior anyway (i.e.
1382 /// this optimization can still triggering without violating program semantics).
1383 struct FoldInitTensorWithTensorCastOp
1384  : public OpRewritePattern<tensor::CastOp> {
1386 
1387  LogicalResult matchAndRewrite(tensor::CastOp castOp,
1388  PatternRewriter &rewriter) const override {
1389  if (!canFoldIntoProducerOp(castOp))
1390  return failure();
1391  auto producer = castOp.getSource().getDefiningOp<InitTensorOp>();
1392  if (!producer)
1393  return failure();
1394 
1395  auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
1396  ArrayRef<int64_t> resultShape = resultType.getShape();
1397  SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1398  SmallVector<OpFoldResult> newMixedSizes;
1399  newMixedSizes.reserve(currMixedSizes.size());
1400  assert(resultShape.size() == currMixedSizes.size() &&
1401  "mismatch in result shape and sizes of init_tensor op");
1402  for (auto it : llvm::zip(resultShape, currMixedSizes)) {
1403  int64_t newDim = std::get<0>(it);
1404  OpFoldResult currDim = std::get<1>(it);
1405  // Case 1: The init tensor dim is static. Check that the tensor cast
1406  // result dim matches.
1407  if (auto attr = currDim.dyn_cast<Attribute>()) {
1408  if (ShapedType::isDynamic(newDim) ||
1409  newDim != attr.cast<IntegerAttr>().getInt()) {
1410  // Something is off, the cast result shape cannot be more dynamic than
1411  // the init tensor result shape (enforced by `canFoldIntoProducer`).
1412  // Abort for now.
1413  return rewriter.notifyMatchFailure(
1414  producer, "mismatch in static value of shape of init "
1415  "tensor result and cast result");
1416  }
1417  newMixedSizes.push_back(attr);
1418  continue;
1419  }
1420 
1421  // Case 2 : The tensor cast shape is static, but init tensor result shape
1422  // is dynamic.
1423  if (!ShapedType::isDynamic(newDim)) {
1424  newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
1425  continue;
1426  }
1427 
1428  // Case 3 : The tensor cast shape is dynamic and init tensor result shape
1429  // is dynamic. Use the dynamic value from the init tensor op.
1430  newMixedSizes.push_back(currDim);
1431  }
1432 
1433  rewriter.replaceOpWithNewOp<InitTensorOp>(castOp, newMixedSizes,
1434  resultType.getElementType());
1435  return success();
1436  }
1437 };
1438 
1439 } // namespace
1440 
1441 void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
1442  MLIRContext *context) {
1443  results.add<FoldInitTensorWithTensorCastOp, FoldInitTensorWithDimOp,
1444  FoldInitTensorWithExtractSliceOp,
1445  FoldInitTensorWithTensorReshapeOp<tensor::ExpandShapeOp>,
1446  FoldInitTensorWithTensorReshapeOp<tensor::CollapseShapeOp>,
1447  ReplaceStaticShapeDims>(context);
1448 }
1449 
1450 LogicalResult InitTensorOp::reifyResultShapes(
1451  OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1452  auto shapes = llvm::to_vector<4>(llvm::map_range(
1453  llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
1454  if (isDynamicSize(dim))
1455  return getDynamicSize(dim);
1456  return builder.create<arith::ConstantIndexOp>(getLoc(),
1457  getStaticSize(dim));
1458  }));
1459  reifiedReturnShapes.emplace_back(std::move(shapes));
1460  return success();
1461 }
1462 
1463 //===----------------------------------------------------------------------===//
1464 // YieldOp
1465 //===----------------------------------------------------------------------===//
1466 
1468  if (getNumOperands() > 0)
1469  p << ' ' << getOperands();
1470  p.printOptionalAttrDict((*this)->getAttrs());
1471  if (getNumOperands() > 0)
1472  p << " : " << getOperandTypes();
1473 }
1474 
1475 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
1477  SmallVector<Type, 2> types;
1478  SMLoc loc = parser.getCurrentLocation();
1479  return failure(parser.parseOperandList(opInfo) ||
1480  parser.parseOptionalAttrDict(result.attributes) ||
1481  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
1482  parser.resolveOperands(opInfo, types, loc, result.operands));
1483 }
1484 
1485 // Check the operand number and types must match the element types of the
1486 // LinalgOp interface's shaped operands.
1487 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
1488  if (op.getNumOperands() != linalgOp.getNumOutputs())
1489  return op.emitOpError("expected number of yield values (")
1490  << linalgOp.getNumOutputs()
1491  << ") to match the number of operands of the enclosing "
1492  << "LinalgOp (" << op.getNumOperands() << ")";
1493 
1494  for (OpOperand &opOperand : op->getOpOperands()) {
1495  OpOperand *outputOperand =
1496  linalgOp.getOutputOperand(opOperand.getOperandNumber());
1497  Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
1498  if (opOperand.get().getType() != elementType)
1499  return op.emitOpError("type of yield operand ")
1500  << (opOperand.getOperandNumber() + 1) << " ("
1501  << opOperand.get().getType() << ") doesn't match "
1502  << "the element type of the enclosing linalg.generic op ("
1503  << elementType << ")";
1504  }
1505  return success();
1506 }
1507 
1509  auto *parentOp = (*this)->getParentOp();
1510  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
1511  return emitOpError("expected single non-empty parent region");
1512 
1513  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
1514  return verifyYield(*this, linalgOp);
1515 
1516  return emitOpError("expected parent op with LinalgOp interface");
1517 }
1518 
1519 //===----------------------------------------------------------------------===//
1520 // IndexOp
1521 //===----------------------------------------------------------------------===//
1522 
1524  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
1525  if (!linalgOp)
1526  return emitOpError("expected parent op with LinalgOp interface");
1527  if (linalgOp.getNumLoops() <= getDim())
1528  return emitOpError("expected dim (")
1529  << getDim() << ") to be lower than the number of loops ("
1530  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
1531  return success();
1532 }
1533 
1534 /////// Operations corresponding to library calls defined with Tablegen ////////
1535 
1536 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
1537 
1538 #define GET_OP_CLASSES
1539 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
1540 
1541 #define GET_OP_CLASSES
1542 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
1543 
1544 /// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`.
1545 /// Assumes `op` is a LinalgOp.
1546 void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName,
1548  if (!cast<LinalgOp>(op).iterator_types())
1549  return;
1550 
1551  unsigned dim = 0;
1552  for (auto tn :
1553  cast<LinalgOp>(op).iterator_types().getAsValueRange<StringAttr>()) {
1554  if (tn == iteratorTypeName)
1555  res.push_back(dim);
1556  ++dim;
1557  }
1558 }
1559 
1561  unsigned rank,
1562  MLIRContext *context) {
1563  if (maybeMap)
1564  return *maybeMap;
1565  if (rank == 0)
1566  return AffineMap::get(context);
1567  return AffineMap::getMultiDimIdentityMap(rank, context);
1568 }
1569 
1571 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
1572  MLIRContext *context) {
1574  res.reserve(num);
1575  for (unsigned i = 0; i < num; ++i)
1576  res.push_back(getAffineDimExpr(startIdx++, context));
1577  return res;
1578 }
1579 
1582  auto rangeA = llvm::make_range(a.begin(), a.end());
1583  auto rangeB = llvm::make_range(b.begin(), b.end());
1584  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
1585  return llvm::to_vector<4>(concatRanges);
1586 }
1587 
1588 static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
1589  if (auto memref = t.dyn_cast<MemRefType>()) {
1590  ss << "view";
1591  for (auto size : memref.getShape())
1592  if (size < 0)
1593  ss << "sx";
1594  else
1595  ss << size << "x";
1596  appendMangledType(ss, memref.getElementType());
1597  } else if (auto vec = t.dyn_cast<VectorType>()) {
1598  ss << "vector";
1599  llvm::interleave(
1600  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
1601  appendMangledType(ss, vec.getElementType());
1602  } else if (t.isSignlessIntOrIndexOrFloat()) {
1603  ss << t;
1604  } else {
1605  llvm_unreachable("Invalid type for linalg library name mangling");
1606  }
1607 }
1608 
1610  assert(isa<LinalgOp>(op));
1611  std::string name(op->getName().getStringRef().str());
1612  name.reserve(128);
1613  std::replace(name.begin(), name.end(), '.', '_');
1614  llvm::raw_string_ostream ss(name);
1615  ss << "_";
1616  auto types = op->getOperandTypes();
1617  llvm::interleave(
1618  types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
1619  [&]() { ss << "_"; });
1620  return ss.str();
1621 }
1622 
1623 //===----------------------------------------------------------------------===//
1624 // Canonicalizers and Folders.
1625 //===----------------------------------------------------------------------===//
1626 
1627 namespace {
1628 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
1630 
1631  LogicalResult matchAndRewrite(LinalgOp op,
1632  PatternRewriter &rewriter) const override {
1633  for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
1634  // Linalg "inputs" may be either tensor or memref type.
1635  // tensor<0xelt_type> is a convention that may not always mean
1636  // "0 iterations". Only erase in cases we see memref<...x0x...>.
1637  auto mt = opOperand->get().getType().dyn_cast<MemRefType>();
1638  if (!mt)
1639  continue;
1640  if (llvm::is_contained(op.getShape(opOperand), 0)) {
1641  rewriter.eraseOp(op);
1642  return success();
1643  }
1644  }
1645  return failure();
1646  }
1647 };
1648 
1649 struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> {
1651 
1652  LogicalResult matchAndRewrite(LinalgOp op,
1653  PatternRewriter &rewriter) const override {
1654  // If no operand comes from a tensor::CastOp and can be folded then fail.
1655  bool hasTensorCastOperand =
1656  llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
1657  if (opOperand->get().isa<BlockArgument>())
1658  return false;
1659  auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
1660  return castOp && canFoldIntoConsumerOp(castOp);
1661  });
1662  if (!hasTensorCastOperand)
1663  return failure();
1664 
1665  SmallVector<Type, 4> newResultTypes;
1666  newResultTypes.reserve(op->getNumResults());
1667  SmallVector<Value, 4> newOperands;
1668  newOperands.reserve(op->getNumOperands());
1669  // Inputs may fold.
1670  for (OpOperand *opOperand : op.getInputOperands()) {
1671  auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
1672  newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
1673  ? tensorCastOp.getSource()
1674  : opOperand->get());
1675  }
1676  // Init tensors may fold, in which case the resultType must also change.
1677  for (OpOperand *opOperand : op.getOutputOperands()) {
1678  auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
1679  bool fold = canFoldIntoConsumerOp(tensorCastOp);
1680  newOperands.push_back(fold ? tensorCastOp.getOperand()
1681  : opOperand->get());
1682  newResultTypes.push_back(newOperands.back().getType());
1683  }
1684  // Clone op.
1685  Operation *newOp =
1686  op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
1687  SmallVector<Value, 4> replacements;
1688  replacements.reserve(newOp->getNumResults());
1689  for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
1690  Value oldResult = std::get<0>(result);
1691  Value newResult = std::get<1>(result);
1692  if (newResult.getType() != oldResult.getType()) {
1693  replacements.push_back(rewriter.create<tensor::CastOp>(
1694  op->getLoc(), oldResult.getType(), newResult));
1695  } else {
1696  replacements.push_back(newResult);
1697  }
1698  }
1699  rewriter.replaceOp(op, replacements);
1700 
1701  return success();
1702  }
1703 };
1704 
1705 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
1706 /// result that is more static than the linalg op.
1707 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
1709 
1710  LogicalResult matchAndRewrite(tensor::CastOp castOp,
1711  PatternRewriter &rewriter) const override {
1712  if (!tensor::canFoldIntoProducerOp(castOp))
1713  return failure();
1714 
1715  auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
1716  if (!linalgOp)
1717  return failure();
1718 
1719  // Cast can be in conditionally reachable region, if which case folding will
1720  // generate invalid code. Only conservatively fold ops in same block for
1721  // now.
1722  if (castOp->getBlock() != linalgOp->getBlock())
1723  return failure();
1724 
1725  OpBuilder::InsertionGuard guard(rewriter);
1726  rewriter.setInsertionPoint(linalgOp);
1727 
1728  Location loc = linalgOp.getLoc();
1729  OpResult resultValue = castOp.getSource().cast<OpResult>();
1730  unsigned resultNumber = resultValue.getResultNumber();
1731  auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
1732  // Replace the `outs` for the result with a `tensor.cast`. This cast is now
1733  // going from a more dynamic shape to a less dynamic shape. If the producer
1734  // for this cast, i.e. producer of the out operand, is also an operation
1735  // that folds with tensor.cast consumer (like this pattern), the cast will
1736  // continue to propagate as far up the stack as it can go.
1737  OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
1738  Value newOperand =
1739  rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
1740  SmallVector<Value> newOperands = linalgOp.getInputOperands();
1741  SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
1742  outputOperands[resultNumber] = newOperand;
1743  newOperands.append(outputOperands.begin(), outputOperands.end());
1744 
1745  SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
1746  linalgOp->result_type_end());
1747  resultTypes[resultNumber] = resultType;
1748  Operation *newOp = linalgOp.clone(rewriter, loc, resultTypes, newOperands);
1749 
1750  // Create a tensor.cast operation back to the original type.
1751  Value castBack = rewriter.create<tensor::CastOp>(
1752  loc, resultValue.getType(), newOp->getResult(resultNumber));
1753 
1754  SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
1755  results[resultNumber] = castBack;
1756  rewriter.replaceOp(linalgOp, results);
1757  rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
1758  return success();
1759  }
1760 };
1761 
1762 /// For each of the operand in `operands` this function maps the static sizes of
1763 /// dimensions to their affine dim expressions.
1764 static void populateMap(LinalgOp linalgOp, ArrayRef<OpOperand *> operands,
1765  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
1766  for (OpOperand *opOperand : operands) {
1767  if (linalgOp.isScalar(opOperand))
1768  continue;
1769  Value src = opOperand->get();
1770  auto sourceType = src.getType().cast<RankedTensorType>();
1771  auto sourceMap = linalgOp.getTiedIndexingMap(opOperand);
1772 
1773  // Get the `sourceShape` of the `sourceType`. If the operand is a result of
1774  // `tensor.cast` operation and source of the cast operation has a static
1775  // shape, then assign it to the `sourceShape`.
1776  auto *parentOp = src.getDefiningOp();
1777  ArrayRef<int64_t> sourceShape = sourceType.getShape();
1778  if (parentOp) {
1779  if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
1780  Value castSource = castOp.getSource();
1781  auto castSourceType = castSource.getType().cast<RankedTensorType>();
1782  if (castSourceType.hasStaticShape())
1783  sourceShape = castSourceType.getShape();
1784  }
1785  }
1786 
1787  // If the source shape's dimension has a static shape, map the affine dim
1788  // expression to the known static size.
1789  for (unsigned i = 0; i < sourceShape.size(); i++) {
1790  if (sourceType.isDynamicDim(i))
1791  continue;
1792  if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast<AffineDimExpr>())
1793  affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
1794  }
1795  }
1796 }
1797 
1798 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
1799 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
1800 /// their result types is stored in `resultTypes`. If `opOperand` requires no
1801 /// change then `changeNeeded` is false and same operand is added in the
1802 /// `newOperands` list.
1803 static void createNewOperandWithStaticSizes(
1804  Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
1805  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
1806  SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
1807  bool &changeNeeded) {
1808  Value src = opOperand->get();
1809  newOperands.push_back(src);
1810  if (linalgOp.isScalar(opOperand))
1811  return;
1812  auto sourceType = src.getType().cast<RankedTensorType>();
1813  Type resultType = sourceType;
1814  if (sourceType.hasStaticShape() && linalgOp.isOutputTensor(opOperand)) {
1815  resultTypes.push_back(resultType);
1816  return;
1817  }
1818  ArrayRef<int64_t> sourceShape = sourceType.getShape();
1819  AffineMap sourceMap = linalgOp.getTiedIndexingMap(opOperand);
1820  SmallVector<int64_t> newShape;
1821  // If operand is updated with new shape, `newOperandNeeded` will be
1822  // true.
1823  bool newOperandNeeded = false;
1824  for (unsigned i = 0; i < sourceShape.size(); i++) {
1825  int64_t dimShape = sourceShape[i];
1826  AffineExpr dimExpr = sourceMap.getResult(i);
1827  if (affineExprToSize.find(dimExpr) == affineExprToSize.end() ||
1828  !sourceType.isDynamicDim(i)) {
1829  newShape.push_back(dimShape);
1830  continue;
1831  }
1832  // Dimension has a dynamic shape and corresponding affine dim
1833  // expression is present in the map. So assign the size for the
1834  // given affine dim expression to the dimension.
1835  newShape.push_back(affineExprToSize[dimExpr]);
1836  newOperandNeeded = true;
1837  }
1838  resultType = RankedTensorType::get(newShape, sourceType.getElementType());
1839  if (newOperandNeeded) {
1840  changeNeeded = true;
1841  // Get the new operand value given its size and element type by
1842  // casting it.
1843  Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
1844  unsigned index = opOperand->getOperandNumber();
1845  newOperands[index] = newOperand;
1846  }
1847  if (linalgOp.isOutputTensor(opOperand))
1848  resultTypes.push_back(resultType);
1849 }
1850 
1851 /// Static shapes for the operands can be inferred if any one of the operands
1852 /// have a static shape. This can be done by referring to the affine dim
1853 /// expressions for the operand.
1854 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
1856 
1857  LogicalResult matchAndRewrite(LinalgOp linalgOp,
1858  PatternRewriter &rewriter) const override {
1859  if (!linalgOp.hasTensorSemantics())
1860  return failure();
1861 
1862  // Maps must be projected permutations.
1863  if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
1864  return !map.isProjectedPermutation();
1865  }))
1866  return failure();
1867 
1868  // Maps affine dim expressions to the static size of that dimension.
1869  llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
1870  Location loc = linalgOp.getLoc();
1871 
1872  // For each of the affine dim expression, check if the size is known. If
1873  // known add that in the map.
1874  populateMap(linalgOp, linalgOp.getInputAndOutputOperands(),
1875  affineExprToSize);
1876 
1877  SmallVector<Value> newOperands;
1878  SmallVector<Type> resultTypes;
1879 
1880  // `changeNeeded` is `false` if the operands of `linalgOp` require no
1881  // change in their types.
1882  bool changeNeeded = false;
1883  newOperands.reserve(linalgOp.getNumInputsAndOutputs());
1884  resultTypes.reserve(linalgOp.getNumOutputs());
1885 
1886  // Iterate over all the operands and update the static sizes.
1887  for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
1888  createNewOperandWithStaticSizes(loc, rewriter, opOperand,
1889  affineExprToSize, linalgOp, newOperands,
1890  resultTypes, changeNeeded);
1891  }
1892 
1893  // If the generic op has all the required static information, no
1894  // canonicalization needed.
1895  if (!changeNeeded)
1896  return failure();
1897 
1898  // Clone op.
1899  Operation *newOp =
1900  linalgOp.clone(rewriter, linalgOp->getLoc(), resultTypes, newOperands);
1901  SmallVector<Value> replacements;
1902  replacements.reserve(newOp->getNumResults());
1903  for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
1904  Value newResult = std::get<1>(it);
1905  Value oldResult = std::get<0>(it);
1906  Type newType = newResult.getType();
1907  Type oldType = oldResult.getType();
1908  replacements.push_back(
1909  (newType != oldType)
1910  ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
1911  : newResult);
1912  }
1913  rewriter.replaceOp(linalgOp, replacements);
1914  return success();
1915  }
1916 };
1917 
1918 } // namespace
1919 
1920 // All named ops canonicalizers and folders are auto-generated in the
1921 // .cpp.inc.
1922 
1923 //===----------------------------------------------------------------------===//
1924 // LinalgDialect
1925 //===----------------------------------------------------------------------===//
1926 
1927 void LinalgDialect::getCanonicalizationPatterns(
1928  RewritePatternSet &results) const {
1929  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
1930  FoldTensorCastProducerOp, InferStaticShapeOfOperands>(
1931  getContext());
1932 }
1933 
1935  Attribute value, Type type,
1936  Location loc) {
1937  return builder.create<arith::ConstantOp>(loc, type, value);
1938 }
Location getUnknownLoc()
Definition: Builders.cpp:26
Include the generated interface declarations.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context)
This parses a single MLIR attribute to an MLIR context if it was valid.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:335
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
SmallVector< OpFoldResult, 4 > getMixedSizes(ArrayAttr staticValues, ValueRange dynamicValues)
Return a vector of all the static and dynamic sizes.
virtual ParseResult parseLParen()=0
Parse a ( token.
MLIRContext * getContext() const
Definition: Builders.h:54
U cast() const
Definition: Attributes.h:135
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
Definition: MPInt.h:369
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops...
Definition: LinalgOps.cpp:126
This is a value defined by a result of an operation.
Definition: Value.h:425
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
Block represents an ordered list of Operations.
Definition: Block.h:29
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec, int64_t sentinel)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:344
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
AffineMap extractOrIdentityMap(Optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank...
Definition: LinalgOps.cpp:1560
This class represents a single result from folding an operation.
Definition: OpDefinition.h:239
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:356
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:204
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
static void buildStructuredOp(OpBuilder &b, OperationState &state, llvm::Optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
Definition: LinalgOps.cpp:96
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:230
operand_type_range getOperandTypes()
Definition: Operation.h:314
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
int64_t floor(Fraction f)
Definition: Fraction.h:63
Operation & front()
Definition: Block.h:144
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:244
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:197
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
BlockArgument getArgument(unsigned i)
Definition: Block.h:120
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
static constexpr const bool value
SmallVector< Value, 4 > operands
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:149
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:311
static DefaultResource * get()
Returns a unique instance for the given effect class.
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:282
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:300
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:77
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
Definition: LinalgOps.cpp:59
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:1609
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
void addOperands(ValueRange newOperands)
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, ValueRange results, ValueRange inputBuffers, ValueRange outputs)
Definition: LinalgOps.cpp:824
U dyn_cast() const
Definition: Types.h:270
Operation * clone(BlockAndValueMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:554
Attributes are known-constant values of operations.
Definition: Attributes.h:24
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
type_range getTypes() const
Definition: ValueRange.cpp:44
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:437
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
bool isIndex() const
Definition: Types.cpp:28
virtual ParseResult parseRParen()=0
Parse a ) token.
Base type for affine expression.
Definition: AffineExpr.h:68
static void appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:1588
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:324
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
void addTypes(ArrayRef< Type > newTypes)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:360
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
This represents an operation in an abstracted form, suitable for use with the builder APIs...
void getDimsOfType(Operation *op, StringRef iteratorTypeName, SmallVectorImpl< unsigned > &res)
Return the dims that are iteratorTypeName loops in the LinalgOp op.
Definition: LinalgOps.cpp:1546
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:42
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:237
BlockArgListType getArguments()
Definition: Block.h:76
OpOperand vector that implicitly converts to a Value vector.
This class represents an argument of a Block.
Definition: Value.h:300
ParseResult resolveOperands(ArrayRef< UnresolvedOperand > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
This class represents a specific instance of an effect.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:846
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:489
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:93
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:135
bool empty()
Definition: Block.h:139
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:203
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:166
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:230
NamedAttrList attributes
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:377
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:294
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
result_iterator result_end()
Definition: Operation.h:331
int64_t ceil(Fraction f)
Definition: Fraction.h:65
Region * addRegion()
Create a region that should be attached to the operation.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1487
Type getType() const
Return the type of this value.
Definition: Value.h:118
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
ImplicitLocOpBuilder maintains a &#39;current location&#39;, allowing use of the create<> method without spec...
U dyn_cast() const
Definition: Attributes.h:127
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:178
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:80
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
This class represents an operand of an operation.
Definition: Value.h:251
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
U cast() const
Definition: Value.h:108
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:382
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
Definition: PatternMatch.h:512
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:372
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
static LogicalResult foldMemRefCast(Operation *op)
This is a common class used for patterns of the form someop(memrefcast(%src)) -> someop(%src) It fold...
Definition: LinalgOps.cpp:260
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:1571
virtual ParseResult parseEqual()=0
Parse a = token.
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with &#39;numDims&#39; identity result dim exprs.
Definition: AffineMap.cpp:244
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with &#39;argTypes&#39; arguments and set the insertion point to the end of it...
Definition: Builders.cpp:377
LogicalResult verifyListOfOperandsOrIntegers(Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr, ValueRange values, function_ref< bool(int64_t)> isDynamic)
Verify that a the values has as many elements as the number of entries in attr for which isDynamic ev...
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
bool isa() const
Definition: Types.h:254
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor...
Definition: TensorOps.cpp:123
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:67
This class represents success/failure for parsing-like operations that find it important to chain tog...
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
result_range getResults()
Definition: Operation.h:332
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block &#39;source&#39; into the end of block &#39;dest&#39;.
This class helps build Operations.
Definition: Builders.h:192
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
result_type_range getResultTypes()
Definition: Operation.h:345
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:225
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
result_iterator result_begin()
Definition: Operation.h:330
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
Definition: FoldUtils.cpp:50
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:370
U cast() const
Definition: Types.h:278
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:1580
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition: Builders.cpp:270