MLIR  17.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
29 #include "mlir/IR/AffineMap.h"
31 #include "mlir/IR/Matchers.h"
34 #include "mlir/IR/PatternMatch.h"
36 
37 #include "llvm/ADT/DenseMap.h"
38 #include "llvm/ADT/SmallSet.h"
39 #include "llvm/ADT/StringSet.h"
40 #include "llvm/ADT/TypeSwitch.h"
41 #include "llvm/Support/FormatVariadic.h"
42 #include "llvm/Support/MathExtras.h"
43 #include "llvm/Support/raw_ostream.h"
44 #include <optional>
45 
46 using namespace mlir;
47 using namespace mlir::linalg;
48 
49 //===----------------------------------------------------------------------===//
50 // Support for named Linalg ops defined in ods-gen.
51 //===----------------------------------------------------------------------===//
52 
55 
56 /// Fills the region of a structured operation using the provided
57 /// `regionBuilder`. The method is used by both named structured ops created by
58 /// ods-gen and by manually defined C++ ops. It is called by both builders and
59 /// parsers and creates a block with arguments corresponding to the elemental
60 /// types of `inputTypes` and `outputTypes`. All output types are asserted to be
61 /// ShapedType.
62 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
63  TypeRange inputTypes, TypeRange outputTypes,
65  RegionBuilderFn regionBuilder) {
66  assert(llvm::all_of(outputTypes,
67  [](Type t) { return llvm::isa<ShapedType>(t); }));
68 
69  // TODO: atm all operands go through getElementTypeOrSelf,
70  // reconsider when we have evidence we need to.
71  SmallVector<Type, 8> argTypes;
73  for (auto containers : {inputTypes, outputTypes}) {
74  for (auto t : containers) {
75  argTypes.push_back(getElementTypeOrSelf(t));
76 
77  // TODO: Pass in a proper location here.
78  argLocs.push_back(opBuilder.getUnknownLoc());
79  }
80  }
81 
82  // RAII.
83  OpBuilder::InsertionGuard guard(opBuilder);
84  Block *body =
85  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
86 
87  opBuilder.setInsertionPointToStart(body);
88  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
89  regionBuilder(b, *body, attrs);
90 
91  // indexing_maps is an auto-generated method.
92 
93  // iterator_types is an auto-generated method.
94 }
95 
96 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
97 /// The result types are derived automatically if `resultTensorTypes` is none.
98 /// The body of the operation is filled using `regionBuilder`. All ods-gen
99 /// created structured operations use the method to implement their builders.
101  std::optional<TypeRange> resultTensorTypes,
102  ValueRange inputs, ValueRange outputs,
103  ArrayRef<NamedAttribute> attributes,
104  RegionBuilderFn regionBuilder) {
105  // Derive the result types if needed.
106  SmallVector<Type> derivedResultTypes =
107  resultTensorTypes.value_or(TypeRange());
108  if (!resultTensorTypes)
109  copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
110  [](Type type) { return llvm::isa<RankedTensorType>(type); });
111 
112  state.addOperands(inputs);
113  state.addOperands(outputs);
114  state.addTypes(derivedResultTypes);
115  state.addAttributes(attributes);
116  state.addAttribute(
117  "operand_segment_sizes",
118  b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
119  static_cast<int32_t>(outputs.size())}));
120 
121  // Create and fill the region of the structured operation.
122  Region &region = *state.addRegion();
123  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
124  state.attributes.getAttrs(), regionBuilder);
125 }
126 
127 /// Common parsing used for both named structured ops created by ods-gen and by
128 /// manually defined C++ ops. Does not handle regions.
129 static ParseResult
131  SmallVectorImpl<Type> &inputTypes,
132  SmallVectorImpl<Type> &outputTypes,
133  bool addOperandSegmentSizes = true) {
134  SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
136  outputsOperands;
137 
138  if (succeeded(parser.parseOptionalLess())) {
139  if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
140  return failure();
141  }
142  attrsLoc = parser.getCurrentLocation();
143  if (parser.parseOptionalAttrDict(result.attributes))
144  return failure();
145 
146  if (succeeded(parser.parseOptionalKeyword("ins"))) {
147  if (parser.parseLParen())
148  return failure();
149 
150  inputsOperandsLoc = parser.getCurrentLocation();
151  if (parser.parseOperandList(inputsOperands) ||
152  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
153  return failure();
154  }
155 
156  if (succeeded(parser.parseOptionalKeyword("outs"))) {
157  outputsOperandsLoc = parser.getCurrentLocation();
158  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
159  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
160  return failure();
161  }
162 
163  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
164  result.operands) ||
165  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
166  result.operands))
167  return failure();
168 
169  if (addOperandSegmentSizes) {
170  // This is a bit complex because we're trying to be backward compatible with
171  // operation syntax that mix the inherent attributes and the discardable ones
172  // in the same dictionary.
173  // If the properties are used, we append the operand_segment_sizes there directly.
174  // Otherwise we append it to the discardable attributes dictionary where it is
175  // handled by the generic Operation::create(...) method.
176  if (result.propertiesAttr) {
177  NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
178  attrs.append("operand_segment_sizes",
180  {static_cast<int32_t>(inputsOperands.size()),
181  static_cast<int32_t>(outputsOperands.size())}));
182  result.propertiesAttr = attrs.getDictionary(parser.getContext());
183  } else {
184  result.addAttribute("operand_segment_sizes",
186  {static_cast<int32_t>(inputsOperands.size()),
187  static_cast<int32_t>(outputsOperands.size())}));
188  }
189  }
190  if (!result.propertiesAttr) {
191  std::optional<RegisteredOperationName> info =
192  result.name.getRegisteredInfo();
193  if (info) {
194  if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
195  return parser.emitError(attrsLoc)
196  << "'" << result.name.getStringRef() << "' op ";
197  })))
198  return failure();
199  }
200  }
201  return success();
202 }
203 
205  ValueRange outputs) {
206  if (!inputs.empty())
207  p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
208  if (!outputs.empty())
209  p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // Specific parsing and printing for named structured ops created by ods-gen.
214 //===----------------------------------------------------------------------===//
215 
217  OpAsmParser &parser, Region &region, unsigned numRegionArgs,
218  TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
219  RegionBuilderFn regionBuilder) {
220  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
221  return parser.emitError(
222  parser.getCurrentLocation(),
223  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
224  "region expects {0} args, got {1}",
225  numRegionArgs, inputTypes.size() + outputTypes.size()));
226  }
227 
228  OpBuilder opBuilder(parser.getContext());
229  fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
230  regionBuilder);
231  return success();
232 }
233 
234 static ParseResult
236  SmallVectorImpl<Type> &resultTypes) {
237  if (parser.parseOptionalArrowTypeList(resultTypes))
238  return failure();
239  return success();
240 }
241 
243  OperationState &result,
244  unsigned numRegionArgs,
245  RegionBuilderFn regionBuilder) {
246  // TODO: Enable when ods-gen supports captures.
247  SmallVector<Type, 1> inputTypes, outputTypes;
248  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
249  return failure();
250 
251  // TODO: consider merging results parsing into region parsing.
252  // Need to wait for declarative assembly resolution to decide.
253  SmallVector<Type, 1> outputTensorsTypes;
254  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
255  return failure();
256  result.addTypes(outputTensorsTypes);
257 
258  std::unique_ptr<Region> region = std::make_unique<Region>();
259  if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
260  outputTypes, result.attributes.getAttrs(),
261  regionBuilder))
262  return failure();
263  result.addRegion(std::move(region));
264 
265  return success();
266 }
267 
269  TypeRange resultTypes) {
270  if (resultTypes.empty())
271  return;
272  p.printOptionalArrowTypeList(resultTypes);
273 }
274 
276  ValueRange inputs, ValueRange outputs) {
278  op->getAttrs(),
279  /*elidedAttrs=*/{"operand_segment_sizes",
280  // See generated code in
281  // LinalgNamedStructuredOps.yamlgen.cpp.inc
282  "linalg.memoized_indexing_maps"});
283 
284  // Printing is shared with generic ops, except for the region and
285  // attributes.
286  printCommonStructuredOpParts(p, inputs, outputs);
287 
288  // Results printing.
290 
291  // Region is elided.
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // Region builder helper.
296 // TODO: Move this to a utility library.
297 // The public methods on this class are referenced directly from generated code.
298 // Helper build the unary, binary, and type conversion functions defined by the
299 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
300 // class.
301 //
302 // Implementations of the math functions must be polymorphic over numeric types,
303 // internally performing necessary casts. If the function application makes no
304 // sense, then the only recourse is to assert and return nullptr. This can be
305 // extended later if it becomes possible to fail construction of the region. The
306 // invariant should be enforced at a higher level.
307 //
308 // TODO: These helpers are currently type polymorphic over the class of integer
309 // and floating point types, but they will not internally cast within bit
310 // widths of a class (mixed precision such as i8->i32) or across classes
311 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
312 // to be handled with care and work is being considered to extend the op
313 // language to make such cases explicit. In the mean-time, violating this will
314 // fail verification, which is deemed acceptable.
315 //===----------------------------------------------------------------------===//
316 
317 namespace {
318 
319 class RegionBuilderHelper {
320 public:
321  RegionBuilderHelper(MLIRContext *context, Block &block)
322  : context(context), block(block) {}
323 
324  // Build the unary functions defined by OpDSL.
325  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
326  if (!isFloatingPoint(arg))
327  llvm_unreachable("unsupported non numeric type");
328  OpBuilder builder = getBuilder();
329  switch (unaryFn) {
330  case UnaryFn::exp:
331  return builder.create<math::ExpOp>(arg.getLoc(), arg);
332  case UnaryFn::log:
333  return builder.create<math::LogOp>(arg.getLoc(), arg);
334  case UnaryFn::abs:
335  return builder.create<math::AbsFOp>(arg.getLoc(), arg);
336  case UnaryFn::ceil:
337  return builder.create<math::CeilOp>(arg.getLoc(), arg);
338  case UnaryFn::floor:
339  return builder.create<math::FloorOp>(arg.getLoc(), arg);
340  case UnaryFn::negf:
341  return builder.create<arith::NegFOp>(arg.getLoc(), arg);
342  }
343  llvm_unreachable("unsupported unary function");
344  }
345 
346  // Build the binary functions defined by OpDSL.
347  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
348  bool allComplex = isComplex(arg0) && isComplex(arg1);
349  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
350  bool allInteger = isInteger(arg0) && isInteger(arg1);
351  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
352  arg1.getType().getIntOrFloatBitWidth() == 1;
353  if (!allComplex && !allFloatingPoint && !allInteger)
354  llvm_unreachable("unsupported non numeric type");
355  OpBuilder builder = getBuilder();
356  switch (binaryFn) {
357  case BinaryFn::add:
358  if (allComplex)
359  return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
360  if (allFloatingPoint)
361  return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
362  if (allBool)
363  return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
364  return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
365  case BinaryFn::sub:
366  if (allComplex)
367  return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
368  if (allFloatingPoint)
369  return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
370  if (allBool)
371  llvm_unreachable("unsupported operation: sub with bools");
372  return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
373  case BinaryFn::mul:
374  if (allComplex)
375  return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
376  if (allFloatingPoint)
377  return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
378  if (allBool)
379  return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
380  return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
381  case BinaryFn::max_signed:
382  assert(!allComplex);
383  if (allFloatingPoint)
384  return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
385  return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
386  case BinaryFn::min_signed:
387  assert(!allComplex);
388  if (allFloatingPoint)
389  return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
390  return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
391  case BinaryFn::max_unsigned:
392  assert(!allComplex);
393  if (allFloatingPoint)
394  return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
395  return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
396  case BinaryFn::min_unsigned:
397  assert(!allComplex);
398  if (allFloatingPoint)
399  return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
400  return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
401  }
402  llvm_unreachable("unsupported binary function");
403  }
404 
405  // Build the type functions defined by OpDSL.
406  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
407  switch (typeFn) {
408  case TypeFn::cast_signed:
409  return cast(toType, operand, false);
410  case TypeFn::cast_unsigned:
411  return cast(toType, operand, true);
412  }
413  llvm_unreachable("unsupported type conversion function");
414  }
415 
416  void yieldOutputs(ValueRange values) {
417  OpBuilder builder = getBuilder();
418  Location loc = builder.getUnknownLoc();
419  builder.create<YieldOp>(loc, values);
420  }
421 
422  Value constant(const std::string &value) {
423  OpBuilder builder = getBuilder();
424  Location loc = builder.getUnknownLoc();
425  Attribute valueAttr = parseAttribute(value, builder.getContext());
426  return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
427  }
428 
429  Value index(int64_t dim) {
430  OpBuilder builder = getBuilder();
431  return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
432  }
433 
434  Type getIntegerType(unsigned width) {
435  return IntegerType::get(context, width);
436  }
437 
438  Type getFloat32Type() { return Float32Type::get(context); }
439  Type getFloat64Type() { return Float64Type::get(context); }
440 
441 private:
442  // Generates operations to cast the given operand to a specified type.
443  // If the cast cannot be performed, a warning will be issued and the
444  // operand returned as-is (which will presumably yield a verification
445  // issue downstream).
446  Value cast(Type toType, Value operand, bool isUnsignedCast) {
447  OpBuilder builder = getBuilder();
448  auto loc = operand.getLoc();
449  return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
450  }
451 
452  bool isComplex(Value value) {
453  return llvm::isa<ComplexType>(value.getType());
454  }
455  bool isFloatingPoint(Value value) {
456  return llvm::isa<FloatType>(value.getType());
457  }
458  bool isInteger(Value value) {
459  return llvm::isa<IntegerType>(value.getType());
460  }
461 
462  OpBuilder getBuilder() {
463  OpBuilder builder(context);
464  builder.setInsertionPointToEnd(&block);
465  return builder;
466  }
467 
468  MLIRContext *context;
469  Block &block;
470 };
471 
472 } // namespace
473 
474 //===----------------------------------------------------------------------===//
475 // FillOp
476 //===----------------------------------------------------------------------===//
477 
478 namespace {
479 
480 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
481 ///
482 /// For such op chains, we can create new linalg.fill ops with the result
483 /// type of the tensor.expand/collapse_shape op.
484 template <typename TensorReshapeOp>
485 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
487  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
488  PatternRewriter &rewriter) const override {
489  auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
490  if (!oldFill)
491  return failure();
492 
493  Location loc = oldFill.getLoc();
494  auto newInit = rewriter.create<TensorReshapeOp>(
495  loc, reshapeOp.getResultType(), oldFill.output(),
496  reshapeOp.getReassociation());
497  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
498  ValueRange{newInit});
499 
500  return success();
501  }
502 };
503 
504 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
505 /// filling value are the same.
506 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
508 
509  LogicalResult matchAndRewrite(tensor::PadOp padOp,
510  PatternRewriter &rewriter) const override {
511  auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
512  if (!fillOp)
513  return failure();
514 
515  // We can only fold if the padding value is the same as the original
516  // filling value.
517  Value padValue = padOp.getConstantPaddingValue();
518  if (!padValue || fillOp.value() != padValue)
519  return failure();
520 
521  ReifiedRankedShapedTypeDims reifiedShape;
522  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
523  return rewriter.notifyMatchFailure(
524  padOp, "failed to reify tensor.pad op result shape");
525 
526  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
527  padOp.getLoc(), reifiedShape.front(),
528  padOp.getResultType().getElementType());
529  Value replacement =
530  rewriter
531  .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
532  ValueRange{emptyTensor})
533  .getResult(0);
534  if (replacement.getType() != padOp.getResultType()) {
535  replacement = rewriter.create<tensor::CastOp>(
536  fillOp.getLoc(), padOp.getResultType(), replacement);
537  }
538  rewriter.replaceOp(padOp, replacement);
539  return success();
540  }
541 };
542 
543 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
544 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
545 /// filling value are the same.
546 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
548 
549  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
550  PatternRewriter &rewriter) const override {
551  auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
552  if (!srcPadOp)
553  return failure();
554 
555  if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
556  return failure();
557 
558  // Walk back the tensor.insert_slice chain and find the first destination
559  // value at the start of the chain.
560  Value firstDest = insertOp.getDest();
561  while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
562  if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
563  return failure();
564 
565  // Make sure the range of values accessed are disjoint. Without this, we
566  // cannot fold tensor.pad away.
567  bool disjoint = false;
568  for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
569  // If the dimension has dynamic offset/size, we cannot guarantee
570  // disjoint. So just skip it.
571  if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
572  insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
573  prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
574  continue;
575 
576  // Get the range start and end, inclusively for both.
577  int64_t prevStart = prevOp.getStaticOffset(i);
578  int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
579  prevOp.getStaticStride(i);
580  int64_t nextStart = insertOp.getStaticOffset(i);
581  int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
582  insertOp.getStaticStride(i);
583  if (prevEnd < nextStart || nextEnd < prevStart) {
584  disjoint = true;
585  break;
586  }
587  }
588 
589  if (!disjoint)
590  break;
591  firstDest = prevOp.getDest();
592  }
593 
594  // Check whether the first destination is a fill op. For overlapped cases,
595  // this also cannot be true.
596  auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
597  if (!dstFillOp)
598  return failure();
599 
600  // We can only fold if the padding value is the same as the original
601  // filling value.
602  Value padValue = srcPadOp.getConstantPaddingValue();
603  if (!padValue || dstFillOp.value() != padValue)
604  return failure();
605 
606  SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
607  SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
608 
609  Location loc = insertOp.getLoc();
610  MLIRContext *context = getContext();
611 
612  AffineExpr sym0, sym1;
613  bindSymbols(context, sym0, sym1);
614  auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
615 
616  // Calculate the new offsets for the insert. It should be the old offsets
617  // plus low padding sizes.
618  SmallVector<OpFoldResult, 4> newOffsets;
619  for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
620  newOffsets.push_back(affine::makeComposedFoldedAffineApply(
621  rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
622  }
623 
625  for (int i = 0, e = srcPadOp.getSourceType().getRank(); i < e; ++i) {
626  newSizes.push_back(
627  rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
628  .getResult());
629  }
630 
631  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
632  insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
633  newSizes, insertOp.getMixedStrides());
634  return success();
635  }
636 };
637 
638 } // namespace
639 
640 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
641  MLIRContext *context) {
642  results
643  .add<FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
644  FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
645  FoldInsertPadIntoFill>(context);
646 }
647 
648 //===----------------------------------------------------------------------===//
649 // GenericOp
650 //===----------------------------------------------------------------------===//
651 
652 static void buildGenericRegion(
653  OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
654  ValueRange outputs,
655  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
656  SmallVector<Type, 4> blockArgTypes;
657  SmallVector<Location, 4> blockArgLocs;
658  for (ValueRange container : {inputs, outputs}) {
659  for (Value v : container) {
660  blockArgTypes.push_back(getElementTypeOrSelf(v));
661  blockArgLocs.push_back(v.getLoc());
662  }
663  }
664 
665  OpBuilder::InsertionGuard guard(builder);
666  Block *bodyBlock =
667  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
668  bodyBuild(builder, loc, bodyBlock->getArguments());
669 }
670 
671 void GenericOp::getAsmBlockArgumentNames(Region &region,
672  OpAsmSetValueNameFn setNameFn) {
673  for (Value v : getRegionInputArgs())
674  setNameFn(v, "in");
675  for (Value v : getRegionOutputArgs())
676  setNameFn(v, "out");
677 }
678 
679 void GenericOp::build(
680  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
681  ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
682  ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
683  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
684  ArrayRef<NamedAttribute> attributes) {
685  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
686  iteratorTypes, doc, libraryCall);
687  result.addAttributes(attributes);
688  if (bodyBuild)
689  buildGenericRegion(builder, result.location, *result.regions.front(),
690  inputs, outputs, bodyBuild);
691 }
692 
693 void GenericOp::build(
694  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
695  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
696  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
697  StringRef libraryCall,
698  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
699  ArrayRef<NamedAttribute> attributes) {
700  build(builder, result, resultTensorTypes, inputs, outputs,
701  builder.getAffineMapArrayAttr(indexingMaps),
702  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
703  iteratorTypes,
704  [&](utils::IteratorType iter) -> mlir::Attribute {
705  return IteratorTypeAttr::get(builder.getContext(), iter);
706  }))),
707  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
708  libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
709  bodyBuild, attributes);
710 }
711 
712 void GenericOp::build(
713  OpBuilder &builder, OperationState &result, ValueRange inputs,
714  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
715  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
716  StringRef libraryCall,
717  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
718  ArrayRef<NamedAttribute> attributes) {
719  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
720  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
721 }
722 
723 void GenericOp::build(
724  OpBuilder &builder, OperationState &result, ValueRange inputs,
725  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
726  ArrayRef<utils::IteratorType> iteratorTypes,
727  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
728  ArrayRef<NamedAttribute> attributes) {
729  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
730  /*doc=*/"",
731  /*libraryCall=*/"", bodyBuild, attributes);
732 }
733 
734 void GenericOp::build(
735  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
736  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
737  ArrayRef<utils::IteratorType> iteratorTypes,
738  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
739  ArrayRef<NamedAttribute> attributes) {
740  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
741  iteratorTypes,
742  /*doc=*/"",
743  /*libraryCall=*/"", bodyBuild, attributes);
744 }
745 
747  p << " ";
748 
749  // Print extra attributes.
750  auto genericAttrNames = linalgTraitAttrNames();
751 
752  llvm::StringSet<> genericAttrNamesSet;
753  genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
754  SmallVector<NamedAttribute, 8> genericAttrs;
755  for (auto attr : (*this)->getAttrs()) {
756  if (attr.getName() == getIteratorTypesAttrName()) {
757  auto iteratorTypes =
758  llvm::cast<ArrayAttr>(attr.getValue())
759  .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
760  // Convert IteratorType enums into the string representation. This is
761  // needed, because tests still use the old format when 'iterator_types'
762  // attribute is represented as an array of strings.
763  // TODO: Remove this conversion once tests are fixed.
764  SmallVector<Attribute> iteratorTypeNames =
765  llvm::to_vector(llvm::map_range(
766  iteratorTypes, [&](utils::IteratorType t) -> Attribute {
767  return StringAttr::get(getContext(), stringifyIteratorType(t));
768  }));
769 
770  genericAttrs.emplace_back(
771  getIteratorTypesAttrName(),
772  ArrayAttr::get(getContext(), iteratorTypeNames));
773  } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
774  genericAttrs.push_back(attr);
775  }
776  }
777  if (!genericAttrs.empty()) {
778  auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
779  p << genericDictAttr;
780  }
781 
782  // Printing is shared with named ops, except for the region and attributes
783  printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
784  SmallVector<Value>(getDpsInitOperands()));
785 
786  genericAttrNames.push_back("operand_segment_sizes");
787  genericAttrNamesSet.insert(genericAttrNames.back());
788 
789  bool hasExtraAttrs = false;
790  for (NamedAttribute n : (*this)->getAttrs()) {
791  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
792  break;
793  }
794  if (hasExtraAttrs) {
795  p << " attrs = ";
796  p.printOptionalAttrDict((*this)->getAttrs(),
797  /*elidedAttrs=*/genericAttrNames);
798  }
799 
800  // Print region.
801  if (!getRegion().empty()) {
802  p << ' ';
803  p.printRegion(getRegion());
804  }
805 
806  // Print results.
807  printNamedStructuredOpResults(p, getResultTensors().getTypes());
808 }
809 
810 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
811  DictionaryAttr dictAttr;
812  // Parse the core linalg traits that must check into a dictAttr.
813  // The name is unimportant as we will overwrite result.attributes.
814  // The core linalg traits must contain the information necessary to pass the
815  // verifier.
816  llvm::SMLoc attributeLocation = parser.getCurrentLocation();
817  if (parser.parseAttribute(dictAttr, "_", result.attributes))
818  return failure();
819  result.attributes.assign(dictAttr.getValue().begin(),
820  dictAttr.getValue().end());
821 
822  // Convert array of string into an array of IteratorType enums. This is
823  // needed, because tests still use the old format when 'iterator_types'
824  // attribute is represented as an array of strings.
825  // TODO: Remove this conversion once tests are fixed.
826  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
827  result.attributes.get(getIteratorTypesAttrName(result.name)));
828  if (!iteratorTypes) {
829  return parser.emitError(attributeLocation)
830  << "expected " << getIteratorTypesAttrName(result.name)
831  << " array attribute";
832  }
833 
834  SmallVector<Attribute> iteratorTypeAttrs;
835 
836  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
837  auto maybeIteratorType = utils::symbolizeIteratorType(s);
838  if (!maybeIteratorType.has_value())
839  return parser.emitError(parser.getCurrentLocation())
840  << "unexpected iterator_type (" << s << ")";
841 
842  iteratorTypeAttrs.push_back(
843  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
844  }
845  result.attributes.set(getIteratorTypesAttrName(result.name),
846  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
847 
848  // Parsing is shared with named ops, except for the region.
849  SmallVector<Type, 1> inputTypes, outputTypes;
850  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
851  return failure();
852 
853  // Optional attributes may be added.
854  if (succeeded(parser.parseOptionalKeyword("attrs")))
855  if (failed(parser.parseEqual()) ||
856  failed(parser.parseOptionalAttrDict(result.attributes)))
857  return failure();
858 
859  std::unique_ptr<Region> region = std::make_unique<Region>();
860  if (parser.parseRegion(*region, {}))
861  return failure();
862  result.addRegion(std::move(region));
863 
864  // Generic ops may specify that a subset of its outputs are tensors. Such
865  // outputs are specified in the result type.
866  // TODO: may need to move output parsing before region parsing.
867  // Need to wait for declarative assembly resolution to decide.
868  SmallVector<Type, 1> outputTensorsTypes;
869  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
870  return failure();
871  result.addTypes(outputTensorsTypes);
872 
873  return success();
874 }
875 
878  &effects,
879  ValueRange results, const OpOperandVector &inputOperands,
880  const OpOperandVector &outputOperands) {
881  for (auto *operand : inputOperands) {
882  if (!llvm::isa<MemRefType>(operand->get().getType()))
883  continue;
884  effects.emplace_back(MemoryEffects::Read::get(), operand->get(),
886  }
887  for (auto *operand : outputOperands) {
888  if (!llvm::isa<MemRefType>(operand->get().getType()))
889  continue;
890  effects.emplace_back(MemoryEffects::Read::get(), operand->get(),
892  effects.emplace_back(MemoryEffects::Write::get(), operand->get(),
894  }
895 }
896 
897 void GenericOp::getEffects(
899  &effects) {
900  getGenericEffectsImpl(effects, getOperation()->getResults(),
901  getDpsInputOperands(), getDpsInitOperands());
902 }
903 
905 
906 namespace {
907 
908 /// Remove generic operations (on tensors) that are just copying
909 /// the values from inputs to the results. Requirements are
910 /// 1) All iterator types are parallel
911 /// 2) The body contains just a yield operation with the yielded values being
912 /// the arguments corresponding to the operands.
913 struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
915 
916  LogicalResult matchAndRewrite(GenericOp genericOp,
917  PatternRewriter &rewriter) const override {
918  // Check all indexing maps are identity.
919  if (llvm::any_of(genericOp.getIndexingMapsArray(),
920  [](AffineMap map) { return !map.isIdentity(); }))
921  return failure();
922 
923  // Check that the body of the linalg operation is just a linalg.yield
924  // operation.
925  Block &body = genericOp.getRegion().front();
926  if (!llvm::hasSingleElement(body))
927  return failure();
928  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
929  if (!yieldOp)
930  return failure();
931 
932  // In the buffer case, we need to check exact buffer equality.
933  if (genericOp.hasBufferSemantics()) {
934  if (genericOp.getNumDpsInputs() == 1 && genericOp.getNumDpsInits() == 1 &&
935  genericOp.getDpsInputOperand(0)->get() ==
936  genericOp.getDpsInitOperand(0)->get()) {
937  rewriter.eraseOp(genericOp);
938  return success();
939  }
940  return failure();
941  }
942 
943  // Mixed semantics is not supported yet.
944  if (!genericOp.hasTensorSemantics())
945  return failure();
946 
947  // Get the argument number of the returned values. That is the operand
948  // number to use for replacing uses of this operation.
949  SmallVector<Value> returnedArgs;
950  for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
951  auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
952  if (!yieldArg || yieldArg.getOwner() != &body)
953  return failure();
954  unsigned argumentNumber = yieldArg.getArgNumber();
955  Value returnedArg = genericOp->getOperand(argumentNumber);
956  Type resultType = genericOp->getResult(yieldVal.index()).getType();
957  // The input can have a different type than the result, e.g. a dynamic
958  // input dimension can be turned into a static output dimension.
959  Type returnType = returnedArg.getType();
960  if (returnType != resultType) {
961  // Distinguish between sparse conversion or dense tensor casting.
962  // TODO: unify the two ops?
963  if (sparse_tensor::getSparseTensorEncoding(returnType) ||
965  returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
966  genericOp.getLoc(), resultType, returnedArg);
967  else {
968  if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
969  resultType))
970  return failure();
971  returnedArg = rewriter.create<tensor::CastOp>(
972  genericOp.getLoc(), resultType, returnedArg);
973  }
974  }
975  returnedArgs.push_back(returnedArg);
976  }
977 
978  if (returnedArgs.size() != genericOp->getNumResults())
979  return failure();
980  rewriter.replaceOp(genericOp, returnedArgs);
981  return success();
982  }
983 };
984 
985 } // namespace
986 
987 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
988  MLIRContext *context) {
989  results.add<EraseIdentityGenericOp>(context);
990 }
991 
992 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
993  return memref::foldMemRefCast(*this);
994 }
995 
996 //===----------------------------------------------------------------------===//
997 // MapOp
998 //===----------------------------------------------------------------------===//
999 
1001  OpAsmParser &parser, OperationState &result,
1002  function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1003  nullptr) {
1004  // Parse `ins` and `outs`.
1005  SmallVector<Type, 4> inputTypes, outputTypes;
1006  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1007  /*addOperandSegmentSizes=*/false))
1008  return failure();
1009 
1010  // Add result types.
1011  for (Type outputType : outputTypes) {
1012  if (llvm::isa<RankedTensorType>(outputType))
1013  result.addTypes(outputType);
1014  }
1015 
1016  // Parse required attributes.
1017  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1018  return failure();
1019 
1020  // Parse optional attributes.
1021  if (parser.parseOptionalAttrDict(result.attributes))
1022  return failure();
1023  return success();
1024 }
1025 
1026 void MapOp::getAsmBlockArgumentNames(Region &region,
1027  OpAsmSetValueNameFn setNameFn) {
1028  for (Value v : getRegionInputArgs())
1029  setNameFn(v, "in");
1030 }
1031 
1032 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1033  if (!getResults().empty())
1034  setNameFn(getResults().front(), "mapped");
1035 }
1036 
1037 void MapOp::build(
1038  OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1039  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1040  ArrayRef<NamedAttribute> attributes) {
1041  build(builder, result, TypeRange{}, inputs, init);
1042  result.addAttributes(attributes);
1043 
1044  // Add output types for `RankedTensorType` output arguments.
1045  Type initType = init.getType();
1046  if (llvm::isa<RankedTensorType>(initType))
1047  result.addTypes(initType);
1048 
1049  if (bodyBuild)
1050  buildGenericRegion(builder, result.location, *result.regions.front(),
1051  inputs, /*outputs=*/{}, bodyBuild);
1052 }
1053 
1054 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1055  const OperationName &payloadOpName,
1056  const NamedAttrList &payloadOpAttrs,
1057  ArrayRef<Value> operands,
1058  bool initFirst = false) {
1059  OpBuilder b(parser.getContext());
1060  Region *body = result.addRegion();
1061  Block &block = body->emplaceBlock();
1062  b.setInsertionPointToStart(&block);
1063  SmallVector<Value> bbArgs;
1064  for (auto &operand : operands) {
1065  block.addArgument(
1066  llvm::cast<ShapedType>(operand.getType()).getElementType(),
1067  b.getUnknownLoc());
1068  }
1069  SmallVector<Value> payloadOpOperands;
1070  // If initFirst flag is enabled, we consider init as the first position of
1071  // payload operands.
1072  if (initFirst) {
1073  payloadOpOperands.push_back(block.getArguments().back());
1074  for (const auto &arg : block.getArguments().drop_back())
1075  payloadOpOperands.push_back(arg);
1076  } else {
1077  payloadOpOperands = {block.getArguments().begin(),
1078  block.getArguments().end()};
1079  }
1080 
1081  Operation *payloadOp = b.create(
1082  result.location, b.getStringAttr(payloadOpName.getStringRef()),
1083  payloadOpOperands,
1084  TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1085  .getElementType()},
1086  payloadOpAttrs);
1087  b.create<YieldOp>(result.location, payloadOp->getResults());
1088 }
1089 
1090 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1091  std::optional<OperationName> payloadOpName;
1092  NamedAttrList payloadOpAttrs;
1093  if (succeeded(parser.parseOptionalLBrace())) {
1094  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1095  if (failed(operationName))
1096  return failure();
1097  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1098  return failure();
1099  payloadOpName = operationName.value();
1100  if (parser.parseRBrace())
1101  return failure();
1102  }
1103 
1104  if (parseDstStyleOp(parser, result))
1105  return failure();
1106 
1107  if (payloadOpName.has_value()) {
1108  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1109  ArrayRef(result.operands).drop_back());
1110  } else {
1112  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1113  /*allowType=*/true, /*allowAttrs=*/true)) {
1114  return failure();
1115  }
1116  Region *body = result.addRegion();
1117  if (parser.parseRegion(*body, regionArgs))
1118  return failure();
1119  }
1120  return success();
1121 }
1122 
1123 // Retrieve the operation from the body, if it is the only one (except
1124 // yield) and if it gets the same amount of arguments as the body does.
1125 // If initFirst flag is enabled, we check that init takes the first position in
1126 // operands of payload.
1127 static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1128  if (body->getOperations().size() != 2)
1129  return nullptr;
1130  Operation &payload = body->getOperations().front();
1131  assert(isa<YieldOp>(body->getOperations().back()));
1132 
1133  if (payload.getNumOperands() == 0 ||
1134  payload.getNumOperands() != body->getNumArguments())
1135  return nullptr;
1136  if (initFirst) {
1137  // check init
1138  if (payload.getOperands().back() != body->getArgument(0))
1139  return nullptr;
1140  // check rest
1141  for (const auto &[operand, bbArg] :
1142  llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1143  if (bbArg != operand)
1144  return nullptr;
1145  }
1146  } else {
1147  for (const auto &[operand, bbArg] :
1148  llvm::zip(payload.getOperands(), body->getArguments())) {
1149  if (bbArg != operand)
1150  return nullptr;
1151  }
1152  }
1153  return &payload;
1154 }
1155 
1156 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1157  SmallVector<StringRef> elidedAttrs;
1158  std::string attrToElide;
1159  p << " { " << payloadOp->getName().getStringRef();
1160  for (const auto &attr : payloadOp->getAttrs()) {
1161  auto fastAttr =
1162  llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1163  if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1164  attrToElide = attr.getName().str();
1165  elidedAttrs.push_back(attrToElide);
1166  break;
1167  }
1168  }
1169  p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1170  p << " }";
1171 }
1172 
1173 void MapOp::print(OpAsmPrinter &p) {
1174  Block *mapper = getBody();
1175  Operation *payloadOp = findPayloadOp(mapper);
1176  if (payloadOp) {
1177  printShortForm(p, payloadOp);
1178  }
1179 
1180  printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
1181  SmallVector<Value>(getDpsInitOperands()));
1182  p.printOptionalAttrDict((*this)->getAttrs());
1183 
1184  if (!payloadOp) {
1185  // Print region if the payload op was not detected.
1186  p.increaseIndent();
1187  p.printNewline();
1188  p << "(";
1189  llvm::interleaveComma(mapper->getArguments(), p,
1190  [&](auto arg) { p.printRegionArgument(arg); });
1191  p << ") ";
1192 
1193  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1194  p.decreaseIndent();
1195  }
1196 }
1197 
1199  auto *bodyBlock = getBody();
1200  auto blockArgs = bodyBlock->getArguments();
1201 
1202  // Checks if the number of `inputs` match the arity of the `mapper` region.
1203  if (getInputs().size() != blockArgs.size())
1204  return emitOpError() << "expects number of operands to match the arity of "
1205  "mapper, but got: "
1206  << getInputs().size() << " and " << blockArgs.size();
1207 
1208  // The parameters of mapper should all match the element type of inputs.
1209  for (const auto &[bbArgType, inputArg] :
1210  llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1211  auto inputElemType =
1212  llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1213  if (bbArgType != inputElemType) {
1214  return emitOpError() << "expected element type of input " << inputElemType
1215  << " to match bbArg type " << bbArgType;
1216  }
1217  }
1218 
1219  // The shape of each input must match the shape of the output.
1220  auto outputShape = getInit().getType().getShape();
1221  for (Type inputArgType : TypeRange{getInputs()}) {
1222  auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1223  if (inputElemShape != outputShape) {
1224  return emitOpError() << "expected shape of input (" << inputElemShape
1225  << ") to match shape of output (" << outputShape
1226  << ")";
1227  }
1228  }
1229 
1230  return success();
1231 }
1232 
1233 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1234  int64_t rank = getInit().getType().getRank();
1235  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1236 }
1237 
1238 ArrayAttr MapOp::getIndexingMaps() {
1239  Builder builder(getContext());
1240  int64_t rank = getInit().getType().getRank();
1241  int64_t numIndexingMaps = getOperands().size();
1243  numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1244 }
1245 
1246 void MapOp::getEffects(
1248  &effects) {
1249  getGenericEffectsImpl(effects, getOperation()->getResults(),
1250  getDpsInputOperands(), getDpsInitOperands());
1251 }
1252 
1253 //===----------------------------------------------------------------------===//
1254 // ReduceOp
1255 //===----------------------------------------------------------------------===//
1256 
1257 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1258  OpAsmSetValueNameFn setNameFn) {
1259  for (Value v : getRegionInputArgs())
1260  setNameFn(v, "in");
1261  for (Value v : getRegionOutputArgs())
1262  setNameFn(v, "init");
1263 }
1264 
1265 void ReduceOp::getAsmResultNames(
1266  function_ref<void(Value, StringRef)> setNameFn) {
1267  if (!getResults().empty())
1268  setNameFn(getResults().front(), "reduced");
1269 }
1270 
1271 void ReduceOp::build(
1272  OpBuilder &builder, OperationState &result, ValueRange inputs,
1273  ValueRange inits, ArrayRef<int64_t> dimensions,
1274  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1275  ArrayRef<NamedAttribute> attributes) {
1276  build(builder, result, TypeRange{}, inputs, inits, dimensions);
1277  result.addAttributes(attributes);
1278 
1279  // Add output types for `RankedTensorType` output arguments.
1280  for (Value init : inits) {
1281  Type initType = init.getType();
1282  if (llvm::isa<RankedTensorType>(initType))
1283  result.addTypes(initType);
1284  }
1285 
1286  if (bodyBuild)
1287  buildGenericRegion(builder, result.location, *result.regions.front(),
1288  inputs, inits, bodyBuild);
1289 }
1290 
1291 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1292  int64_t inputRank =
1293  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1294  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1295  utils::IteratorType::parallel);
1296  for (int64_t reductionDim : getDimensions())
1297  iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1298  return iteratorTypes;
1299 }
1300 
1301 ArrayAttr ReduceOp::getIndexingMaps() {
1302  int64_t inputRank =
1303  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1304  SmallVector<AffineMap> affineMaps(
1305  getNumDpsInputs(),
1306  AffineMap::getMultiDimIdentityMap(inputRank, getContext()));
1307  AffineMap resultMap =
1308  AffineMap::getMultiDimIdentityMap(inputRank, getContext())
1309  .dropResults(getDimensions());
1310  for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1311  affineMaps.push_back(resultMap);
1312  return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1313 }
1314 
1315 void ReduceOp::getEffects(
1317  &effects) {
1318  getGenericEffectsImpl(effects, getOperation()->getResults(),
1319  getDpsInputOperands(), getDpsInitOperands());
1320 }
1321 
1323  NamedAttrList &attributes,
1324  StringRef attributeName) {
1325  if (parser.parseKeyword(attributeName) || parser.parseEqual())
1326  return failure();
1327 
1328  attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1329  return success();
1330 }
1331 
1332 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1333  std::optional<OperationName> payloadOpName;
1334  NamedAttrList payloadOpAttrs;
1335  if (succeeded(parser.parseOptionalLBrace())) {
1336  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1337  if (failed(operationName))
1338  return failure();
1339  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1340  return failure();
1341  payloadOpName = operationName.value();
1342  if (parser.parseRBrace())
1343  return failure();
1344  }
1345 
1346  if (parseDstStyleOp(
1347  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1348  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1349  }))
1350  return failure();
1351 
1352  if (payloadOpName.has_value()) {
1353  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1354  ArrayRef(result.operands), /*initFirst=*/true);
1355  } else {
1357  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1358  /*allowType=*/true, /*allowAttrs=*/true)) {
1359  return failure();
1360  }
1361 
1362  Region *body = result.addRegion();
1363  if (parser.parseRegion(*body, regionArgs))
1364  return failure();
1365  }
1366 
1367  return success();
1368 }
1369 
1370 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1371  ArrayRef<int64_t> attributeValue) {
1372  p << ' ' << attributeName << " = [" << attributeValue << "] ";
1373 }
1374 
1375 void ReduceOp::print(OpAsmPrinter &p) {
1376  Block *mapper = getBody();
1377  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1378  if (payloadOp) {
1379  printShortForm(p, payloadOp);
1380  }
1381 
1382  printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
1383  SmallVector<Value>(getDpsInitOperands()));
1384  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1385  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1386  if (!payloadOp) {
1387  // Print region if the payload op was not detected.
1388  p.increaseIndent();
1389  p.printNewline();
1390  p << "(";
1391  llvm::interleaveComma(mapper->getArguments(), p,
1392  [&](auto arg) { p.printRegionArgument(arg); });
1393  p << ") ";
1394 
1395  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1396  p.decreaseIndent();
1397  }
1398 }
1399 
1401  ArrayRef<int64_t> dimensionsRef = getDimensions();
1402 
1403  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1404  if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1405  llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1406  return emitOpError() << "expects all inputs to have the same shapes. "
1407  "Shape at input-index "
1408  << i
1409  << " is not equal to the shape at input-index 0.";
1410  }
1411  }
1412  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1413  if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1414  llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1415  return emitOpError() << "expects all outputs to have the same shapes. "
1416  "Shape at output-index "
1417  << i
1418  << " is not equal to the shape at output-index 0.";
1419  }
1420  }
1421  auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1422  auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1423 
1424  DenseSet<int64_t> dimensionsToReduce;
1425  for (int64_t dimension : dimensionsRef) {
1426  if (dimension < 0 || dimension >= inputType.getRank()) {
1427  return emitOpError()
1428  << "dimensions for reduction should be in the range [0, "
1429  << inputType.getRank() - 1 << "].";
1430  }
1431  dimensionsToReduce.insert(dimension);
1432  }
1433 
1434  auto inputDims = inputType.getShape();
1435  auto initDims = initType.getShape();
1436 
1437  // Input dimensions that will be left after the reduction.
1438  SmallVector<int64_t> reducedInputDims;
1439  for (const auto &en : llvm::enumerate(inputDims)) {
1440  if (!dimensionsToReduce.count(en.index()))
1441  reducedInputDims.push_back(en.value());
1442  }
1443 
1444  if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1445  return emitOpError() << "number of dimensions after reduction "
1446  << reducedInputDims.size()
1447  << " doesn't match the init rank "
1448  << initType.getRank();
1449  }
1450 
1451  if (reducedInputDims != initDims)
1452  return emitOpError() << "init dimensions [" << initDims
1453  << "] doesn't match input dimensions after reduction ["
1454  << reducedInputDims << "]";
1455 
1456  Block *block = getBody();
1457  if (block->getNumArguments() != this->getNumOperands())
1458  return emitOpError()
1459  << "mismatching number of operands and block arguments";
1460 
1461  // Check that the first block arguments match the element type of the inputs.
1462  for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1463  Type inputElementType =
1464  llvm::cast<ShapedType>(input.getType()).getElementType();
1465  if (inputElementType != bbArg.getType())
1466  return emitOpError()
1467  << "input element type " << inputElementType
1468  << " does not match corresponding block argument type "
1469  << bbArg.getType();
1470  }
1471 
1472  // Check that the last block arguments match the element type of the outputs.
1473  for (auto [output, bbArg] :
1474  llvm::zip(getDpsInitOperands(),
1475  block->getArguments().take_back(getNumDpsInits()))) {
1476  auto outputElementType =
1477  llvm::cast<ShapedType>(output->get().getType()).getElementType();
1478  if (outputElementType != bbArg.getType())
1479  return emitOpError()
1480  << "output element type " << outputElementType
1481  << " does not match corresponding block argument type "
1482  << bbArg.getType();
1483  }
1484  return success();
1485 }
1486 
1487 //===----------------------------------------------------------------------===//
1488 // TransposeOp
1489 //===----------------------------------------------------------------------===//
1490 
1491 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1492  Region &region, ValueRange inputs,
1493  ValueRange outputs) {
1494  buildGenericRegion(builder, loc, region, inputs, outputs,
1495  [](OpBuilder &b, Location loc, ValueRange args) {
1496  b.create<linalg::YieldOp>(loc, args[0]);
1497  });
1498 }
1499 
1500 void TransposeOp::build(::mlir::OpBuilder &builder,
1501  ::mlir::OperationState &result, Value input, Value init,
1502  DenseI64ArrayAttr permutation,
1503  ArrayRef<NamedAttribute> attributes) {
1504  result.addOperands(input);
1505  result.addOperands(init);
1506  result.addAttribute(getPermutationAttrName(result.name), permutation);
1507  result.addAttributes(attributes);
1508 
1509  // Add output types for `RankedTensorType` output arguments.
1510  Type initType = init.getType();
1511  if (llvm::isa<RankedTensorType>(initType))
1512  result.addTypes(initType);
1513 
1514  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1515  init);
1516 }
1517 
1518 void TransposeOp::build(::mlir::OpBuilder &builder,
1519  ::mlir::OperationState &result, Value input, Value init,
1520  ArrayRef<int64_t> permutation,
1521  ArrayRef<NamedAttribute> attributes) {
1522  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1523  attributes);
1524 }
1525 
1526 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1527  if (failed(parseDstStyleOp(
1528  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1529  return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1530  })))
1531  return failure();
1532 
1533  OpBuilder builder(parser.getContext());
1534  buildIdentityRegion(builder, result.location, *result.addRegion(),
1535  /*inputs=*/result.operands,
1536  /*outputs=*/{});
1537  return success();
1538 }
1539 
1540 void TransposeOp::getAsmResultNames(
1541  function_ref<void(Value, StringRef)> setNameFn) {
1542  if (!getResults().empty())
1543  setNameFn(getResults().front(), "transposed");
1544 }
1545 
1547  printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
1548  SmallVector<Value>(getDpsInitOperands()));
1549  printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1550  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1551 }
1552 
1554  ArrayRef<int64_t> permutationRef = getPermutation();
1555 
1556  if (!isPermutationVector(permutationRef))
1557  return emitOpError("permutation is not valid");
1558 
1559  auto inputType = getInput().getType();
1560  auto initType = getInit().getType();
1561 
1562  int64_t rank = inputType.getRank();
1563 
1564  if (rank != initType.getRank())
1565  return emitOpError() << "input rank " << rank
1566  << " does not match init rank " << initType.getRank();
1567 
1568  if (rank != static_cast<int64_t>(permutationRef.size()))
1569  return emitOpError() << "size of permutation " << permutationRef.size()
1570  << " does not match the argument rank " << rank;
1571 
1572  auto inputDims = inputType.getShape();
1573  auto initDims = initType.getShape();
1574 
1575  for (int64_t i = 0; i < rank; ++i) {
1576  int64_t inputDim = inputDims[permutationRef[i]];
1577  int64_t initDim = initDims[i];
1578 
1579  if (inputDim != initDim) {
1580  return emitOpError() << "dim(result, " << i << ") = " << initDim
1581  << " doesn't match dim(input, permutation[" << i
1582  << "]) = " << inputDim;
1583  }
1584  }
1585 
1586  return success();
1587 }
1588 
1589 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1590  int64_t rank = getInit().getType().getRank();
1591  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1592 }
1593 
1594 ArrayAttr TransposeOp::getIndexingMaps() {
1595  Builder builder(getContext());
1596  int64_t rank = getInit().getType().getRank();
1597  return builder.getAffineMapArrayAttr(
1598  {builder.getMultiDimIdentityMap(rank),
1600  llvm::to_vector_of<unsigned>(getPermutation()), getContext())});
1601 }
1602 
1603 void TransposeOp::getEffects(
1605  &effects) {
1606  getGenericEffectsImpl(effects, getOperation()->getResults(),
1607  getDpsInputOperands(), getDpsInitOperands());
1608 }
1609 
1610 //===----------------------------------------------------------------------===//
1611 // BroadcastOp
1612 //===----------------------------------------------------------------------===//
1613 
1614 void BroadcastOp::build(::mlir::OpBuilder &builder,
1615  ::mlir::OperationState &result, Value input, Value init,
1616  DenseI64ArrayAttr dimensions,
1617  ArrayRef<NamedAttribute> attributes) {
1618  result.addOperands(input);
1619  result.addOperands(init);
1620  result.addAttribute(getDimensionsAttrName(result.name), dimensions);
1621  result.addAttributes(attributes);
1622 
1623  // Add output types for `RankedTensorType` output arguments.
1624  Type initType = init.getType();
1625  if (llvm::isa<RankedTensorType>(initType))
1626  result.addTypes(initType);
1627 
1628  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1629  init);
1630 }
1631 
1632 void BroadcastOp::build(::mlir::OpBuilder &builder,
1633  ::mlir::OperationState &result, Value input, Value init,
1634  ArrayRef<int64_t> dimensions,
1635  ArrayRef<NamedAttribute> attributes) {
1636  build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
1637  attributes);
1638 }
1639 
1640 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
1641  if (failed(parseDstStyleOp(
1642  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1643  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1644  })))
1645  return failure();
1646 
1647  OpBuilder builder(parser.getContext());
1648  buildIdentityRegion(builder, result.location, *result.addRegion(),
1649  /*inputs=*/result.operands,
1650  /*outputs=*/{});
1651  return success();
1652 }
1653 
1654 void BroadcastOp::getAsmResultNames(
1655  function_ref<void(Value, StringRef)> setNameFn) {
1656  if (!getResults().empty())
1657  setNameFn(getResults().front(), "broadcasted");
1658 }
1659 
1661  printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
1662  SmallVector<Value>(getDpsInitOperands()));
1663  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1664  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1665 }
1666 
1668  ArrayRef<int64_t> dimensionsRef = getDimensions();
1669 
1670  auto inputType = getInput().getType();
1671  auto initType = getInit().getType();
1672 
1673  int64_t inputRank = inputType.getRank();
1674  int64_t initRank = initType.getRank();
1675 
1676  auto inputShape = inputType.getShape();
1677  auto initShape = initType.getShape();
1678 
1679  if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
1680  return emitOpError() << "input rank plus added dimensions does not "
1681  "match init rank. input rank: "
1682  << inputRank
1683  << ", dimensions size: " << dimensionsRef.size()
1684  << ", init rank: " << initRank;
1685 
1686  for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
1687  if (dim < 0 || dim >= initRank)
1688  return emitOpError() << "dimension " << idx
1689  << " is out of range. expected range: [0, "
1690  << initRank - 1 << "], got: " << dim;
1691  }
1692 
1693  // Mapping from input dims to init dims.
1694  SmallVector<int64_t> dimMap;
1695  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
1696  if (!llvm::is_contained(dimensionsRef, dim))
1697  dimMap.push_back(dim);
1698  }
1699 
1700  for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
1701  // This dimensions is mapped from the input. Init and input dims should
1702  // match.
1703  if (inputShape[inputDimIdx] != initShape[initDimIdx])
1704  return emitOpError() << "input dim " << inputDimIdx
1705  << " should match init dim " << initDimIdx
1706  << ". input: " << inputShape[inputDimIdx]
1707  << ", init: " << initShape[initDimIdx];
1708  }
1709 
1710  return success();
1711 }
1712 
1713 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
1714  int64_t rank = getInit().getType().getRank();
1715  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1716 }
1717 
1718 ArrayAttr BroadcastOp::getIndexingMaps() {
1719  Builder builder(getContext());
1720  int64_t rank = getInit().getType().getRank();
1721  return builder.getAffineMapArrayAttr(
1722  {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
1723  builder.getMultiDimIdentityMap(rank)});
1724 }
1725 
1726 void BroadcastOp::getEffects(
1728  &effects) {
1729  getGenericEffectsImpl(effects, getOperation()->getResults(),
1730  getDpsInputOperands(), getDpsInitOperands());
1731 }
1732 
1733 //===----------------------------------------------------------------------===//
1734 // YieldOp
1735 //===----------------------------------------------------------------------===//
1736 
1738  if (getNumOperands() > 0)
1739  p << ' ' << getOperands();
1740  p.printOptionalAttrDict((*this)->getAttrs());
1741  if (getNumOperands() > 0)
1742  p << " : " << getOperandTypes();
1743 }
1744 
1745 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
1747  SmallVector<Type, 2> types;
1748  SMLoc loc = parser.getCurrentLocation();
1749  return failure(parser.parseOperandList(opInfo) ||
1750  parser.parseOptionalAttrDict(result.attributes) ||
1751  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
1752  parser.resolveOperands(opInfo, types, loc, result.operands));
1753 }
1754 
1755 // Check the operand number and types must match the element types of the
1756 // LinalgOp interface's shaped operands.
1757 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
1758  if (op.getNumOperands() != linalgOp.getNumDpsInits())
1759  return op.emitOpError("expected number of yield values (")
1760  << linalgOp.getNumDpsInits()
1761  << ") to match the number of operands of the enclosing "
1762  << "LinalgOp (" << op.getNumOperands() << ")";
1763 
1764  for (OpOperand &opOperand : op->getOpOperands()) {
1765  OpOperand *outputOperand =
1766  linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
1767  Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
1768  if (opOperand.get().getType() != elementType)
1769  return op.emitOpError("type of yield operand ")
1770  << (opOperand.getOperandNumber() + 1) << " ("
1771  << opOperand.get().getType() << ") doesn't match "
1772  << "the element type of the enclosing linalg.generic op ("
1773  << elementType << ")";
1774  }
1775  return success();
1776 }
1777 
1779  auto *parentOp = (*this)->getParentOp();
1780  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
1781  return emitOpError("expected single non-empty parent region");
1782 
1783  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
1784  return verifyYield(*this, linalgOp);
1785 
1786  return emitOpError("expected parent op with LinalgOp interface");
1787 }
1788 
1789 //===----------------------------------------------------------------------===//
1790 // IndexOp
1791 //===----------------------------------------------------------------------===//
1792 
1794  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
1795  if (!linalgOp)
1796  return emitOpError("expected parent op with LinalgOp interface");
1797  if (linalgOp.getNumLoops() <= getDim())
1798  return emitOpError("expected dim (")
1799  << getDim() << ") to be lower than the number of loops ("
1800  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
1801  return success();
1802 }
1803 
1804 /////// Operations corresponding to library calls defined with Tablegen ////////
1805 
1806 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
1807 
1808 #define GET_OP_CLASSES
1809 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
1810 
1811 #define GET_OP_CLASSES
1812 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
1813 
1814 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
1815  unsigned rank,
1816  MLIRContext *context) {
1817  if (maybeMap)
1818  return *maybeMap;
1819  if (rank == 0)
1820  return AffineMap::get(context);
1821  return AffineMap::getMultiDimIdentityMap(rank, context);
1822 }
1823 
1825 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
1826  MLIRContext *context) {
1828  res.reserve(num);
1829  for (unsigned i = 0; i < num; ++i)
1830  res.push_back(getAffineDimExpr(startIdx++, context));
1831  return res;
1832 }
1833 
1836  auto rangeA = llvm::make_range(a.begin(), a.end());
1837  auto rangeB = llvm::make_range(b.begin(), b.end());
1838  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
1839  return llvm::to_vector<4>(concatRanges);
1840 }
1841 
1842 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
1843  if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
1844  ss << "view";
1845  for (auto size : memref.getShape())
1846  if (size < 0)
1847  ss << "sx";
1848  else
1849  ss << size << "x";
1850  if (failed(appendMangledType(ss, memref.getElementType())))
1851  return failure();
1852  if (auto as = memref.getMemorySpace()) {
1853  if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
1854  ss << "as" << attr.getInt();
1855  else
1856  return failure();
1857  }
1858  return success();
1859  }
1860  if (auto vec = llvm::dyn_cast<VectorType>(t)) {
1861  ss << "vector";
1862  llvm::interleave(
1863  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
1864  if (failed(appendMangledType(ss, vec.getElementType())))
1865  return failure();
1866  return success();
1867  } else if (t.isSignlessIntOrIndexOrFloat()) {
1868  ss << t;
1869  return success();
1870  }
1871  return failure();
1872 }
1873 
1875  assert(isa<LinalgOp>(op));
1876  std::string name(op->getName().getStringRef().str());
1877  std::string fun = "";
1878  for (NamedAttribute kv : op->getAttrs()) {
1879  if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
1880  fun = stringifyEnum(ufa.getValue()).str() + "_";
1881  } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
1882  fun = stringifyEnum(bfa.getValue()).str() + "_";
1883  }
1884  }
1885  name.reserve(128);
1886  std::replace(name.begin(), name.end(), '.', '_');
1887  llvm::raw_string_ostream ss(name);
1888  ss << "_" << fun;
1889  for (Type t : op->getOperandTypes()) {
1890  if (failed(appendMangledType(ss, t)))
1891  return std::string();
1892  ss << "_";
1893  }
1894  std::string res = ss.str();
1895  res.pop_back();
1896  return res;
1897 }
1898 
1899 //===----------------------------------------------------------------------===//
1900 // Canonicalizers and Folders.
1901 //===----------------------------------------------------------------------===//
1902 
1903 namespace {
1904 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
1906 
1907  LogicalResult matchAndRewrite(LinalgOp op,
1908  PatternRewriter &rewriter) const override {
1909  for (OpOperand &opOperand : op->getOpOperands()) {
1910  // Linalg "inputs" may be either tensor or memref type.
1911  // tensor<0xelt_type> is a convention that may not always mean
1912  // "0 iterations". Only erase in cases we see memref<...x0x...>.
1913  auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
1914  if (!mt)
1915  continue;
1916  if (llvm::is_contained(op.getShape(&opOperand), 0)) {
1917  rewriter.eraseOp(op);
1918  return success();
1919  }
1920  }
1921  return failure();
1922  }
1923 };
1924 
1925 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
1926 /// result that is more static than the linalg op.
1927 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
1929 
1930  LogicalResult matchAndRewrite(tensor::CastOp castOp,
1931  PatternRewriter &rewriter) const override {
1932  if (!tensor::canFoldIntoProducerOp(castOp))
1933  return failure();
1934 
1935  auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
1936  if (!linalgOp)
1937  return failure();
1938 
1939  // Cast can be in conditionally reachable region, if which case folding will
1940  // generate invalid code. Only conservatively fold ops in same block for
1941  // now.
1942  if (castOp->getBlock() != linalgOp->getBlock())
1943  return failure();
1944 
1945  OpBuilder::InsertionGuard guard(rewriter);
1946  rewriter.setInsertionPoint(linalgOp);
1947 
1948  Location loc = linalgOp.getLoc();
1949  OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
1950  unsigned resultNumber = resultValue.getResultNumber();
1951  auto resultType =
1952  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1953  // Replace the `outs` for the result with a `tensor.cast`. This cast is now
1954  // going from a more dynamic shape to a less dynamic shape. If the producer
1955  // for this cast, i.e. producer of the out operand, is also an operation
1956  // that folds with tensor.cast consumer (like this pattern), the cast will
1957  // continue to propagate as far up the stack as it can go.
1958  OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
1959  Value newOperand =
1960  rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
1961  SmallVector<Value> newOperands{linalgOp.getDpsInputOperands()};
1962  SmallVector<Value> outputOperands{linalgOp.getDpsInitOperands()};
1963  outputOperands[resultNumber] = newOperand;
1964  newOperands.append(outputOperands.begin(), outputOperands.end());
1965 
1966  SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
1967  linalgOp->result_type_end());
1968  resultTypes[resultNumber] = resultType;
1969  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
1970 
1971  // Create a tensor.cast operation back to the original type.
1972  Value castBack = rewriter.create<tensor::CastOp>(
1973  loc, resultValue.getType(), newOp->getResult(resultNumber));
1974 
1975  SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
1976  results[resultNumber] = castBack;
1977  rewriter.replaceOp(linalgOp, results);
1978  rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
1979  return success();
1980  }
1981 };
1982 
1983 /// For each of the operand in `operands` this function maps the static sizes of
1984 /// dimensions to their affine dim expressions.
1985 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
1986  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
1987  for (OpOperand &opOperand : operands) {
1988  if (linalgOp.isScalar(&opOperand))
1989  continue;
1990  Value src = opOperand.get();
1991  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
1992  auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
1993 
1994  // Get the `sourceShape` of the `sourceType`. If the operand is a result of
1995  // `tensor.cast` operation and source of the cast operation has a static
1996  // shape, then assign it to the `sourceShape`.
1997  auto *parentOp = src.getDefiningOp();
1998  ArrayRef<int64_t> sourceShape = sourceType.getShape();
1999  if (parentOp) {
2000  if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2001  Value castSource = castOp.getSource();
2002  auto castSourceType =
2003  llvm::dyn_cast<RankedTensorType>(castSource.getType());
2004  if (castSourceType && castSourceType.hasStaticShape())
2005  sourceShape = castSourceType.getShape();
2006  }
2007  }
2008 
2009  // If the source shape's dimension has a static shape, map the affine dim
2010  // expression to the known static size.
2011  for (unsigned i = 0; i < sourceShape.size(); i++) {
2012  if (sourceType.isDynamicDim(i))
2013  continue;
2014  if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast<AffineDimExpr>())
2015  affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2016  }
2017  }
2018 }
2019 
2020 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2021 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2022 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2023 /// change then `changeNeeded` is false and same operand is added in the
2024 /// `newOperands` list.
2025 static void createNewOperandWithStaticSizes(
2026  Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2027  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2028  SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2029  bool &changeNeeded) {
2030  Value src = opOperand->get();
2031  newOperands.push_back(src);
2032  if (linalgOp.isScalar(opOperand))
2033  return;
2034  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2035  Type resultType = sourceType;
2036  if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2037  resultTypes.push_back(resultType);
2038  return;
2039  }
2040  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2041  AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2042  SmallVector<int64_t> newShape;
2043  // If operand is updated with new shape, `newOperandNeeded` will be
2044  // true.
2045  bool newOperandNeeded = false;
2046  for (unsigned i = 0; i < sourceShape.size(); i++) {
2047  int64_t dimShape = sourceShape[i];
2048  AffineExpr dimExpr = sourceMap.getResult(i);
2049  if (!affineExprToSize.contains(dimExpr) ||
2050  !sourceType.isDynamicDim(i)) {
2051  newShape.push_back(dimShape);
2052  continue;
2053  }
2054  // Dimension has a dynamic shape and corresponding affine dim
2055  // expression is present in the map. So assign the size for the
2056  // given affine dim expression to the dimension.
2057  newShape.push_back(affineExprToSize[dimExpr]);
2058  newOperandNeeded = true;
2059  }
2060  resultType = RankedTensorType::get(newShape, sourceType.getElementType());
2061  if (newOperandNeeded) {
2062  changeNeeded = true;
2063  // Get the new operand value given its size and element type by
2064  // casting it.
2065  Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2066  unsigned index = opOperand->getOperandNumber();
2067  newOperands[index] = newOperand;
2068  }
2069  if (linalgOp.isDpsInit(opOperand))
2070  resultTypes.push_back(resultType);
2071 }
2072 
2073 /// Static shapes for the operands can be inferred if any one of the operands
2074 /// have a static shape. This can be done by referring to the affine dim
2075 /// expressions for the operand.
2076 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2078 
2079  LogicalResult matchAndRewrite(LinalgOp linalgOp,
2080  PatternRewriter &rewriter) const override {
2081  if (!linalgOp.hasTensorSemantics())
2082  return failure();
2083 
2084  // Maps must be projected permutations.
2085  if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2086  return !map.isProjectedPermutation();
2087  }))
2088  return failure();
2089 
2090  // Maps affine dim expressions to the static size of that dimension.
2091  llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2092  Location loc = linalgOp.getLoc();
2093 
2094  // For each of the affine dim expression, check if the size is known. If
2095  // known add that in the map.
2096  populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2097 
2098  SmallVector<Value> newOperands;
2099  SmallVector<Type> resultTypes;
2100 
2101  // `changeNeeded` is `false` if the operands of `linalgOp` require no
2102  // change in their types.
2103  bool changeNeeded = false;
2104  newOperands.reserve(linalgOp->getNumOperands());
2105  resultTypes.reserve(linalgOp.getNumDpsInits());
2106 
2107  // Iterate over all the operands and update the static sizes.
2108  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2109  createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2110  affineExprToSize, linalgOp, newOperands,
2111  resultTypes, changeNeeded);
2112  }
2113 
2114  // If the generic op has all the required static information, no
2115  // canonicalization needed.
2116  if (!changeNeeded)
2117  return failure();
2118 
2119  // Clone op.
2120  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2121  SmallVector<Value> replacements;
2122  replacements.reserve(newOp->getNumResults());
2123  for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2124  Value newResult = std::get<1>(it);
2125  Value oldResult = std::get<0>(it);
2126  Type newType = newResult.getType();
2127  Type oldType = oldResult.getType();
2128  replacements.push_back(
2129  (newType != oldType)
2130  ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2131  : newResult);
2132  }
2133  rewriter.replaceOp(linalgOp, replacements);
2134  return success();
2135  }
2136 };
2137 
2138 } // namespace
2139 
2140 // All named ops canonicalizers and folders are auto-generated in the
2141 // .cpp.inc.
2142 
2143 //===----------------------------------------------------------------------===//
2144 // LinalgDialect
2145 //===----------------------------------------------------------------------===//
2146 
2147 void LinalgDialect::getCanonicalizationPatterns(
2148  RewritePatternSet &results) const {
2149  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
2150  InferStaticShapeOfOperands>(getContext());
2151 }
2152 
2154  Attribute value, Type type,
2155  Location loc) {
2156  return arith::ConstantOp::materialize(builder, value, type, loc);
2157 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:1491
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:216
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
Definition: LinalgOps.cpp:62
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:1842
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:204
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:275
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
Definition: LinalgOps.cpp:1322
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
Definition: LinalgOps.cpp:1370
static Operation * findPayloadOp(Block *body, bool initFirst=false)
Definition: LinalgOps.cpp:1127
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
Definition: LinalgOps.cpp:100
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
Definition: LinalgOps.cpp:1000
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, ValueRange results, const OpOperandVector &inputOperands, const OpOperandVector &outputOperands)
Definition: LinalgOps.cpp:876
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1757
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:242
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
Definition: LinalgOps.cpp:652
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:235
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
Definition: LinalgOps.cpp:1054
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
Definition: LinalgOps.cpp:1156
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:268
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
Definition: LinalgOps.cpp:130
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:259
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:262
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:341
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:212
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:67
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:118
unsigned getNumArguments()
Definition: Block.h:117
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
OpListType & getOperations()
Definition: Block.h:126
BlockArgListType getArguments()
Definition: Block.h:76
Operation & front()
Definition: Block.h:142
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:166
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:170
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:363
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:256
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:260
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:312
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:152
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:189
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:329
This class helps build Operations.
Definition: Builders.h:202
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:412
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:379
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:417
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:406
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:433
This class represents an operand of an operation.
Definition: Value.h:261
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:217
This is a value defined by a result of an operation.
Definition: Value.h:448
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:460
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_iterator result_begin()
Definition: Operation.h:408
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:469
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_iterator result_end()
Definition: Operation.h:409
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:632
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:700
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:629
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:514
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:104
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1334
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:262
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:1834
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:1874
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
Definition: LinalgOps.cpp:1814
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:1825
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:88
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
Definition: MPInt.h:370
MPInt ceil(const Fraction &f)
Definition: Fraction.h:70
MPInt floor(const Fraction &f)
Definition: Fraction.h:68
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:300
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:87
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:343
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:502
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:374
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:372
OpOperand vector that implicitly converts to a Value vector.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.