MLIR  16.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
29 #include "mlir/IR/AffineMap.h"
30 #include "mlir/IR/Matchers.h"
32 #include "mlir/IR/PatternMatch.h"
34 
35 #include "llvm/ADT/DenseMap.h"
36 #include "llvm/ADT/SmallSet.h"
37 #include "llvm/ADT/StringSet.h"
38 #include "llvm/ADT/TypeSwitch.h"
39 #include "llvm/Support/FormatVariadic.h"
40 #include "llvm/Support/MathExtras.h"
41 #include "llvm/Support/raw_ostream.h"
42 
43 using namespace mlir;
44 using namespace mlir::linalg;
45 
46 //===----------------------------------------------------------------------===//
47 // Support for named Linalg ops defined in ods-gen.
48 //===----------------------------------------------------------------------===//
49 
52 
53 /// Fills the region of a structured operation using the provided
54 /// `regionBuilder`. The method is used by both named structured ops created by
55 /// ods-gen and by manually defined C++ ops. It is called by both builders and
56 /// parsers and creates a block with arguments corresponding to the elemental
57 /// types of `inputTypes` and `outputTypes`. All output types are asserted to be
58 /// ShapedType.
59 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
60  TypeRange inputTypes, TypeRange outputTypes,
62  RegionBuilderFn regionBuilder) {
63  assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
64 
65  // TODO: atm all operands go through getElementTypeOrSelf,
66  // reconsider when we have evidence we need to.
67  SmallVector<Type, 8> argTypes;
69  for (auto containers : {inputTypes, outputTypes}) {
70  for (auto t : containers) {
71  argTypes.push_back(getElementTypeOrSelf(t));
72 
73  // TODO: Pass in a proper location here.
74  argLocs.push_back(opBuilder.getUnknownLoc());
75  }
76  }
77 
78  // RAII.
79  OpBuilder::InsertionGuard guard(opBuilder);
80  Block *body =
81  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
82 
83  opBuilder.setInsertionPointToStart(body);
84  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
85  regionBuilder(b, *body, attrs);
86 
87  // indexing_maps is an auto-generated method.
88 
89  // iterator_types is an auto-generated method.
90 }
91 
92 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
93 /// The result types are derived automatically if `resultTensorTypes` is none.
94 /// The body of the operation is filled using `regionBuilder`. All ods-gen
95 /// created structured operations use the method to implement their builders.
97  llvm::Optional<TypeRange> resultTensorTypes,
98  ValueRange inputs, ValueRange outputs,
99  ArrayRef<NamedAttribute> attributes,
100  RegionBuilderFn regionBuilder) {
101  // Derive the result types if needed.
102  SmallVector<Type> derivedResultTypes =
103  resultTensorTypes.value_or(TypeRange());
104  if (!resultTensorTypes)
105  copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
106  [](Type type) { return type.isa<RankedTensorType>(); });
107 
108  state.addOperands(inputs);
109  state.addOperands(outputs);
110  state.addTypes(derivedResultTypes);
111  state.addAttributes(attributes);
112  state.addAttribute(
113  "operand_segment_sizes",
114  b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
115  static_cast<int32_t>(outputs.size())}));
116 
117  // Create and fill the region of the structured operation.
118  Region &region = *state.addRegion();
119  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
120  state.attributes.getAttrs(), regionBuilder);
121 }
122 
123 /// Common parsing used for both named structured ops created by ods-gen and by
124 /// manually defined C++ ops. Does not handle regions.
125 static ParseResult
127  SmallVectorImpl<Type> &inputTypes,
128  SmallVectorImpl<Type> &outputTypes,
129  bool addOperandSegmentSizes = true) {
130  SMLoc inputsOperandsLoc, outputsOperandsLoc;
132  outputsOperands;
133 
134  if (parser.parseOptionalAttrDict(result.attributes))
135  return failure();
136 
137  if (succeeded(parser.parseOptionalKeyword("ins"))) {
138  if (parser.parseLParen())
139  return failure();
140 
141  inputsOperandsLoc = parser.getCurrentLocation();
142  if (parser.parseOperandList(inputsOperands) ||
143  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
144  return failure();
145  }
146 
147  if (succeeded(parser.parseOptionalKeyword("outs"))) {
148  outputsOperandsLoc = parser.getCurrentLocation();
149  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
150  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
151  return failure();
152  }
153 
154  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
155  result.operands) ||
156  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
157  result.operands))
158  return failure();
159 
160  if (addOperandSegmentSizes) {
161  result.addAttribute("operand_segment_sizes",
163  {static_cast<int32_t>(inputsOperands.size()),
164  static_cast<int32_t>(outputsOperands.size())}));
165  }
166  return success();
167 }
168 
170  ValueRange outputs) {
171  if (!inputs.empty())
172  p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
173  if (!outputs.empty())
174  p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
175 }
176 
178  ValueRange inputs,
179  ValueRange outputs) {
180  if (!inputs.empty()) {
181  p.printNewline();
182  p << "ins(" << inputs << " : " << inputs.getTypes() << ")";
183  }
184  if (!outputs.empty()) {
185  p.printNewline();
186  p << "outs(" << outputs << " : " << outputs.getTypes() << ")";
187  }
188 }
189 //===----------------------------------------------------------------------===//
190 // Specific parsing and printing for named structured ops created by ods-gen.
191 //===----------------------------------------------------------------------===//
192 
194  OpAsmParser &parser, Region &region, unsigned numRegionArgs,
195  TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
196  RegionBuilderFn regionBuilder) {
197  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
198  return parser.emitError(
199  parser.getCurrentLocation(),
200  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
201  "region expects {0} args, got {1}",
202  numRegionArgs, inputTypes.size() + outputTypes.size()));
203  }
204 
205  OpBuilder opBuilder(parser.getContext());
206  fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
207  regionBuilder);
208  return success();
209 }
210 
211 static ParseResult
213  SmallVectorImpl<Type> &resultTypes) {
214  if (parser.parseOptionalArrowTypeList(resultTypes))
215  return failure();
216  return success();
217 }
218 
220  OperationState &result,
221  unsigned numRegionArgs,
222  RegionBuilderFn regionBuilder) {
223  // TODO: Enable when ods-gen supports captures.
224  SmallVector<Type, 1> inputTypes, outputTypes;
225  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
226  return failure();
227 
228  // TODO: consider merging results parsing into region parsing.
229  // Need to wait for declarative assembly resolution to decide.
230  SmallVector<Type, 1> outputTensorsTypes;
231  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
232  return failure();
233  result.addTypes(outputTensorsTypes);
234 
235  std::unique_ptr<Region> region = std::make_unique<Region>();
236  if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
237  outputTypes, result.attributes.getAttrs(),
238  regionBuilder))
239  return failure();
240  result.addRegion(std::move(region));
241 
242  return success();
243 }
244 
246  TypeRange resultTypes) {
247  if (resultTypes.empty())
248  return;
249  p.printOptionalArrowTypeList(resultTypes);
250 }
251 
253  ValueRange inputs, ValueRange outputs) {
255  op->getAttrs(),
256  /*elidedAttrs=*/{"operand_segment_sizes",
257  // See generated code in mlir-linalg-yaml-gen.cpp
258  "linalg.memoized_indexing_maps"});
259 
260  // Printing is shared with generic ops, except for the region and
261  // attributes.
262  printCommonStructuredOpParts(p, inputs, outputs);
263 
264  // Results printing.
266 
267  // Region is elided.
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // Region builder helper.
272 // TODO: Move this to a utility library.
273 // The public methods on this class are referenced directly from generated code.
274 // Helper build the unary, binary, and type conversion functions defined by the
275 // DSL. See mlir-linalg-ods-yaml-gen.cpp for the code that uses this class.
276 //
277 // Implementations of the math functions must be polymorphic over numeric types,
278 // internally performing necessary casts. If the function application makes no
279 // sense, then the only recourse is to assert and return nullptr. This can be
280 // extended later if it becomes possible to fail construction of the region. The
281 // invariant should be enforced at a higher level.
282 //
283 // TODO: These helpers are currently type polymorphic over the class of integer
284 // and floating point types, but they will not internally cast within bit
285 // widths of a class (mixed precision such as i8->i32) or across classes
286 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
287 // to be handled with care and work is being considered to extend the op
288 // language to make such cases explicit. In the mean-time, violating this will
289 // fail verification, which is deemed acceptable.
290 //===----------------------------------------------------------------------===//
291 
292 namespace {
293 
294 class RegionBuilderHelper {
295 public:
296  RegionBuilderHelper(MLIRContext *context, Block &block)
297  : context(context), block(block) {}
298 
299  // Build the unary functions defined by OpDSL.
300  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
301  if (!isFloatingPoint(arg))
302  llvm_unreachable("unsupported non numeric type");
303  OpBuilder builder = getBuilder();
304  switch (unaryFn) {
305  case UnaryFn::exp:
306  return builder.create<math::ExpOp>(arg.getLoc(), arg);
307  case UnaryFn::log:
308  return builder.create<math::LogOp>(arg.getLoc(), arg);
309  case UnaryFn::abs:
310  return builder.create<math::AbsFOp>(arg.getLoc(), arg);
311  case UnaryFn::ceil:
312  return builder.create<math::CeilOp>(arg.getLoc(), arg);
313  case UnaryFn::floor:
314  return builder.create<math::FloorOp>(arg.getLoc(), arg);
315  case UnaryFn::negf:
316  return builder.create<arith::NegFOp>(arg.getLoc(), arg);
317  }
318  llvm_unreachable("unsupported unary function");
319  }
320 
321  // Build the binary functions defined by OpDSL.
322  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
323  bool allComplex = isComplex(arg0) && isComplex(arg1);
324  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
325  bool allInteger = isInteger(arg0) && isInteger(arg1);
326  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
327  arg1.getType().getIntOrFloatBitWidth() == 1;
328  if (!allComplex && !allFloatingPoint && !allInteger)
329  llvm_unreachable("unsupported non numeric type");
330  OpBuilder builder = getBuilder();
331  switch (binaryFn) {
332  case BinaryFn::add:
333  if (allComplex)
334  return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
335  if (allFloatingPoint)
336  return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
337  if (allBool)
338  return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
339  return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
340  case BinaryFn::sub:
341  if (allComplex)
342  return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
343  if (allFloatingPoint)
344  return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
345  if (allBool)
346  llvm_unreachable("unsupported operation: sub with bools");
347  return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
348  case BinaryFn::mul:
349  if (allComplex)
350  return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
351  if (allFloatingPoint)
352  return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
353  if (allBool)
354  return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
355  return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
356  case BinaryFn::max_signed:
357  assert(!allComplex);
358  if (allFloatingPoint)
359  return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
360  return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
361  case BinaryFn::min_signed:
362  assert(!allComplex);
363  if (allFloatingPoint)
364  return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
365  return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
366  case BinaryFn::max_unsigned:
367  assert(!allComplex);
368  if (allFloatingPoint)
369  return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
370  return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
371  case BinaryFn::min_unsigned:
372  assert(!allComplex);
373  if (allFloatingPoint)
374  return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
375  return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
376  }
377  llvm_unreachable("unsupported binary function");
378  }
379 
380  // Build the type functions defined by OpDSL.
381  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
382  switch (typeFn) {
383  case TypeFn::cast_signed:
384  return cast(toType, operand, false);
385  case TypeFn::cast_unsigned:
386  return cast(toType, operand, true);
387  }
388  llvm_unreachable("unsupported type conversion function");
389  }
390 
391  void yieldOutputs(ValueRange values) {
392  OpBuilder builder = getBuilder();
393  Location loc = builder.getUnknownLoc();
394  builder.create<YieldOp>(loc, values);
395  }
396 
397  Value constant(const std::string &value) {
398  OpBuilder builder = getBuilder();
399  Location loc = builder.getUnknownLoc();
400  Attribute valueAttr = parseAttribute(value, builder.getContext());
401  Type type = NoneType::get(builder.getContext());
402  if (auto typedAttr = valueAttr.dyn_cast<TypedAttr>())
403  type = typedAttr.getType();
404  return builder.create<arith::ConstantOp>(loc, type, valueAttr);
405  }
406 
407  Value index(int64_t dim) {
408  OpBuilder builder = getBuilder();
409  return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
410  }
411 
412  Type getIntegerType(unsigned width) {
413  return IntegerType::get(context, width);
414  }
415 
416  Type getFloat32Type() { return Float32Type::get(context); }
417  Type getFloat64Type() { return Float64Type::get(context); }
418 
419 private:
420  // Generates operations to cast the given operand to a specified type.
421  // If the cast cannot be performed, a warning will be issued and the
422  // operand returned as-is (which will presumably yield a verification
423  // issue downstream).
424  Value cast(Type toType, Value operand, bool isUnsignedCast) {
425  OpBuilder builder = getBuilder();
426  auto loc = operand.getLoc();
427  return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
428  }
429 
430  bool isComplex(Value value) { return value.getType().isa<ComplexType>(); }
431  bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
432  bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
433 
434  OpBuilder getBuilder() {
435  OpBuilder builder(context);
436  builder.setInsertionPointToEnd(&block);
437  return builder;
438  }
439 
440  MLIRContext *context;
441  Block &block;
442 };
443 
444 } // namespace
445 
446 //===----------------------------------------------------------------------===//
447 // FillOp
448 //===----------------------------------------------------------------------===//
449 
450 namespace {
451 
452 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
453 ///
454 /// For such op chains, we can create new linalg.fill ops with the result
455 /// type of the tensor.expand/collapse_shape op.
456 template <typename TensorReshapeOp>
457 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
459  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
460  PatternRewriter &rewriter) const override {
461  auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
462  if (!oldFill)
463  return failure();
464 
465  Location loc = oldFill.getLoc();
466  auto newInit = rewriter.create<TensorReshapeOp>(
467  loc, reshapeOp.getResultType(), oldFill.output(),
468  reshapeOp.getReassociation());
469  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
470  ValueRange{newInit});
471 
472  return success();
473  }
474 };
475 
476 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
477 /// filling value are the same.
478 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
480 
481  LogicalResult matchAndRewrite(tensor::PadOp padOp,
482  PatternRewriter &rewriter) const override {
483  auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
484  if (!fillOp)
485  return failure();
486 
487  // We can only fold if the padding value is the same as the original
488  // filling value.
489  Value padValue = padOp.getConstantPaddingValue();
490  if (!padValue || fillOp.value() != padValue)
491  return failure();
492 
493  ReifiedRankedShapedTypeDims reifiedShape;
494  ReifyRankedShapedTypeOpInterface interface =
495  cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation());
496  if (failed(interface.reifyResultShapes(rewriter, reifiedShape)))
497  return rewriter.notifyMatchFailure(
498  padOp, "failed to reify tensor.pad op result shape");
499 
500  SmallVector<OpFoldResult> newShape =
501  getAsOpFoldResult(reifiedShape.front());
502  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
503  padOp.getLoc(), newShape, padOp.getResultType().getElementType());
504  Value replacement =
505  rewriter
506  .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
507  ValueRange{emptyTensor})
508  .getResult(0);
509  if (replacement.getType() != padOp.getResultType()) {
510  replacement = rewriter.create<tensor::CastOp>(
511  fillOp.getLoc(), padOp.getResultType(), replacement);
512  }
513  rewriter.replaceOp(padOp, replacement);
514  return success();
515  }
516 };
517 
518 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
519 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
520 /// filling value are the same.
521 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
523 
524  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
525  PatternRewriter &rewriter) const override {
526  auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
527  if (!srcPadOp)
528  return failure();
529 
530  if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
531  return failure();
532 
533  // Walk back the tensor.insert_slice chain and find the first destination
534  // value at the start of the chain.
535  Value firstDest = insertOp.getDest();
536  while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
537  if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
538  return failure();
539 
540  // Make sure the range of values accessed are disjoint. Without this, we
541  // cannot fold tensor.pad away.
542  bool disjoint = false;
543  for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
544  // If the dimension has dynamic offset/size, we cannot guarantee
545  // disjoint. So just skip it.
546  if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
547  insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
548  prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
549  continue;
550 
551  // Get the range start and end, inclusively for both.
552  int64_t prevStart = prevOp.getStaticOffset(i);
553  int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
554  prevOp.getStaticStride(i);
555  int64_t nextStart = insertOp.getStaticOffset(i);
556  int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
557  insertOp.getStaticStride(i);
558  if (prevEnd < nextStart || nextEnd < prevStart) {
559  disjoint = true;
560  break;
561  }
562  }
563 
564  if (!disjoint)
565  break;
566  firstDest = prevOp.getDest();
567  }
568 
569  // Check whether the first destination is a fill op. For overlapped cases,
570  // this also cannot be true.
571  auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
572  if (!dstFillOp)
573  return failure();
574 
575  // We can only fold if the padding value is the same as the original
576  // filling value.
577  Value padValue = srcPadOp.getConstantPaddingValue();
578  if (!padValue || dstFillOp.value() != padValue)
579  return failure();
580 
581  SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
582  SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
583 
584  Location loc = insertOp.getLoc();
585  MLIRContext *context = getContext();
586 
587  AffineExpr sym0, sym1;
588  bindSymbols(context, sym0, sym1);
589  auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
590 
591  // Calculate the new offsets for the insert. It should be the old offsets
592  // plus low padding sizes.
593  SmallVector<OpFoldResult, 4> newOffsets;
594  for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
595  newOffsets.push_back(makeComposedFoldedAffineApply(
596  rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
597  }
598 
600  for (int i = 0, e = srcPadOp.getSourceType().getRank(); i < e; ++i) {
601  newSizes.push_back(
602  rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
603  .getResult());
604  }
605 
606  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
607  insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
608  newSizes, insertOp.getMixedStrides());
609  return success();
610  }
611 };
612 
613 } // namespace
614 
615 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
616  MLIRContext *context) {
617  results
618  .add<FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
619  FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
620  FoldInsertPadIntoFill>(context);
621 }
622 
623 //===----------------------------------------------------------------------===//
624 // GenericOp
625 //===----------------------------------------------------------------------===//
626 
627 static void buildGenericRegion(
628  OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
629  ValueRange outputs,
630  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
631  SmallVector<Type, 4> blockArgTypes;
632  SmallVector<Location, 4> blockArgLocs;
633  for (ValueRange container : {inputs, outputs}) {
634  for (Value v : container) {
635  blockArgTypes.push_back(getElementTypeOrSelf(v));
636  blockArgLocs.push_back(v.getLoc());
637  }
638  }
639 
640  OpBuilder::InsertionGuard guard(builder);
641  Block *bodyBlock =
642  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
643  bodyBuild(builder, loc, bodyBlock->getArguments());
644 }
645 
646 void GenericOp::getAsmBlockArgumentNames(Region &region,
647  OpAsmSetValueNameFn setNameFn) {
648  for (Value v : getRegionInputArgs())
649  setNameFn(v, "in");
650  for (Value v : getRegionOutputArgs())
651  setNameFn(v, "out");
652 }
653 
654 void GenericOp::build(
655  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
656  ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
657  ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
658  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
659  ArrayRef<NamedAttribute> attributes) {
660  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
661  iteratorTypes, doc, libraryCall);
662  result.addAttributes(attributes);
663  if (bodyBuild)
664  buildGenericRegion(builder, result.location, *result.regions.front(),
665  inputs, outputs, bodyBuild);
666 }
667 
668 void GenericOp::build(
669  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
670  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
671  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
672  StringRef libraryCall,
673  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
674  ArrayRef<NamedAttribute> attributes) {
675  build(builder, result, resultTensorTypes, inputs, outputs,
676  builder.getAffineMapArrayAttr(indexingMaps),
677  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
678  iteratorTypes,
679  [&](utils::IteratorType iter) -> mlir::Attribute {
680  return IteratorTypeAttr::get(builder.getContext(), iter);
681  }))),
682  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
683  libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
684  bodyBuild, attributes);
685 }
686 
687 void GenericOp::build(
688  OpBuilder &builder, OperationState &result, ValueRange inputs,
689  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
690  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
691  StringRef libraryCall,
692  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
693  ArrayRef<NamedAttribute> attributes) {
694  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
695  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
696 }
697 
698 void GenericOp::build(
699  OpBuilder &builder, OperationState &result, ValueRange inputs,
700  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
701  ArrayRef<utils::IteratorType> iteratorTypes,
702  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
703  ArrayRef<NamedAttribute> attributes) {
704  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
705  /*doc=*/"",
706  /*libraryCall=*/"", bodyBuild, attributes);
707 }
708 
709 void GenericOp::build(
710  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
711  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
712  ArrayRef<utils::IteratorType> iteratorTypes,
713  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
714  ArrayRef<NamedAttribute> attributes) {
715  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
716  iteratorTypes,
717  /*doc=*/"",
718  /*libraryCall=*/"", bodyBuild, attributes);
719 }
720 
722  p << " ";
723 
724  // Print extra attributes.
725  auto genericAttrNames = linalgTraitAttrNames();
726 
727  llvm::StringSet<> genericAttrNamesSet;
728  genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
729  SmallVector<NamedAttribute, 8> genericAttrs;
730  for (auto attr : (*this)->getAttrs()) {
731  if (attr.getName() == getIteratorTypesAttrName()) {
732  auto iteratorTypes =
733  attr.getValue()
734  .cast<ArrayAttr>()
735  .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
736  // Convert IteratorType enums into the string representation. This is
737  // needed, because tests still use the old format when 'iterator_types'
738  // attribute is represented as an array of strings.
739  // TODO: Remove this conversion once tests are fixed.
740  SmallVector<Attribute> iteratorTypeNames =
741  llvm::to_vector(llvm::map_range(
742  iteratorTypes, [&](utils::IteratorType t) -> Attribute {
743  return StringAttr::get(getContext(), stringifyIteratorType(t));
744  }));
745 
746  genericAttrs.emplace_back(
747  getIteratorTypesAttrName(),
748  ArrayAttr::get(getContext(), iteratorTypeNames));
749  } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
750  genericAttrs.push_back(attr);
751  }
752  }
753  if (!genericAttrs.empty()) {
754  auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
755  p << genericDictAttr;
756  }
757 
758  // Printing is shared with named ops, except for the region and attributes
759  printCommonStructuredOpParts(p, SmallVector<Value>(getDpsInputOperands()),
760  SmallVector<Value>(getDpsInitOperands()));
761 
762  genericAttrNames.push_back("operand_segment_sizes");
763  genericAttrNamesSet.insert(genericAttrNames.back());
764 
765  bool hasExtraAttrs = false;
766  for (NamedAttribute n : (*this)->getAttrs()) {
767  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
768  break;
769  }
770  if (hasExtraAttrs) {
771  p << " attrs = ";
772  p.printOptionalAttrDict((*this)->getAttrs(),
773  /*elidedAttrs=*/genericAttrNames);
774  }
775 
776  // Print region.
777  if (!getRegion().empty()) {
778  p << ' ';
779  p.printRegion(getRegion());
780  }
781 
782  // Print results.
783  printNamedStructuredOpResults(p, getResultTensors().getTypes());
784 }
785 
786 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
787  DictionaryAttr dictAttr;
788  // Parse the core linalg traits that must check into a dictAttr.
789  // The name is unimportant as we will overwrite result.attributes.
790  // The core linalg traits must contain the information necessary to pass the
791  // verifier.
792  if (parser.parseAttribute(dictAttr, "_", result.attributes))
793  return failure();
794  result.attributes.assign(dictAttr.getValue().begin(),
795  dictAttr.getValue().end());
796 
797  // Convert array of string into an array of IteratyType enums. This is needed,
798  // because tests still use the old format when 'iterator_types' attribute is
799  // represented as an array of strings.
800  // TODO: Remove this conversion once tests are fixed.
801  ArrayAttr iteratorTypes =
802  result.attributes.get(getIteratorTypesAttrName(result.name))
803  .cast<ArrayAttr>();
804 
805  SmallVector<Attribute> iteratorTypeAttrs;
806 
807  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
808  auto maybeIteratorType = utils::symbolizeIteratorType(s);
809  if (!maybeIteratorType.has_value())
810  return parser.emitError(parser.getCurrentLocation())
811  << "unexpected iterator_type (" << s << ")";
812 
813  iteratorTypeAttrs.push_back(
814  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
815  }
816  result.attributes.set(getIteratorTypesAttrName(result.name),
817  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
818 
819  // Parsing is shared with named ops, except for the region.
820  SmallVector<Type, 1> inputTypes, outputTypes;
821  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
822  return failure();
823 
824  // Optional attributes may be added.
825  if (succeeded(parser.parseOptionalKeyword("attrs")))
826  if (failed(parser.parseEqual()) ||
827  failed(parser.parseOptionalAttrDict(result.attributes)))
828  return failure();
829 
830  std::unique_ptr<Region> region = std::make_unique<Region>();
831  if (parser.parseRegion(*region, {}))
832  return failure();
833  result.addRegion(std::move(region));
834 
835  // Generic ops may specify that a subset of its outputs are tensors. Such
836  // outputs are specified in the result type.
837  // TODO: may need to move output parsing before region parsing.
838  // Need to wait for declarative assembly resolution to decide.
839  SmallVector<Type, 1> outputTensorsTypes;
840  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
841  return failure();
842  result.addTypes(outputTensorsTypes);
843 
844  return success();
845 }
846 
849  &effects,
850  ValueRange results, const OpOperandVector &inputOperands,
851  const OpOperandVector &outputOperands) {
852  for (auto *operand : inputOperands) {
853  if (!operand->get().getType().isa<MemRefType>())
854  continue;
855  effects.emplace_back(MemoryEffects::Read::get(), operand->get(),
857  }
858  for (auto *operand : outputOperands) {
859  if (!operand->get().getType().isa<MemRefType>())
860  continue;
861  effects.emplace_back(MemoryEffects::Read::get(), operand->get(),
863  effects.emplace_back(MemoryEffects::Write::get(), operand->get(),
865  }
866 }
867 
868 void GenericOp::getEffects(
870  &effects) {
871  getGenericEffectsImpl(effects, getOperation()->getResults(),
872  getDpsInputOperands(), getDpsInitOperands());
873 }
874 
876 
877 namespace {
878 
879 /// Remove generic operations (on tensors) that are just copying
880 /// the values from inputs to the results. Requirements are
881 /// 1) All iterator types are parallel
882 /// 2) The body contains just a yield operation with the yielded values being
883 /// the arguments corresponding to the operands.
884 struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
886 
887  LogicalResult matchAndRewrite(GenericOp genericOp,
888  PatternRewriter &rewriter) const override {
889  // Check all indexing maps are identity.
890  if (llvm::any_of(genericOp.getIndexingMapsArray(),
891  [](AffineMap map) { return !map.isIdentity(); }))
892  return failure();
893 
894  // Check that the body of the linalg operation is just a linalg.yield
895  // operation.
896  Block &body = genericOp.getRegion().front();
897  if (!llvm::hasSingleElement(body))
898  return failure();
899  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
900  if (!yieldOp)
901  return failure();
902 
903  // In the buffer case, we need to check exact buffer equality.
904  if (genericOp.hasBufferSemantics()) {
905  if (genericOp.getNumDpsInputs() == 1 && genericOp.getNumDpsInits() == 1 &&
906  genericOp.getDpsInputOperand(0)->get() ==
907  genericOp.getDpsInitOperand(0)->get()) {
908  rewriter.eraseOp(genericOp);
909  return success();
910  }
911  return failure();
912  }
913 
914  // Mixed semantics is not supported yet.
915  if (!genericOp.hasTensorSemantics())
916  return failure();
917 
918  // Get the argument number of the returned values. That is the operand
919  // number to use for replacing uses of this operation.
920  SmallVector<Value> returnedArgs;
921  for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
922  auto yieldArg = yieldVal.value().dyn_cast<BlockArgument>();
923  if (!yieldArg || yieldArg.getOwner() != &body)
924  return failure();
925  unsigned argumentNumber = yieldArg.getArgNumber();
926  Value returnedArg = genericOp->getOperand(argumentNumber);
927  Type resultType = genericOp->getResult(yieldVal.index()).getType();
928  // The input can have a different type than the result, e.g. a dynamic
929  // input dimension can be turned into a static output dimension.
930  Type returnType = returnedArg.getType();
931  if (returnType != resultType) {
932  // Distinguish between sparse conversion or dense tensor casting.
933  // TODO: unify the two ops?
934  if (sparse_tensor::getSparseTensorEncoding(returnType) ||
936  returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
937  genericOp.getLoc(), resultType, returnedArg);
938  else {
939  if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
940  resultType))
941  return failure();
942  returnedArg = rewriter.create<tensor::CastOp>(
943  genericOp.getLoc(), resultType, returnedArg);
944  }
945  }
946  returnedArgs.push_back(returnedArg);
947  }
948 
949  if (returnedArgs.size() != genericOp->getNumResults())
950  return failure();
951  rewriter.replaceOp(genericOp, returnedArgs);
952  return success();
953  }
954 };
955 
956 } // namespace
957 
958 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
959  MLIRContext *context) {
960  results.add<EraseIdentityGenericOp>(context);
961 }
962 
963 LogicalResult GenericOp::fold(ArrayRef<Attribute>,
965  return memref::foldMemRefCast(*this);
966 }
967 
968 //===----------------------------------------------------------------------===//
969 // MapOp
970 //===----------------------------------------------------------------------===//
971 
973  OpAsmParser &parser, OperationState &result,
974  function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
975  nullptr) {
976  // Parse `ins` and `outs`.
977  SmallVector<Type, 4> inputTypes, outputTypes;
978  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
979  /*addOperandSegmentSizes=*/false))
980  return failure();
981 
982  // Add result types.
983  for (Type outputType : outputTypes) {
984  if (outputType.isa<RankedTensorType>())
985  result.addTypes(outputType);
986  }
987 
988  // Parse required attributes.
989  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
990  return failure();
991 
992  // Parse optional attributes.
993  if (parser.parseOptionalAttrDict(result.attributes))
994  return failure();
995  return success();
996 }
997 
998 void MapOp::getAsmBlockArgumentNames(Region &region,
999  OpAsmSetValueNameFn setNameFn) {
1000  for (Value v : getRegionInputArgs())
1001  setNameFn(v, "in");
1002 }
1003 
1004 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1005  if (!getResults().empty())
1006  setNameFn(getResults().front(), "mapped");
1007 }
1008 
1009 void MapOp::build(
1010  OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1011  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1012  ArrayRef<NamedAttribute> attributes) {
1013  build(builder, result, TypeRange{}, inputs, init);
1014  result.addAttributes(attributes);
1015 
1016  // Add output types for `RankedTensorType` output arguments.
1017  Type initType = init.getType();
1018  if (initType.isa<RankedTensorType>())
1019  result.addTypes(initType);
1020 
1021  if (bodyBuild)
1022  buildGenericRegion(builder, result.location, *result.regions.front(),
1023  inputs, /*outputs=*/{}, bodyBuild);
1024 }
1025 
1026 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1027  if (parseDstStyleOp(parser, result))
1028  return failure();
1029 
1031  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1032  /*allowType=*/true, /*allowAttrs=*/true)) {
1033  return failure();
1034  }
1035 
1036  Region *body = result.addRegion();
1037  if (parser.parseRegion(*body, regionArgs))
1038  return failure();
1039 
1040  return success();
1041 }
1042 
1043 void MapOp::print(OpAsmPrinter &p) {
1044  p.increaseIndent();
1046  p, SmallVector<Value>(getDpsInputOperands()),
1047  SmallVector<Value>(getDpsInitOperands()));
1048  p.printOptionalAttrDict((*this)->getAttrs());
1049 
1050  p.printNewline();
1051  p << "(";
1052  llvm::interleaveComma(getMapper().getArguments(), p,
1053  [&](auto arg) { p.printRegionArgument(arg); });
1054  p << ") ";
1055 
1056  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1057  p.decreaseIndent();
1058 }
1059 
1061  auto *bodyBlock = getBody();
1062  auto blockArgs = bodyBlock->getArguments();
1063 
1064  // Checks if the number of `inputs` match the arity of the `mapper` region.
1065  if (getInputs().size() != blockArgs.size())
1066  return emitOpError() << "expects number of operands to match the arity of "
1067  "mapper, but got: "
1068  << getInputs().size() << " and " << blockArgs.size();
1069 
1070  // The parameters of mapper should all match the element type // of inputs.
1071  for (const auto &[bbArgType, inputArg] :
1072  llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1073  auto inputElemType = inputArg.getType().cast<ShapedType>().getElementType();
1074  if (bbArgType != inputElemType) {
1075  return emitOpError() << "expected element type of input " << inputElemType
1076  << " to match bbArg type " << bbArgType;
1077  }
1078  }
1079 
1080  // The shape of each input must match the shape of the output.
1081  auto outputShape = getInit().getType().cast<ShapedType>().getShape();
1082  for (Type inputArgType : TypeRange{getInputs()}) {
1083  auto inputElemShape = inputArgType.cast<ShapedType>().getShape();
1084  if (inputElemShape != outputShape) {
1085  return emitOpError() << "expected shape of input (" << inputElemShape
1086  << ") to match shape of output (" << outputShape
1087  << ")";
1088  }
1089  }
1090 
1091  return success();
1092 }
1093 
1094 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1095  int64_t rank = getInit().getType().getRank();
1096  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1097 }
1098 
1099 ArrayAttr MapOp::getIndexingMaps() {
1100  Builder builder(getContext());
1101  int64_t rank = getInit().getType().getRank();
1102  int64_t numIndexingMaps = getOperands().size();
1104  numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1105 }
1106 
1107 void MapOp::getEffects(
1109  &effects) {
1110  getGenericEffectsImpl(effects, getOperation()->getResults(),
1111  getDpsInputOperands(), getDpsInitOperands());
1112 }
1113 
1114 //===----------------------------------------------------------------------===//
1115 // ReduceOp
1116 //===----------------------------------------------------------------------===//
1117 
1118 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1119  OpAsmSetValueNameFn setNameFn) {
1120  for (Value v : getRegionInputArgs())
1121  setNameFn(v, "in");
1122  for (Value v : getRegionOutputArgs())
1123  setNameFn(v, "init");
1124 }
1125 
1126 void ReduceOp::getAsmResultNames(
1127  function_ref<void(Value, StringRef)> setNameFn) {
1128  if (!getResults().empty())
1129  setNameFn(getResults().front(), "reduced");
1130 }
1131 
1132 void ReduceOp::build(
1133  OpBuilder &builder, OperationState &result, ValueRange inputs,
1134  ValueRange inits, ArrayRef<int64_t> dimensions,
1135  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1136  ArrayRef<NamedAttribute> attributes) {
1137  build(builder, result, TypeRange{}, inputs, inits, dimensions);
1138  result.addAttributes(attributes);
1139 
1140  // Add output types for `RankedTensorType` output arguments.
1141  for (Value init : inits) {
1142  Type initType = init.getType();
1143  if (initType.isa<RankedTensorType>())
1144  result.addTypes(initType);
1145  }
1146 
1147  if (bodyBuild)
1148  buildGenericRegion(builder, result.location, *result.regions.front(),
1149  inputs, inits, bodyBuild);
1150 }
1151 
1152 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1153  int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
1154  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1155  utils::IteratorType::parallel);
1156  for (int64_t reductionDim : getDimensions())
1157  iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1158  return iteratorTypes;
1159 }
1160 
1161 ArrayAttr ReduceOp::getIndexingMaps() {
1162  int64_t inputRank = getInputs()[0].getType().cast<ShapedType>().getRank();
1163  SmallVector<AffineMap> affineMaps(
1164  getNumDpsInputs(),
1165  AffineMap::getMultiDimIdentityMap(inputRank, getContext()));
1166  AffineMap resultMap =
1167  AffineMap::getMultiDimIdentityMap(inputRank, getContext())
1168  .dropResults(getDimensions());
1169  for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1170  affineMaps.push_back(resultMap);
1171  return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1172 }
1173 
1174 void ReduceOp::getEffects(
1176  &effects) {
1177  getGenericEffectsImpl(effects, getOperation()->getResults(),
1178  getDpsInputOperands(), getDpsInitOperands());
1179 }
1180 
1182  NamedAttrList &attributes,
1183  StringRef attributeName) {
1184  if (parser.parseKeyword(attributeName) || parser.parseEqual())
1185  return failure();
1186 
1187  attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1188  return success();
1189 }
1190 
1191 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1192  if (parseDstStyleOp(
1193  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1194  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1195  }))
1196  return failure();
1197 
1199  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1200  /*allowType=*/true, /*allowAttrs=*/true)) {
1201  return failure();
1202  }
1203 
1204  Region *body = result.addRegion();
1205  if (parser.parseRegion(*body, regionArgs))
1206  return failure();
1207 
1208  return success();
1209 }
1210 
1211 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1212  ArrayRef<int64_t> attributeValue) {
1213  p << attributeName << " = [" << attributeValue << "] ";
1214 }
1215 
1216 void ReduceOp::print(OpAsmPrinter &p) {
1217  p.increaseIndent();
1219  p, SmallVector<Value>(getDpsInputOperands()),
1220  SmallVector<Value>(getDpsInitOperands()));
1221  p.printNewline();
1222 
1223  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1224  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1225 
1226  p.printNewline();
1227  p << "(";
1228  llvm::interleaveComma(getCombiner().getArguments(), p,
1229  [&](auto arg) { p.printRegionArgument(arg); });
1230  p << ") ";
1231 
1232  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1233  p.decreaseIndent();
1234 }
1235 
1237  ArrayRef<int64_t> dimensionsRef = getDimensions();
1238 
1239  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1240  if (getInputs()[i].getType().cast<ShapedType>().getShape() !=
1241  getInputs()[0].getType().cast<ShapedType>().getShape()) {
1242  return emitOpError() << "expects all inputs to have the same shapes. "
1243  "Shape at input-index "
1244  << i
1245  << " is not equal to the shape at input-index 0.";
1246  }
1247  }
1248  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1249  if (getInits()[i].getType().cast<ShapedType>().getShape() !=
1250  getInits()[0].getType().cast<ShapedType>().getShape()) {
1251  return emitOpError() << "expects all outputs to have the same shapes. "
1252  "Shape at output-index "
1253  << i
1254  << " is not equal to the shape at output-index 0.";
1255  }
1256  }
1257  auto inputType = getInputs()[0].getType().cast<ShapedType>();
1258  auto initType = getInits()[0].getType().cast<ShapedType>();
1259 
1260  DenseSet<int64_t> dimensionsToReduce;
1261  for (int64_t dimension : dimensionsRef) {
1262  if (dimension < 0 || dimension >= inputType.getRank()) {
1263  return emitOpError()
1264  << "dimensions for reduction should be in the range [0, "
1265  << inputType.getRank() - 1 << "].";
1266  }
1267  dimensionsToReduce.insert(dimension);
1268  }
1269 
1270  auto inputDims = inputType.getShape();
1271  auto initDims = initType.getShape();
1272 
1273  // Input dimensions that will be left after the reduction.
1274  SmallVector<int64_t> reducedInputDims;
1275  for (const auto &en : llvm::enumerate(inputDims)) {
1276  if (!dimensionsToReduce.count(en.index()))
1277  reducedInputDims.push_back(en.value());
1278  }
1279 
1280  if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1281  return emitOpError() << "number of dimensions after reduction "
1282  << reducedInputDims.size()
1283  << " doesn't match the init rank "
1284  << initType.getRank();
1285  }
1286 
1287  if (reducedInputDims != initDims)
1288  return emitOpError() << "init dimensions [" << initDims
1289  << "] doesn't match input dimensions after reduction ["
1290  << reducedInputDims << "]";
1291 
1292  Block *block = getBody();
1293  if (block->getNumArguments() != this->getNumOperands())
1294  return emitOpError()
1295  << "mismatching number of operands and block arguments";
1296 
1297  // Check that the first block arguments match the element type of the inputs.
1298  for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1299  Type inputElementType = input.getType().cast<ShapedType>().getElementType();
1300  if (inputElementType != bbArg.getType())
1301  return emitOpError()
1302  << "input element type " << inputElementType
1303  << " does not match corresponding block argument type "
1304  << bbArg.getType();
1305  }
1306 
1307  // Check that the last block arguments match the element type of the outputs.
1308  for (auto [output, bbArg] :
1309  llvm::zip(getDpsInitOperands(),
1310  block->getArguments().take_back(getNumDpsInits()))) {
1311  auto outputElementType =
1312  output->get().getType().cast<ShapedType>().getElementType();
1313  if (outputElementType != bbArg.getType())
1314  return emitOpError()
1315  << "output element type " << outputElementType
1316  << " does not match corresponding block argument type "
1317  << bbArg.getType();
1318  }
1319  return success();
1320 }
1321 
1322 //===----------------------------------------------------------------------===//
1323 // TransposeOp
1324 //===----------------------------------------------------------------------===//
1325 
1326 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1327  Region &region, ValueRange inputs,
1328  ValueRange outputs) {
1329  buildGenericRegion(builder, loc, region, inputs, outputs,
1330  [](OpBuilder &b, Location loc, ValueRange args) {
1331  b.create<linalg::YieldOp>(loc, args[0]);
1332  });
1333 }
1334 
1335 void TransposeOp::build(::mlir::OpBuilder &builder,
1336  ::mlir::OperationState &result, Value input, Value init,
1337  DenseI64ArrayAttr permutation,
1338  ArrayRef<NamedAttribute> attributes) {
1339  result.addOperands(input);
1340  result.addOperands(init);
1341  result.addAttribute(getPermutationAttrName(result.name), permutation);
1342  result.addAttributes(attributes);
1343 
1344  // Add output types for `RankedTensorType` output arguments.
1345  Type initType = init.getType();
1346  if (initType.isa<RankedTensorType>())
1347  result.addTypes(initType);
1348 
1349  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1350  init);
1351 }
1352 
1353 void TransposeOp::build(::mlir::OpBuilder &builder,
1354  ::mlir::OperationState &result, Value input, Value init,
1355  ArrayRef<int64_t> permutation,
1356  ArrayRef<NamedAttribute> attributes) {
1357  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1358  attributes);
1359 }
1360 
1361 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1362  if (failed(parseDstStyleOp(
1363  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1364  return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1365  })))
1366  return failure();
1367 
1368  OpBuilder builder(parser.getContext());
1369  buildIdentityRegion(builder, result.location, *result.addRegion(),
1370  /*inputs=*/result.operands,
1371  /*outputs=*/{});
1372  return success();
1373 }
1374 
1375 void TransposeOp::getAsmResultNames(
1376  function_ref<void(Value, StringRef)> setNameFn) {
1377  if (!getResults().empty())
1378  setNameFn(getResults().front(), "transposed");
1379 }
1380 
1382  p.increaseIndent();
1384  p, SmallVector<Value>(getDpsInputOperands()),
1385  SmallVector<Value>(getDpsInitOperands()));
1386  p.printNewline();
1387 
1388  printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1389  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1390  p.decreaseIndent();
1391 }
1392 
1394  ArrayRef<int64_t> permutationRef = getPermutation();
1395 
1396  if (!isPermutationVector(permutationRef))
1397  return emitOpError("permutation is not valid");
1398 
1399  auto inputType = getInput().getType();
1400  auto initType = getInit().getType();
1401 
1402  int64_t rank = inputType.getRank();
1403 
1404  if (rank != initType.getRank())
1405  return emitOpError() << "input rank " << rank
1406  << " does not match init rank " << initType.getRank();
1407 
1408  if (rank != static_cast<int64_t>(permutationRef.size()))
1409  return emitOpError() << "size of permutation " << permutationRef.size()
1410  << " does not match the argument rank " << rank;
1411 
1412  auto inputDims = inputType.getShape();
1413  auto initDims = initType.getShape();
1414 
1415  for (int64_t i = 0; i < rank; ++i) {
1416  int64_t inputDim = inputDims[permutationRef[i]];
1417  int64_t initDim = initDims[i];
1418 
1419  if (inputDim != initDim) {
1420  return emitOpError() << "dim(result, " << i << ") = " << initDim
1421  << " doesn't match dim(input, permutation[" << i
1422  << "]) = " << inputDim;
1423  }
1424  }
1425 
1426  return success();
1427 }
1428 
1429 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1430  int64_t rank = getInit().getType().getRank();
1431  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1432 }
1433 
1434 ArrayAttr TransposeOp::getIndexingMaps() {
1435  Builder builder(getContext());
1436  int64_t rank = getInit().getType().getRank();
1437  return builder.getAffineMapArrayAttr(
1438  {builder.getMultiDimIdentityMap(rank),
1440  llvm::to_vector_of<unsigned>(getPermutation()), getContext())});
1441 }
1442 
1443 void TransposeOp::getEffects(
1445  &effects) {
1446  getGenericEffectsImpl(effects, getOperation()->getResults(),
1447  getDpsInputOperands(), getDpsInitOperands());
1448 }
1449 
1450 //===----------------------------------------------------------------------===//
1451 // BroadcastOp
1452 //===----------------------------------------------------------------------===//
1453 
1454 void BroadcastOp::build(::mlir::OpBuilder &builder,
1455  ::mlir::OperationState &result, Value input, Value init,
1456  DenseI64ArrayAttr dimensions,
1457  ArrayRef<NamedAttribute> attributes) {
1458  result.addOperands(input);
1459  result.addOperands(init);
1460  result.addAttribute(getDimensionsAttrName(result.name), dimensions);
1461  result.addAttributes(attributes);
1462 
1463  // Add output types for `RankedTensorType` output arguments.
1464  Type initType = init.getType();
1465  if (initType.isa<RankedTensorType>())
1466  result.addTypes(initType);
1467 
1468  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1469  init);
1470 }
1471 
1472 void BroadcastOp::build(::mlir::OpBuilder &builder,
1473  ::mlir::OperationState &result, Value input, Value init,
1474  ArrayRef<int64_t> dimensions,
1475  ArrayRef<NamedAttribute> attributes) {
1476  build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
1477  attributes);
1478 }
1479 
1480 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
1481  if (failed(parseDstStyleOp(
1482  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1483  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1484  })))
1485  return failure();
1486 
1487  OpBuilder builder(parser.getContext());
1488  buildIdentityRegion(builder, result.location, *result.addRegion(),
1489  /*inputs=*/result.operands,
1490  /*outputs=*/{});
1491  return success();
1492 }
1493 
1494 void BroadcastOp::getAsmResultNames(
1495  function_ref<void(Value, StringRef)> setNameFn) {
1496  if (!getResults().empty())
1497  setNameFn(getResults().front(), "broadcasted");
1498 }
1499 
1501  p.increaseIndent();
1503  p, SmallVector<Value>(getDpsInputOperands()),
1504  SmallVector<Value>(getDpsInitOperands()));
1505  p.printNewline();
1506 
1507  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1508  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1509  p.decreaseIndent();
1510 }
1511 
1513  ArrayRef<int64_t> dimensionsRef = getDimensions();
1514 
1515  auto inputType = getInput().getType();
1516  auto initType = getInit().getType();
1517 
1518  int64_t inputRank = inputType.getRank();
1519  int64_t initRank = initType.getRank();
1520 
1521  auto inputShape = inputType.getShape();
1522  auto initShape = initType.getShape();
1523 
1524  if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
1525  return emitOpError() << "input rank plus added dimensions does not "
1526  "match init rank. input rank: "
1527  << inputRank
1528  << ", dimensions size: " << dimensionsRef.size()
1529  << ", init rank: " << initRank;
1530 
1531  for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
1532  if (dim < 0 || dim >= initRank)
1533  return emitOpError() << "dimension " << idx
1534  << " is out of range. expected range: [0, "
1535  << initRank - 1 << "], got: " << dim;
1536  }
1537 
1538  // Mapping from input dims to init dims.
1539  SmallVector<int64_t> dimMap;
1540  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
1541  if (!llvm::is_contained(dimensionsRef, dim))
1542  dimMap.push_back(dim);
1543  }
1544 
1545  for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
1546  // This dimensions is mapped from the input. Init and input dims should
1547  // match.
1548  if (inputShape[inputDimIdx] != initShape[initDimIdx])
1549  return emitOpError() << "input dim " << inputDimIdx
1550  << " should match init dim " << initDimIdx
1551  << ". input: " << inputShape[inputDimIdx]
1552  << ", init: " << initShape[initDimIdx];
1553  }
1554 
1555  return success();
1556 }
1557 
1558 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
1559  int64_t rank = getInit().getType().getRank();
1560  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1561 }
1562 
1563 ArrayAttr BroadcastOp::getIndexingMaps() {
1564  Builder builder(getContext());
1565  int64_t rank = getInit().getType().getRank();
1566  return builder.getAffineMapArrayAttr(
1567  {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
1568  builder.getMultiDimIdentityMap(rank)});
1569 }
1570 
1571 void BroadcastOp::getEffects(
1573  &effects) {
1574  getGenericEffectsImpl(effects, getOperation()->getResults(),
1575  getDpsInputOperands(), getDpsInitOperands());
1576 }
1577 
1578 //===----------------------------------------------------------------------===//
1579 // YieldOp
1580 //===----------------------------------------------------------------------===//
1581 
1583  if (getNumOperands() > 0)
1584  p << ' ' << getOperands();
1585  p.printOptionalAttrDict((*this)->getAttrs());
1586  if (getNumOperands() > 0)
1587  p << " : " << getOperandTypes();
1588 }
1589 
1590 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
1592  SmallVector<Type, 2> types;
1593  SMLoc loc = parser.getCurrentLocation();
1594  return failure(parser.parseOperandList(opInfo) ||
1595  parser.parseOptionalAttrDict(result.attributes) ||
1596  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
1597  parser.resolveOperands(opInfo, types, loc, result.operands));
1598 }
1599 
1600 // Check the operand number and types must match the element types of the
1601 // LinalgOp interface's shaped operands.
1602 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
1603  if (op.getNumOperands() != linalgOp.getNumDpsInits())
1604  return op.emitOpError("expected number of yield values (")
1605  << linalgOp.getNumDpsInits()
1606  << ") to match the number of operands of the enclosing "
1607  << "LinalgOp (" << op.getNumOperands() << ")";
1608 
1609  for (OpOperand &opOperand : op->getOpOperands()) {
1610  OpOperand *outputOperand =
1611  linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
1612  Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
1613  if (opOperand.get().getType() != elementType)
1614  return op.emitOpError("type of yield operand ")
1615  << (opOperand.getOperandNumber() + 1) << " ("
1616  << opOperand.get().getType() << ") doesn't match "
1617  << "the element type of the enclosing linalg.generic op ("
1618  << elementType << ")";
1619  }
1620  return success();
1621 }
1622 
1624  auto *parentOp = (*this)->getParentOp();
1625  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
1626  return emitOpError("expected single non-empty parent region");
1627 
1628  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
1629  return verifyYield(*this, linalgOp);
1630 
1631  return emitOpError("expected parent op with LinalgOp interface");
1632 }
1633 
1634 //===----------------------------------------------------------------------===//
1635 // IndexOp
1636 //===----------------------------------------------------------------------===//
1637 
1639  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
1640  if (!linalgOp)
1641  return emitOpError("expected parent op with LinalgOp interface");
1642  if (linalgOp.getNumLoops() <= getDim())
1643  return emitOpError("expected dim (")
1644  << getDim() << ") to be lower than the number of loops ("
1645  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
1646  return success();
1647 }
1648 
1649 /////// Operations corresponding to library calls defined with Tablegen ////////
1650 
1651 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
1652 
1653 #define GET_OP_CLASSES
1654 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
1655 
1656 #define GET_OP_CLASSES
1657 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
1658 
1660  unsigned rank,
1661  MLIRContext *context) {
1662  if (maybeMap)
1663  return *maybeMap;
1664  if (rank == 0)
1665  return AffineMap::get(context);
1666  return AffineMap::getMultiDimIdentityMap(rank, context);
1667 }
1668 
1670 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
1671  MLIRContext *context) {
1673  res.reserve(num);
1674  for (unsigned i = 0; i < num; ++i)
1675  res.push_back(getAffineDimExpr(startIdx++, context));
1676  return res;
1677 }
1678 
1681  auto rangeA = llvm::make_range(a.begin(), a.end());
1682  auto rangeB = llvm::make_range(b.begin(), b.end());
1683  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
1684  return llvm::to_vector<4>(concatRanges);
1685 }
1686 
1687 static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
1688  if (auto memref = t.dyn_cast<MemRefType>()) {
1689  ss << "view";
1690  for (auto size : memref.getShape())
1691  if (size < 0)
1692  ss << "sx";
1693  else
1694  ss << size << "x";
1695  appendMangledType(ss, memref.getElementType());
1696  } else if (auto vec = t.dyn_cast<VectorType>()) {
1697  ss << "vector";
1698  llvm::interleave(
1699  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
1700  appendMangledType(ss, vec.getElementType());
1701  } else if (t.isSignlessIntOrIndexOrFloat()) {
1702  ss << t;
1703  } else {
1704  llvm_unreachable("Invalid type for linalg library name mangling");
1705  }
1706 }
1707 
1709  assert(isa<LinalgOp>(op));
1710  std::string name(op->getName().getStringRef().str());
1711  name.reserve(128);
1712  std::replace(name.begin(), name.end(), '.', '_');
1713  llvm::raw_string_ostream ss(name);
1714  ss << "_";
1715  auto types = op->getOperandTypes();
1716  llvm::interleave(
1717  types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
1718  [&]() { ss << "_"; });
1719  return ss.str();
1720 }
1721 
1722 //===----------------------------------------------------------------------===//
1723 // Canonicalizers and Folders.
1724 //===----------------------------------------------------------------------===//
1725 
1726 namespace {
1727 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
1729 
1730  LogicalResult matchAndRewrite(LinalgOp op,
1731  PatternRewriter &rewriter) const override {
1732  for (OpOperand &opOperand : op->getOpOperands()) {
1733  // Linalg "inputs" may be either tensor or memref type.
1734  // tensor<0xelt_type> is a convention that may not always mean
1735  // "0 iterations". Only erase in cases we see memref<...x0x...>.
1736  auto mt = opOperand.get().getType().dyn_cast<MemRefType>();
1737  if (!mt)
1738  continue;
1739  if (llvm::is_contained(op.getShape(&opOperand), 0)) {
1740  rewriter.eraseOp(op);
1741  return success();
1742  }
1743  }
1744  return failure();
1745  }
1746 };
1747 
1748 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
1749 /// result that is more static than the linalg op.
1750 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
1752 
1753  LogicalResult matchAndRewrite(tensor::CastOp castOp,
1754  PatternRewriter &rewriter) const override {
1755  if (!tensor::canFoldIntoProducerOp(castOp))
1756  return failure();
1757 
1758  auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
1759  if (!linalgOp)
1760  return failure();
1761 
1762  // Cast can be in conditionally reachable region, if which case folding will
1763  // generate invalid code. Only conservatively fold ops in same block for
1764  // now.
1765  if (castOp->getBlock() != linalgOp->getBlock())
1766  return failure();
1767 
1768  OpBuilder::InsertionGuard guard(rewriter);
1769  rewriter.setInsertionPoint(linalgOp);
1770 
1771  Location loc = linalgOp.getLoc();
1772  OpResult resultValue = castOp.getSource().cast<OpResult>();
1773  unsigned resultNumber = resultValue.getResultNumber();
1774  auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
1775  // Replace the `outs` for the result with a `tensor.cast`. This cast is now
1776  // going from a more dynamic shape to a less dynamic shape. If the producer
1777  // for this cast, i.e. producer of the out operand, is also an operation
1778  // that folds with tensor.cast consumer (like this pattern), the cast will
1779  // continue to propagate as far up the stack as it can go.
1780  OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
1781  Value newOperand =
1782  rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
1783  SmallVector<Value> newOperands{linalgOp.getDpsInputOperands()};
1784  SmallVector<Value> outputOperands{linalgOp.getDpsInitOperands()};
1785  outputOperands[resultNumber] = newOperand;
1786  newOperands.append(outputOperands.begin(), outputOperands.end());
1787 
1788  SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
1789  linalgOp->result_type_end());
1790  resultTypes[resultNumber] = resultType;
1791  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
1792 
1793  // Create a tensor.cast operation back to the original type.
1794  Value castBack = rewriter.create<tensor::CastOp>(
1795  loc, resultValue.getType(), newOp->getResult(resultNumber));
1796 
1797  SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
1798  results[resultNumber] = castBack;
1799  rewriter.replaceOp(linalgOp, results);
1800  rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
1801  return success();
1802  }
1803 };
1804 
1805 /// For each of the operand in `operands` this function maps the static sizes of
1806 /// dimensions to their affine dim expressions.
1807 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
1808  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
1809  for (OpOperand &opOperand : operands) {
1810  if (linalgOp.isScalar(&opOperand))
1811  continue;
1812  Value src = opOperand.get();
1813  auto sourceType = src.getType().cast<RankedTensorType>();
1814  auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
1815 
1816  // Get the `sourceShape` of the `sourceType`. If the operand is a result of
1817  // `tensor.cast` operation and source of the cast operation has a static
1818  // shape, then assign it to the `sourceShape`.
1819  auto *parentOp = src.getDefiningOp();
1820  ArrayRef<int64_t> sourceShape = sourceType.getShape();
1821  if (parentOp) {
1822  if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
1823  Value castSource = castOp.getSource();
1824  auto castSourceType = castSource.getType().dyn_cast<RankedTensorType>();
1825  if (castSourceType && castSourceType.hasStaticShape())
1826  sourceShape = castSourceType.getShape();
1827  }
1828  }
1829 
1830  // If the source shape's dimension has a static shape, map the affine dim
1831  // expression to the known static size.
1832  for (unsigned i = 0; i < sourceShape.size(); i++) {
1833  if (sourceType.isDynamicDim(i))
1834  continue;
1835  if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast<AffineDimExpr>())
1836  affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
1837  }
1838  }
1839 }
1840 
1841 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
1842 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
1843 /// their result types is stored in `resultTypes`. If `opOperand` requires no
1844 /// change then `changeNeeded` is false and same operand is added in the
1845 /// `newOperands` list.
1846 static void createNewOperandWithStaticSizes(
1847  Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
1848  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
1849  SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
1850  bool &changeNeeded) {
1851  Value src = opOperand->get();
1852  newOperands.push_back(src);
1853  if (linalgOp.isScalar(opOperand))
1854  return;
1855  auto sourceType = src.getType().cast<RankedTensorType>();
1856  Type resultType = sourceType;
1857  if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
1858  resultTypes.push_back(resultType);
1859  return;
1860  }
1861  ArrayRef<int64_t> sourceShape = sourceType.getShape();
1862  AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
1863  SmallVector<int64_t> newShape;
1864  // If operand is updated with new shape, `newOperandNeeded` will be
1865  // true.
1866  bool newOperandNeeded = false;
1867  for (unsigned i = 0; i < sourceShape.size(); i++) {
1868  int64_t dimShape = sourceShape[i];
1869  AffineExpr dimExpr = sourceMap.getResult(i);
1870  if (affineExprToSize.find(dimExpr) == affineExprToSize.end() ||
1871  !sourceType.isDynamicDim(i)) {
1872  newShape.push_back(dimShape);
1873  continue;
1874  }
1875  // Dimension has a dynamic shape and corresponding affine dim
1876  // expression is present in the map. So assign the size for the
1877  // given affine dim expression to the dimension.
1878  newShape.push_back(affineExprToSize[dimExpr]);
1879  newOperandNeeded = true;
1880  }
1881  resultType = RankedTensorType::get(newShape, sourceType.getElementType());
1882  if (newOperandNeeded) {
1883  changeNeeded = true;
1884  // Get the new operand value given its size and element type by
1885  // casting it.
1886  Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
1887  unsigned index = opOperand->getOperandNumber();
1888  newOperands[index] = newOperand;
1889  }
1890  if (linalgOp.isDpsInit(opOperand))
1891  resultTypes.push_back(resultType);
1892 }
1893 
1894 /// Static shapes for the operands can be inferred if any one of the operands
1895 /// have a static shape. This can be done by referring to the affine dim
1896 /// expressions for the operand.
1897 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
1899 
1900  LogicalResult matchAndRewrite(LinalgOp linalgOp,
1901  PatternRewriter &rewriter) const override {
1902  if (!linalgOp.hasTensorSemantics())
1903  return failure();
1904 
1905  // Maps must be projected permutations.
1906  if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
1907  return !map.isProjectedPermutation();
1908  }))
1909  return failure();
1910 
1911  // Maps affine dim expressions to the static size of that dimension.
1912  llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
1913  Location loc = linalgOp.getLoc();
1914 
1915  // For each of the affine dim expression, check if the size is known. If
1916  // known add that in the map.
1917  populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
1918 
1919  SmallVector<Value> newOperands;
1920  SmallVector<Type> resultTypes;
1921 
1922  // `changeNeeded` is `false` if the operands of `linalgOp` require no
1923  // change in their types.
1924  bool changeNeeded = false;
1925  newOperands.reserve(linalgOp->getNumOperands());
1926  resultTypes.reserve(linalgOp.getNumDpsInits());
1927 
1928  // Iterate over all the operands and update the static sizes.
1929  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1930  createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
1931  affineExprToSize, linalgOp, newOperands,
1932  resultTypes, changeNeeded);
1933  }
1934 
1935  // If the generic op has all the required static information, no
1936  // canonicalization needed.
1937  if (!changeNeeded)
1938  return failure();
1939 
1940  // Clone op.
1941  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
1942  SmallVector<Value> replacements;
1943  replacements.reserve(newOp->getNumResults());
1944  for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
1945  Value newResult = std::get<1>(it);
1946  Value oldResult = std::get<0>(it);
1947  Type newType = newResult.getType();
1948  Type oldType = oldResult.getType();
1949  replacements.push_back(
1950  (newType != oldType)
1951  ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
1952  : newResult);
1953  }
1954  rewriter.replaceOp(linalgOp, replacements);
1955  return success();
1956  }
1957 };
1958 
1959 } // namespace
1960 
1961 // All named ops canonicalizers and folders are auto-generated in the
1962 // .cpp.inc.
1963 
1964 //===----------------------------------------------------------------------===//
1965 // LinalgDialect
1966 //===----------------------------------------------------------------------===//
1967 
1968 void LinalgDialect::getCanonicalizationPatterns(
1969  RewritePatternSet &results) const {
1970  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
1971  InferStaticShapeOfOperands>(getContext());
1972 }
1973 
1975  Attribute value, Type type,
1976  Location loc) {
1977  return builder.create<arith::ConstantOp>(loc, type, value);
1978 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static constexpr const bool value
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:1326
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:193
static void buildStructuredOp(OpBuilder &b, OperationState &state, llvm::Optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
Definition: LinalgOps.cpp:96
static void printCommonStructuredOpPartsWithNewLine(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:177
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
Definition: LinalgOps.cpp:59
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:169
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:252
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
Definition: LinalgOps.cpp:1181
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
Definition: LinalgOps.cpp:1211
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
Definition: LinalgOps.cpp:972
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, ValueRange results, const OpOperandVector &inputOperands, const OpOperandVector &outputOperands)
Definition: LinalgOps.cpp:847
static void appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:1687
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1602
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:219
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
Definition: LinalgOps.cpp:627
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:212
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:245
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
Definition: LinalgOps.cpp:126
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:698
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:42
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:256
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap dropResults(ArrayRef< int64_t > positions)
Definition: AffineMap.h:251
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:323
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:206
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:67
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast() const
Definition: Attributes.h:127
U cast() const
Definition: Attributes.h:137
This class represents an argument of a Block.
Definition: Value.h:296
Block represents an ordered list of Operations.
Definition: Block.h:30
unsigned getNumArguments()
Definition: Block.h:117
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
BlockArgListType getArguments()
Definition: Block.h:76
Operation & front()
Definition: Block.h:142
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:49
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:153
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:157
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:350
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:243
MLIRContext * getContext() const
Definition: Builders.h:54
Location getUnknownLoc()
Definition: Builders.cpp:26
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:247
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:300
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:137
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:150
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:300
This class helps build Operations.
Definition: Builders.h:198
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:383
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:350
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition: Builders.h:388
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:395
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
This class represents an operand of an operation.
Definition: Value.h:247
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:212
This is a value defined by a result of an operation.
Definition: Value.h:442
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:454
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
result_iterator result_begin()
Definition: Operation.h:330
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:356
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
operand_type_range getOperandTypes()
Definition: Operation.h:314
result_iterator result_end()
Definition: Operation.h:331
result_type_range getResultTypes()
Definition: Operation.h:345
result_range getResults()
Definition: Operation.h:332
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:610
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:522
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
U dyn_cast() const
Definition: Types.h:270
bool isa() const
Definition: Types.h:260
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:93
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:79
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
U cast() const
Definition: Value.h:105
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:1679
AffineMap extractOrIdentityMap(Optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
Definition: LinalgOps.cpp:1659
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:1708
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:1670
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:89
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
Definition: MPInt.h:370
MPInt ceil(const Fraction &f)
Definition: Fraction.h:70
MPInt floor(const Fraction &f)
Definition: Fraction.h:68
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:202
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:83
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1012
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:343
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:488
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:372
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context)
This parses a single MLIR attribute to an MLIR context if it was valid.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:371
OpOperand vector that implicitly converts to a Value vector.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:356
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:360
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.