MLIR  19.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
29 #include "mlir/IR/AffineMap.h"
32 #include "mlir/IR/Matchers.h"
35 #include "mlir/IR/PatternMatch.h"
37 
38 #include "llvm/ADT/DenseMap.h"
39 #include "llvm/ADT/SmallSet.h"
40 #include "llvm/ADT/StringSet.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/Support/FormatVariadic.h"
43 #include "llvm/Support/MathExtras.h"
44 #include "llvm/Support/raw_ostream.h"
45 #include <optional>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
52  int64_t dim) {
53  auto type = cast<ShapedType>(v.getType());
54  if (!type.isDynamicDim(dim))
55  return builder.getIndexAttr(type.getDimSize(dim));
56 
57  return getAsOpFoldResult(
59  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
60  return builder.create<tensor::DimOp>(loc, v, dim);
61  })
62  .Case<MemRefType>([&](MemRefType t) -> Value {
63  return builder.create<memref::DimOp>(loc, v, dim);
64  }));
65 }
66 
67 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
68 /// `source`.
69 static Value getSlice(OpBuilder &b, Location loc, Value source,
70  ArrayRef<OpFoldResult> offsets,
72  ArrayRef<OpFoldResult> strides) {
73  return TypeSwitch<Type, Value>(source.getType())
74  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
75  return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
76  strides);
77  })
78  .Case<MemRefType>([&](MemRefType type) -> Value {
79  return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
80  strides);
81  })
82  .Default([&](Type t) { return nullptr; });
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // Helper functions
87 //===----------------------------------------------------------------------===//
88 
90  int64_t dim) {
91  if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
92  return b.createOrFold<memref::DimOp>(loc, source, dim);
93  if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
94  return b.createOrFold<tensor::DimOp>(loc, source, dim);
95  llvm_unreachable("Expected MemRefType or TensorType");
96 }
97 
99  int64_t dim) {
100  auto shapedType = llvm::cast<ShapedType>(source.getType());
101  if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
102  return createOrFoldDimOp(b, loc, source, dim);
103  return b.getIndexAttr(shapedType.getDimSize(dim));
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // Support for named Linalg ops defined in ods-gen.
108 //===----------------------------------------------------------------------===//
109 
112 
113 /// Fills the region of a structured operation using the provided
114 /// `regionBuilder`. The method is used by both named structured ops created by
115 /// ods-gen and by manually defined C++ ops. It is called by both builders and
116 /// parsers and creates a block with arguments corresponding to the elemental
117 /// types of `inputTypes` and `outputTypes`. All output types are asserted to be
118 /// ShapedType.
119 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
120  TypeRange inputTypes, TypeRange outputTypes,
122  RegionBuilderFn regionBuilder) {
123  assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
124 
125  SmallVector<Type, 8> argTypes;
126  SmallVector<Location, 8> argLocs;
127  for (auto containers : {inputTypes, outputTypes}) {
128  for (auto t : containers) {
129  argTypes.push_back(
130  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
131 
132  // TODO: Pass in a proper location here.
133  argLocs.push_back(opBuilder.getUnknownLoc());
134  }
135  }
136 
137  // RAII.
138  OpBuilder::InsertionGuard guard(opBuilder);
139  Block *body =
140  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
141 
142  opBuilder.setInsertionPointToStart(body);
143  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
144  regionBuilder(b, *body, attrs);
145 
146  // indexing_maps is an auto-generated method.
147 
148  // iterator_types is an auto-generated method.
149 }
150 
151 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
152 /// The result types are derived automatically if `resultTensorTypes` is none.
153 /// The body of the operation is filled using `regionBuilder`. All ods-gen
154 /// created structured operations use the method to implement their builders.
156  std::optional<TypeRange> resultTensorTypes,
157  ValueRange inputs, ValueRange outputs,
158  ArrayRef<NamedAttribute> attributes,
159  RegionBuilderFn regionBuilder) {
160  // Derive the result types if needed.
161  SmallVector<Type> derivedResultTypes =
162  resultTensorTypes.value_or(TypeRange());
163  if (!resultTensorTypes)
164  copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
165  llvm::IsaPred<RankedTensorType>);
166 
167  state.addOperands(inputs);
168  state.addOperands(outputs);
169  state.addTypes(derivedResultTypes);
170  state.addAttributes(attributes);
171  state.addAttribute(
172  "operandSegmentSizes",
173  b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
174  static_cast<int32_t>(outputs.size())}));
175 
176  // Create and fill the region of the structured operation.
177  Region &region = *state.addRegion();
178  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
179  state.attributes.getAttrs(), regionBuilder);
180 }
181 
182 /// Common parsing used for both named structured ops created by ods-gen and by
183 /// manually defined C++ ops. Does not handle regions.
184 static ParseResult
186  SmallVectorImpl<Type> &inputTypes,
187  SmallVectorImpl<Type> &outputTypes,
188  bool addOperandSegmentSizes = true) {
189  SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
191  outputsOperands;
192 
193  if (succeeded(parser.parseOptionalLess())) {
194  if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
195  return failure();
196  }
197  attrsLoc = parser.getCurrentLocation();
198  if (parser.parseOptionalAttrDict(result.attributes))
199  return failure();
200 
201  if (succeeded(parser.parseOptionalKeyword("ins"))) {
202  if (parser.parseLParen())
203  return failure();
204 
205  inputsOperandsLoc = parser.getCurrentLocation();
206  if (parser.parseOperandList(inputsOperands) ||
207  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
208  return failure();
209  }
210 
211  if (succeeded(parser.parseOptionalKeyword("outs"))) {
212  outputsOperandsLoc = parser.getCurrentLocation();
213  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
214  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
215  return failure();
216  }
217 
218  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
219  result.operands) ||
220  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
221  result.operands))
222  return failure();
223 
224  if (addOperandSegmentSizes) {
225  // This is a bit complex because we're trying to be backward compatible with
226  // operation syntax that mix the inherent attributes and the discardable
227  // ones in the same dictionary. If the properties are used, we append the
228  // operandSegmentSizes there directly. Otherwise we append it to the
229  // discardable attributes dictionary where it is handled by the generic
230  // Operation::create(...) method.
231  if (result.propertiesAttr) {
232  NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
233  attrs.append("operandSegmentSizes",
235  {static_cast<int32_t>(inputsOperands.size()),
236  static_cast<int32_t>(outputsOperands.size())}));
237  result.propertiesAttr = attrs.getDictionary(parser.getContext());
238  } else {
239  result.addAttribute("operandSegmentSizes",
241  {static_cast<int32_t>(inputsOperands.size()),
242  static_cast<int32_t>(outputsOperands.size())}));
243  }
244  }
245  if (!result.propertiesAttr) {
246  std::optional<RegisteredOperationName> info =
247  result.name.getRegisteredInfo();
248  if (info) {
249  if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
250  return parser.emitError(attrsLoc)
251  << "'" << result.name.getStringRef() << "' op ";
252  })))
253  return failure();
254  }
255  }
256  return success();
257 }
258 
260  ValueRange outputs) {
261  if (!inputs.empty())
262  p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
263  if (!outputs.empty())
264  p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // Specific parsing and printing for named structured ops created by ods-gen.
269 //===----------------------------------------------------------------------===//
270 
271 static ParseResult parseNamedStructuredOpRegion(
272  OpAsmParser &parser, Region &region, unsigned numRegionArgs,
273  TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
274  RegionBuilderFn regionBuilder) {
275  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
276  return parser.emitError(
277  parser.getCurrentLocation(),
278  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
279  "region expects {0} args, got {1}",
280  numRegionArgs, inputTypes.size() + outputTypes.size()));
281  }
282 
283  OpBuilder opBuilder(parser.getContext());
284  fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
285  regionBuilder);
286  return success();
287 }
288 
289 static ParseResult
291  SmallVectorImpl<Type> &resultTypes) {
292  if (parser.parseOptionalArrowTypeList(resultTypes))
293  return failure();
294  return success();
295 }
296 
297 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
298  OperationState &result,
299  unsigned numRegionArgs,
300  RegionBuilderFn regionBuilder) {
301  // TODO: Enable when ods-gen supports captures.
302  SmallVector<Type, 1> inputTypes, outputTypes;
303  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
304  return failure();
305 
306  // TODO: consider merging results parsing into region parsing.
307  // Need to wait for declarative assembly resolution to decide.
308  SmallVector<Type, 1> outputTensorsTypes;
309  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
310  return failure();
311  result.addTypes(outputTensorsTypes);
312 
313  std::unique_ptr<Region> region = std::make_unique<Region>();
314  if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
315  outputTypes, result.attributes.getAttrs(),
316  regionBuilder))
317  return failure();
318  result.addRegion(std::move(region));
319 
320  return success();
321 }
322 
324  TypeRange resultTypes) {
325  if (resultTypes.empty())
326  return;
327  p.printOptionalArrowTypeList(resultTypes);
328 }
329 
331  ValueRange inputs, ValueRange outputs) {
333  op->getAttrs(),
334  /*elidedAttrs=*/{"operandSegmentSizes",
335  // See generated code in
336  // LinalgNamedStructuredOps.yamlgen.cpp.inc
337  "linalg.memoized_indexing_maps"});
338 
339  // Printing is shared with generic ops, except for the region and
340  // attributes.
341  printCommonStructuredOpParts(p, inputs, outputs);
342 
343  // Results printing.
345 
346  // Region is elided.
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // Region builder helper.
351 // TODO: Move this to a utility library.
352 // The public methods on this class are referenced directly from generated code.
353 // Helper build the unary, binary, and type conversion functions defined by the
354 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
355 // class.
356 //
357 // Implementations of the math functions must be polymorphic over numeric types,
358 // internally performing necessary casts. If the function application makes no
359 // sense, then the only recourse is to assert and return nullptr. This can be
360 // extended later if it becomes possible to fail construction of the region. The
361 // invariant should be enforced at a higher level.
362 //
363 // TODO: These helpers are currently type polymorphic over the class of integer
364 // and floating point types, but they will not internally cast within bit
365 // widths of a class (mixed precision such as i8->i32) or across classes
366 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
367 // to be handled with care and work is being considered to extend the op
368 // language to make such cases explicit. In the mean-time, violating this will
369 // fail verification, which is deemed acceptable.
370 //===----------------------------------------------------------------------===//
371 
372 namespace {
373 
374 class RegionBuilderHelper {
375 public:
376  RegionBuilderHelper(OpBuilder &builder, Block &block)
377  : builder(builder), block(block) {}
378 
379  // Build the unary functions defined by OpDSL.
380  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
381  if (!isFloatingPoint(arg))
382  llvm_unreachable("unsupported non numeric type");
383  OpBuilder::InsertionGuard g(builder);
384  builder.setInsertionPointToEnd(&block);
385  switch (unaryFn) {
386  case UnaryFn::exp:
387  return builder.create<math::ExpOp>(arg.getLoc(), arg);
388  case UnaryFn::log:
389  return builder.create<math::LogOp>(arg.getLoc(), arg);
390  case UnaryFn::abs:
391  return builder.create<math::AbsFOp>(arg.getLoc(), arg);
392  case UnaryFn::ceil:
393  return builder.create<math::CeilOp>(arg.getLoc(), arg);
394  case UnaryFn::floor:
395  return builder.create<math::FloorOp>(arg.getLoc(), arg);
396  case UnaryFn::negf:
397  return builder.create<arith::NegFOp>(arg.getLoc(), arg);
398  case UnaryFn::reciprocal: {
399  Attribute oneAttr = builder.getOneAttr(arg.getType());
400  auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
401  ::cast<TypedAttr>(oneAttr));
402  return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
403  }
404  case UnaryFn::round:
405  return builder.create<math::RoundOp>(arg.getLoc(), arg);
406  case UnaryFn::sqrt:
407  return builder.create<math::SqrtOp>(arg.getLoc(), arg);
408  case UnaryFn::rsqrt:
409  return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
410  case UnaryFn::square:
411  return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
412  case UnaryFn::tanh:
413  return builder.create<math::TanhOp>(arg.getLoc(), arg);
414  case UnaryFn::erf:
415  return builder.create<math::ErfOp>(arg.getLoc(), arg);
416  }
417  llvm_unreachable("unsupported unary function");
418  }
419 
420  // Build the binary functions defined by OpDSL.
421  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
422  bool allComplex = isComplex(arg0) && isComplex(arg1);
423  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
424  bool allInteger = isInteger(arg0) && isInteger(arg1);
425  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
426  arg1.getType().getIntOrFloatBitWidth() == 1;
427  if (!allComplex && !allFloatingPoint && !allInteger)
428  llvm_unreachable("unsupported non numeric type");
429  OpBuilder::InsertionGuard g(builder);
430  builder.setInsertionPointToEnd(&block);
431  switch (binaryFn) {
432  case BinaryFn::add:
433  if (allComplex)
434  return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
435  if (allFloatingPoint)
436  return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
437  if (allBool)
438  return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
439  return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
440  case BinaryFn::sub:
441  if (allComplex)
442  return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
443  if (allFloatingPoint)
444  return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
445  if (allBool)
446  llvm_unreachable("unsupported operation: sub with bools");
447  return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
448  case BinaryFn::mul:
449  if (allComplex)
450  return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
451  if (allFloatingPoint)
452  return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
453  if (allBool)
454  return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
455  return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
456  case BinaryFn::div:
457  if (allComplex)
458  return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
459  if (allFloatingPoint)
460  return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
461  if (allBool)
462  llvm_unreachable("unsupported operation: div with bools");
463  return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
464  case BinaryFn::div_unsigned:
465  if (!allInteger || allBool)
466  llvm_unreachable("unsupported operation: unsigned div not on uint");
467  return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
468  case BinaryFn::max_signed:
469  assert(!allComplex);
470  if (allFloatingPoint)
471  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
472  return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
473  case BinaryFn::min_signed:
474  assert(!allComplex);
475  if (allFloatingPoint)
476  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
477  return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
478  case BinaryFn::max_unsigned:
479  assert(!allComplex);
480  if (allFloatingPoint)
481  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
482  return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
483  case BinaryFn::min_unsigned:
484  assert(!allComplex);
485  if (allFloatingPoint)
486  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
487  return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
488  case BinaryFn::powf:
489  assert(allFloatingPoint);
490  return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
491  }
492  llvm_unreachable("unsupported binary function");
493  }
494 
495  // Build the ternary functions defined by OpDSL.
496  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
497  Value arg2) {
498  bool headBool =
499  isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
500  bool tailFloatingPoint =
501  isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
502  bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1);
503  OpBuilder::InsertionGuard g(builder);
504  builder.setInsertionPointToEnd(&block);
505  switch (ternaryFn) {
506  case TernaryFn::select:
507  if (!headBool && !(tailFloatingPoint || tailInteger))
508  llvm_unreachable("unsupported non numeric type");
509  return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
510  }
511  llvm_unreachable("unsupported ternary function");
512  }
513 
514  // Build the type functions defined by OpDSL.
515  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
516  switch (typeFn) {
517  case TypeFn::cast_signed:
518  return cast(toType, operand, false);
519  case TypeFn::cast_unsigned:
520  return cast(toType, operand, true);
521  }
522  llvm_unreachable("unsupported type conversion function");
523  }
524 
525  void yieldOutputs(ValueRange values) {
526  OpBuilder::InsertionGuard g(builder);
527  builder.setInsertionPointToEnd(&block);
528  Location loc = builder.getUnknownLoc();
529  builder.create<YieldOp>(loc, values);
530  }
531 
532  Value constant(const std::string &value) {
533  OpBuilder::InsertionGuard g(builder);
534  builder.setInsertionPointToEnd(&block);
535  Location loc = builder.getUnknownLoc();
536  Attribute valueAttr = parseAttribute(value, builder.getContext());
537  return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
538  }
539 
540  Value index(int64_t dim) {
541  OpBuilder::InsertionGuard g(builder);
542  builder.setInsertionPointToEnd(&block);
543  return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
544  }
545 
546  Type getIntegerType(unsigned width) {
547  return IntegerType::get(builder.getContext(), width);
548  }
549 
550  Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
551  Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
552 
553 private:
554  // Generates operations to cast the given operand to a specified type.
555  // If the cast cannot be performed, a warning will be issued and the
556  // operand returned as-is (which will presumably yield a verification
557  // issue downstream).
558  Value cast(Type toType, Value operand, bool isUnsignedCast) {
559  OpBuilder::InsertionGuard g(builder);
560  builder.setInsertionPointToEnd(&block);
561  auto loc = operand.getLoc();
562  return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
563  }
564 
565  bool isComplex(Value value) {
566  return llvm::isa<ComplexType>(value.getType());
567  }
568  bool isFloatingPoint(Value value) {
569  return llvm::isa<FloatType>(value.getType());
570  }
571  bool isInteger(Value value) {
572  return llvm::isa<IntegerType>(value.getType());
573  }
574 
575  OpBuilder &builder;
576  Block &block;
577 };
578 
579 } // namespace
580 
581 //===----------------------------------------------------------------------===//
582 // CopyOp
583 //===----------------------------------------------------------------------===//
584 
585 namespace {
586 
587 struct EraseSelfCopy : OpRewritePattern<CopyOp> {
589  LogicalResult matchAndRewrite(CopyOp copyOp,
590  PatternRewriter &rewriter) const override {
591  if (copyOp.getInputs() != copyOp.getOutputs())
592  return rewriter.notifyMatchFailure(copyOp, "not a self copy");
593  if (copyOp.hasPureBufferSemantics())
594  rewriter.eraseOp(copyOp);
595  else
596  rewriter.replaceOp(copyOp, copyOp.getInputs());
597 
598  return success();
599  }
600 };
601 
602 } // namespace
603 
604 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
605  MLIRContext *context) {
606  results.add<EraseSelfCopy>(context);
607 }
608 
609 //===----------------------------------------------------------------------===//
610 // FillOp
611 //===----------------------------------------------------------------------===//
612 
613 namespace {
614 
615 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
616 ///
617 /// For such op chains, we can create new linalg.fill ops with the result
618 /// type of the tensor.expand/collapse_shape op.
619 template <typename TensorReshapeOp>
620 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
622  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
623  PatternRewriter &rewriter) const override {
624  auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
625  if (!oldFill)
626  return failure();
627 
628  Location loc = oldFill.getLoc();
629  TensorReshapeOp newInit;
630  if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
631 
632  newInit = rewriter.create<TensorReshapeOp>(
633  loc, reshapeOp.getResultType(), oldFill.output(),
634  reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
635  reshapeOp.getStaticOutputShape());
636  } else {
637  newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
638  oldFill.output(),
639  reshapeOp.getReassociation());
640  }
641  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
642  ValueRange{newInit});
643  return success();
644  }
645 };
646 
647 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
648 /// filling value are the same.
649 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
651 
652  LogicalResult matchAndRewrite(tensor::PadOp padOp,
653  PatternRewriter &rewriter) const override {
654  auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
655  if (!fillOp)
656  return failure();
657 
658  // We can only fold if the padding value is the same as the original
659  // filling value.
660  Value padValue = padOp.getConstantPaddingValue();
661  if (!padValue || fillOp.value() != padValue)
662  return failure();
663 
664  ReifiedRankedShapedTypeDims reifiedShape;
665  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
666  return rewriter.notifyMatchFailure(
667  padOp, "failed to reify tensor.pad op result shape");
668 
669  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
670  padOp.getLoc(), reifiedShape.front(),
671  padOp.getResultType().getElementType());
672  Value replacement =
673  rewriter
674  .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
675  ValueRange{emptyTensor})
676  .getResult(0);
677  if (replacement.getType() != padOp.getResultType()) {
678  replacement = rewriter.create<tensor::CastOp>(
679  fillOp.getLoc(), padOp.getResultType(), replacement);
680  }
681  rewriter.replaceOp(padOp, replacement);
682  return success();
683  }
684 };
685 
686 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
687 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
688 /// filling value are the same.
689 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
691 
692  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
693  PatternRewriter &rewriter) const override {
694  auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
695  if (!srcPadOp)
696  return failure();
697 
698  if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
699  return failure();
700 
701  // Walk back the tensor.insert_slice chain and find the first destination
702  // value at the start of the chain.
703  Value firstDest = insertOp.getDest();
704  while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
705  if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
706  return failure();
707 
708  // Make sure the range of values accessed are disjoint. Without this, we
709  // cannot fold tensor.pad away.
710  bool disjoint = false;
711  for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
712  // If the dimension has dynamic offset/size, we cannot guarantee
713  // disjoint. So just skip it.
714  if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
715  insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
716  prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
717  continue;
718 
719  // Get the range start and end, inclusively for both.
720  int64_t prevStart = prevOp.getStaticOffset(i);
721  int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
722  prevOp.getStaticStride(i);
723  int64_t nextStart = insertOp.getStaticOffset(i);
724  int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
725  insertOp.getStaticStride(i);
726  if (prevEnd < nextStart || nextEnd < prevStart) {
727  disjoint = true;
728  break;
729  }
730  }
731 
732  if (!disjoint)
733  break;
734  firstDest = prevOp.getDest();
735  }
736 
737  // Check whether the first destination is a fill op. For overlapped cases,
738  // this also cannot be true.
739  auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
740  if (!dstFillOp)
741  return failure();
742 
743  // We can only fold if the padding value is the same as the original
744  // filling value.
745  Value padValue = srcPadOp.getConstantPaddingValue();
746  if (!padValue || dstFillOp.value() != padValue)
747  return failure();
748 
749  SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
750  SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
751 
752  Location loc = insertOp.getLoc();
753  MLIRContext *context = getContext();
754 
755  AffineExpr sym0, sym1;
756  bindSymbols(context, sym0, sym1);
757  auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
758 
759  // Calculate the new offsets for the insert. It should be the old offsets
760  // plus low padding sizes.
761  SmallVector<OpFoldResult, 4> newOffsets;
762  for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
763  newOffsets.push_back(affine::makeComposedFoldedAffineApply(
764  rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
765  }
766 
767  RankedTensorType srcPadType = srcPadOp.getSourceType();
769  for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
770  if (srcPadType.isDynamicDim(i)) {
771  newSizes.push_back(
772  rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
773  .getResult());
774  } else {
775  newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
776  }
777  }
778 
779  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
780  insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
781  newSizes, insertOp.getMixedStrides());
782  return success();
783  }
784 };
785 
786 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
787 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
788 public:
790 
791  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
792  PatternRewriter &rewriter) const override {
793  // See if tensor input of tensor.extract op is the result of a linalg.fill
794  // op.
795  auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
796  if (!fillOp)
797  return failure();
798 
799  // Get scalar input operand of linalg.fill op.
800  Value extractedScalar = fillOp.getInputs()[0];
801 
802  // Replace tensor.extract op with scalar value used to fill the tensor.
803  rewriter.replaceOp(extractOp, extractedScalar);
804  return success();
805  }
806 };
807 
808 /// Folds pack(fill) into a single fill op if
809 /// 1. The pack op does not have padding value, or
810 /// 2. The filled value and padding value are the same.
811 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
812  tensor::PackOp packOp) {
813  auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
814  if (!fillOp)
815  return failure();
816 
817  if (auto paddingValue = packOp.getPaddingValue())
818  if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
819  return failure();
820 
821  Value packOpDest = packOp.getDest();
822  if (!packOpDest.hasOneUse())
823  return failure();
824 
825  return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
826  packOp.getDest());
827 }
828 
829 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
830 struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
831 public:
832  FoldFillWithPack(MLIRContext *context)
833  : OpRewritePattern<tensor::PackOp>(context) {}
834 
835  LogicalResult matchAndRewrite(tensor::PackOp packOp,
836  PatternRewriter &rewriter) const override {
837  auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
838  if (failed(fillOp))
839  return failure();
840  rewriter.replaceOp(packOp, fillOp.value().result());
841  return success();
842  }
843 };
844 
845 /// Fold fill with copy.
846 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
848 
849  LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
850  PatternRewriter &rewriter) const override {
851  if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
852  rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
853  fillOp.getInputs(),
854  copyOp.getOutputs());
855  return success();
856  }
857  if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
858  rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
859  fillOp.getOutputs());
860  return success();
861  }
862  return failure();
863  }
864 };
865 
866 /// Fold fill with transpose.
867 struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
869 
870  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
871  PatternRewriter &rewriter) const override {
872  if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
873  rewriter.replaceOpWithNewOp<FillOp>(
874  transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
875  transposeOp.getDpsInitOperand(0)->get());
876  return success();
877  }
878  return failure();
879  }
880 };
881 
882 } // namespace
883 
884 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
885  MLIRContext *context) {
886  results
887  .add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
888  FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
889  FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
890  FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
891 }
892 
893 //===----------------------------------------------------------------------===//
894 // GenericOp
895 //===----------------------------------------------------------------------===//
896 
897 static void buildGenericRegion(
898  OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
899  ValueRange outputs,
900  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
901  SmallVector<Type, 4> blockArgTypes;
902  SmallVector<Location, 4> blockArgLocs;
903  for (ValueRange container : {inputs, outputs}) {
904  for (Value v : container) {
905  Type t = v.getType();
906  blockArgTypes.push_back(
907  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
908  blockArgLocs.push_back(v.getLoc());
909  }
910  }
911 
912  OpBuilder::InsertionGuard guard(builder);
913  Block *bodyBlock =
914  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
915  bodyBuild(builder, loc, bodyBlock->getArguments());
916 }
917 
918 void GenericOp::getAsmBlockArgumentNames(Region &region,
919  OpAsmSetValueNameFn setNameFn) {
920  for (Value v : getRegionInputArgs())
921  setNameFn(v, "in");
922  for (Value v : getRegionOutputArgs())
923  setNameFn(v, "out");
924 }
925 
926 void GenericOp::build(
927  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
928  ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
929  ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
930  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
931  ArrayRef<NamedAttribute> attributes) {
932  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
933  iteratorTypes, doc, libraryCall);
934  result.addAttributes(attributes);
935  if (bodyBuild)
936  buildGenericRegion(builder, result.location, *result.regions.front(),
937  inputs, outputs, bodyBuild);
938 }
939 
940 void GenericOp::build(
941  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
942  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
943  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
944  StringRef libraryCall,
945  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
946  ArrayRef<NamedAttribute> attributes) {
947  build(builder, result, resultTensorTypes, inputs, outputs,
948  builder.getAffineMapArrayAttr(indexingMaps),
949  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
950  iteratorTypes,
951  [&](utils::IteratorType iter) -> mlir::Attribute {
952  return IteratorTypeAttr::get(builder.getContext(), iter);
953  }))),
954  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
955  libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
956  bodyBuild, attributes);
957 }
958 
959 void GenericOp::build(
960  OpBuilder &builder, OperationState &result, ValueRange inputs,
961  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
962  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
963  StringRef libraryCall,
964  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
965  ArrayRef<NamedAttribute> attributes) {
966  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
967  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
968 }
969 
970 void GenericOp::build(
971  OpBuilder &builder, OperationState &result, ValueRange inputs,
972  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
973  ArrayRef<utils::IteratorType> iteratorTypes,
974  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
975  ArrayRef<NamedAttribute> attributes) {
976  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
977  /*doc=*/"",
978  /*libraryCall=*/"", bodyBuild, attributes);
979 }
980 
981 void GenericOp::build(
982  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
983  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
984  ArrayRef<utils::IteratorType> iteratorTypes,
985  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
986  ArrayRef<NamedAttribute> attributes) {
987  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
988  iteratorTypes,
989  /*doc=*/"",
990  /*libraryCall=*/"", bodyBuild, attributes);
991 }
992 
994  p << " ";
995 
996  // Print extra attributes.
997  auto genericAttrNames = linalgTraitAttrNames();
998 
999  llvm::StringSet<> genericAttrNamesSet;
1000  genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
1001  SmallVector<NamedAttribute, 8> genericAttrs;
1002  for (auto attr : (*this)->getAttrs()) {
1003  if (attr.getName() == getIteratorTypesAttrName()) {
1004  auto iteratorTypes =
1005  llvm::cast<ArrayAttr>(attr.getValue())
1006  .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1007  // Convert IteratorType enums into the string representation. This is
1008  // needed, because tests still use the old format when 'iterator_types'
1009  // attribute is represented as an array of strings.
1010  // TODO: Remove this conversion once tests are fixed.
1011  SmallVector<Attribute> iteratorTypeNames =
1012  llvm::to_vector(llvm::map_range(
1013  iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1014  return StringAttr::get(getContext(), stringifyIteratorType(t));
1015  }));
1016 
1017  genericAttrs.emplace_back(
1018  getIteratorTypesAttrName(),
1019  ArrayAttr::get(getContext(), iteratorTypeNames));
1020  } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1021  genericAttrs.push_back(attr);
1022  }
1023  }
1024  if (!genericAttrs.empty()) {
1025  auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1026  p << genericDictAttr;
1027  }
1028 
1029  // Printing is shared with named ops, except for the region and attributes
1030  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1031 
1032  genericAttrNames.push_back("operandSegmentSizes");
1033  genericAttrNamesSet.insert(genericAttrNames.back());
1034 
1035  bool hasExtraAttrs = false;
1036  for (NamedAttribute n : (*this)->getAttrs()) {
1037  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1038  break;
1039  }
1040  if (hasExtraAttrs) {
1041  p << " attrs = ";
1042  p.printOptionalAttrDict((*this)->getAttrs(),
1043  /*elidedAttrs=*/genericAttrNames);
1044  }
1045 
1046  // Print region.
1047  if (!getRegion().empty()) {
1048  p << ' ';
1049  p.printRegion(getRegion());
1050  }
1051 
1052  // Print results.
1053  printNamedStructuredOpResults(p, getResultTensors().getTypes());
1054 }
1055 
1056 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1057  DictionaryAttr dictAttr;
1058  // Parse the core linalg traits that must check into a dictAttr.
1059  // The name is unimportant as we will overwrite result.attributes.
1060  // The core linalg traits must contain the information necessary to pass the
1061  // verifier.
1062  llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1063  if (parser.parseAttribute(dictAttr, "_", result.attributes))
1064  return failure();
1065  result.attributes.assign(dictAttr.getValue().begin(),
1066  dictAttr.getValue().end());
1067 
1068  // Convert array of string into an array of IteratorType enums. This is
1069  // needed, because tests still use the old format when 'iterator_types'
1070  // attribute is represented as an array of strings.
1071  // TODO: Remove this conversion once tests are fixed.
1072  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1073  result.attributes.get(getIteratorTypesAttrName(result.name)));
1074  if (!iteratorTypes) {
1075  return parser.emitError(attributeLocation)
1076  << "expected " << getIteratorTypesAttrName(result.name)
1077  << " array attribute";
1078  }
1079 
1080  SmallVector<Attribute> iteratorTypeAttrs;
1081 
1082  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1083  auto maybeIteratorType = utils::symbolizeIteratorType(s);
1084  if (!maybeIteratorType.has_value())
1085  return parser.emitError(parser.getCurrentLocation())
1086  << "unexpected iterator_type (" << s << ")";
1087 
1088  iteratorTypeAttrs.push_back(
1089  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1090  }
1091  result.attributes.set(getIteratorTypesAttrName(result.name),
1092  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1093 
1094  // Parsing is shared with named ops, except for the region.
1095  SmallVector<Type, 1> inputTypes, outputTypes;
1096  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1097  return failure();
1098 
1099  // Optional attributes may be added.
1100  if (succeeded(parser.parseOptionalKeyword("attrs")))
1101  if (failed(parser.parseEqual()) ||
1102  failed(parser.parseOptionalAttrDict(result.attributes)))
1103  return failure();
1104 
1105  std::unique_ptr<Region> region = std::make_unique<Region>();
1106  if (parser.parseRegion(*region, {}))
1107  return failure();
1108  result.addRegion(std::move(region));
1109 
1110  // Generic ops may specify that a subset of its outputs are tensors. Such
1111  // outputs are specified in the result type.
1112  // TODO: may need to move output parsing before region parsing.
1113  // Need to wait for declarative assembly resolution to decide.
1114  SmallVector<Type, 1> outputTensorsTypes;
1115  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1116  return failure();
1117  result.addTypes(outputTensorsTypes);
1118 
1119  return success();
1120 }
1121 
1124  &effects,
1125  LinalgOp linalgOp) {
1126  for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1127  if (!llvm::isa<MemRefType>(operand.getType()))
1128  continue;
1129  effects.emplace_back(
1130  MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1131  /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1132  }
1133 
1134  for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1135  if (!llvm::isa<MemRefType>(operand.get().getType()))
1136  continue;
1137  if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1138  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1139  /*effectOnFullRegion=*/true,
1141  }
1142  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1143  /*effectOnFullRegion=*/true,
1145  }
1146 }
1147 
1148 void GenericOp::getEffects(
1150  &effects) {
1151  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1152 }
1153 
1154 LogicalResult GenericOp::verify() { return success(); }
1155 
1156 namespace {
1157 
1158 /// Remove any linalg operation (on tensors) that are just copying
1159 /// the values from inputs to the results. Requirements are
1160 /// 1) All iterator types are parallel
1161 /// 2) The body contains just a yield operation with the yielded values being
1162 /// the arguments corresponding to the operands.
1163 template <typename OpTy>
1164 struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1166 
1167  LogicalResult matchAndRewrite(OpTy linalgOp,
1168  PatternRewriter &rewriter) const override {
1169  // Check all indexing maps are identity.
1170  if (llvm::any_of(linalgOp.getIndexingMapsArray(),
1171  [](AffineMap map) { return !map.isIdentity(); }))
1172  return failure();
1173 
1174  // Check that the body of the linalg operation is just a linalg.yield
1175  // operation.
1176  Block &body = linalgOp->getRegion(0).front();
1177  if (!llvm::hasSingleElement(body))
1178  return failure();
1179  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1180  if (!yieldOp)
1181  return failure();
1182 
1183  // In the buffer case, we need to check exact buffer equality.
1184  if (linalgOp.hasPureBufferSemantics()) {
1185  if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1186  linalgOp.getDpsInputOperand(0)->get() ==
1187  linalgOp.getDpsInitOperand(0)->get()) {
1188  rewriter.eraseOp(linalgOp);
1189  return success();
1190  }
1191  return failure();
1192  }
1193 
1194  // Mixed semantics is not supported yet.
1195  if (!linalgOp.hasPureTensorSemantics())
1196  return failure();
1197 
1198  // Get the argument number of the returned values. That is the operand
1199  // number to use for replacing uses of this operation.
1200  SmallVector<Value> returnedArgs;
1201  for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1202  auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1203  if (!yieldArg || yieldArg.getOwner() != &body)
1204  return failure();
1205  unsigned argumentNumber = yieldArg.getArgNumber();
1206  Value returnedArg = linalgOp->getOperand(argumentNumber);
1207  Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1208  // The input can have a different type than the result, e.g. a dynamic
1209  // input dimension can be turned into a static output dimension.
1210  Type returnType = returnedArg.getType();
1211  if (returnType != resultType) {
1212  // Distinguish between sparse conversion or dense tensor casting.
1213  // TODO: unify the two ops?
1214  if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1216  returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1217  linalgOp.getLoc(), resultType, returnedArg);
1218  else {
1219  if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1220  resultType))
1221  return failure();
1222  returnedArg = rewriter.create<tensor::CastOp>(
1223  linalgOp.getLoc(), resultType, returnedArg);
1224  }
1225  }
1226  returnedArgs.push_back(returnedArg);
1227  }
1228 
1229  if (returnedArgs.size() != linalgOp->getNumResults())
1230  return failure();
1231  rewriter.replaceOp(linalgOp, returnedArgs);
1232  return success();
1233  }
1234 };
1235 
1236 } // namespace
1237 
1238 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1239  MLIRContext *context) {
1240  results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1241 }
1242 
1243 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1244  return memref::foldMemRefCast(*this);
1245 }
1246 
1247 //===----------------------------------------------------------------------===//
1248 // MapOp
1249 //===----------------------------------------------------------------------===//
1250 
1251 static ParseResult parseDstStyleOp(
1252  OpAsmParser &parser, OperationState &result,
1253  function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1254  nullptr) {
1255  // Parse `ins` and `outs`.
1256  SmallVector<Type, 4> inputTypes, outputTypes;
1257  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1258  /*addOperandSegmentSizes=*/false))
1259  return failure();
1260 
1261  // Add result types.
1262  for (Type outputType : outputTypes) {
1263  if (llvm::isa<RankedTensorType>(outputType))
1264  result.addTypes(outputType);
1265  }
1266 
1267  // Parse required attributes.
1268  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1269  return failure();
1270 
1271  // Parse optional attributes.
1272  if (parser.parseOptionalAttrDict(result.attributes))
1273  return failure();
1274  return success();
1275 }
1276 
1277 void MapOp::getAsmBlockArgumentNames(Region &region,
1278  OpAsmSetValueNameFn setNameFn) {
1279  for (Value v : getRegionInputArgs())
1280  setNameFn(v, "in");
1281 }
1282 
1283 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1284  if (!getResults().empty())
1285  setNameFn(getResults().front(), "mapped");
1286 }
1287 
1288 void MapOp::build(
1289  OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1290  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1291  ArrayRef<NamedAttribute> attributes) {
1292  build(builder, result, TypeRange{}, inputs, init);
1293  result.addAttributes(attributes);
1294 
1295  // Add output types for `RankedTensorType` output arguments.
1296  Type initType = init.getType();
1297  if (llvm::isa<RankedTensorType>(initType))
1298  result.addTypes(initType);
1299 
1300  if (bodyBuild)
1301  buildGenericRegion(builder, result.location, *result.regions.front(),
1302  inputs, /*outputs=*/{}, bodyBuild);
1303 }
1304 
1305 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1306  const OperationName &payloadOpName,
1307  const NamedAttrList &payloadOpAttrs,
1308  ArrayRef<Value> operands,
1309  bool initFirst = false) {
1310  OpBuilder b(parser.getContext());
1311  Region *body = result.addRegion();
1312  Block &block = body->emplaceBlock();
1313  b.setInsertionPointToStart(&block);
1314  SmallVector<Value> bbArgs;
1315  for (auto &operand : operands) {
1316  block.addArgument(
1317  llvm::cast<ShapedType>(operand.getType()).getElementType(),
1318  b.getUnknownLoc());
1319  }
1320  SmallVector<Value> payloadOpOperands;
1321  // If initFirst flag is enabled, we consider init as the first position of
1322  // payload operands.
1323  if (initFirst) {
1324  payloadOpOperands.push_back(block.getArguments().back());
1325  for (const auto &arg : block.getArguments().drop_back())
1326  payloadOpOperands.push_back(arg);
1327  } else {
1328  payloadOpOperands = {block.getArguments().begin(),
1329  block.getArguments().end()};
1330  }
1331 
1332  Operation *payloadOp = b.create(
1333  result.location, b.getStringAttr(payloadOpName.getStringRef()),
1334  payloadOpOperands,
1335  TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1336  .getElementType()},
1337  payloadOpAttrs);
1338  b.create<YieldOp>(result.location, payloadOp->getResults());
1339 }
1340 
1341 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1342  std::optional<OperationName> payloadOpName;
1343  NamedAttrList payloadOpAttrs;
1344  if (succeeded(parser.parseOptionalLBrace())) {
1345  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1346  if (failed(operationName))
1347  return failure();
1348  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1349  return failure();
1350  payloadOpName = operationName.value();
1351  if (parser.parseRBrace())
1352  return failure();
1353  }
1354 
1355  if (parseDstStyleOp(parser, result))
1356  return failure();
1357 
1358  if (payloadOpName.has_value()) {
1359  if (!result.operands.empty())
1360  addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1361  payloadOpAttrs,
1362  ArrayRef(result.operands).drop_back());
1363  else
1364  result.addRegion();
1365  } else {
1367  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1368  /*allowType=*/true, /*allowAttrs=*/true)) {
1369  return failure();
1370  }
1371  Region *body = result.addRegion();
1372  if (parser.parseRegion(*body, regionArgs))
1373  return failure();
1374  }
1375  return success();
1376 }
1377 
1378 // Retrieve the operation from the body, if it is the only one (except
1379 // yield) and if it gets the same amount of arguments as the body does.
1380 // If initFirst flag is enabled, we check that init takes the first position in
1381 // operands of payload.
1382 static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1383  if (body->getOperations().size() != 2)
1384  return nullptr;
1385  Operation &payload = body->getOperations().front();
1386  assert(isa<YieldOp>(body->getOperations().back()));
1387 
1388  if (payload.getNumOperands() == 0 ||
1389  payload.getNumOperands() != body->getNumArguments())
1390  return nullptr;
1391  if (initFirst) {
1392  // check init
1393  if (payload.getOperands().back() != body->getArgument(0))
1394  return nullptr;
1395  // check rest
1396  for (const auto &[operand, bbArg] :
1397  llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1398  if (bbArg != operand)
1399  return nullptr;
1400  }
1401  } else {
1402  for (const auto &[operand, bbArg] :
1403  llvm::zip(payload.getOperands(), body->getArguments())) {
1404  if (bbArg != operand)
1405  return nullptr;
1406  }
1407  }
1408  return &payload;
1409 }
1410 
1411 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1412  SmallVector<StringRef> elidedAttrs;
1413  std::string attrToElide;
1414  p << " { " << payloadOp->getName().getStringRef();
1415  for (const auto &attr : payloadOp->getAttrs()) {
1416  auto fastAttr =
1417  llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1418  if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1419  attrToElide = attr.getName().str();
1420  elidedAttrs.push_back(attrToElide);
1421  break;
1422  }
1423  }
1424  p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1425  p << " }";
1426 }
1427 
1428 void MapOp::print(OpAsmPrinter &p) {
1429  Block *mapper = getBody();
1430  Operation *payloadOp = findPayloadOp(mapper);
1431  if (payloadOp) {
1432  printShortForm(p, payloadOp);
1433  }
1434 
1435  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1436  p.printOptionalAttrDict((*this)->getAttrs());
1437 
1438  if (!payloadOp) {
1439  // Print region if the payload op was not detected.
1440  p.increaseIndent();
1441  p.printNewline();
1442  p << "(";
1443  llvm::interleaveComma(mapper->getArguments(), p,
1444  [&](auto arg) { p.printRegionArgument(arg); });
1445  p << ") ";
1446 
1447  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1448  p.decreaseIndent();
1449  }
1450 }
1451 
1452 LogicalResult MapOp::verify() {
1453  auto *bodyBlock = getBody();
1454  auto blockArgs = bodyBlock->getArguments();
1455 
1456  // Checks if the number of `inputs` match the arity of the `mapper` region.
1457  if (getInputs().size() != blockArgs.size())
1458  return emitOpError() << "expects number of operands to match the arity of "
1459  "mapper, but got: "
1460  << getInputs().size() << " and " << blockArgs.size();
1461 
1462  // The parameters of mapper should all match the element type of inputs.
1463  for (const auto &[bbArgType, inputArg] :
1464  llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1465  auto inputElemType =
1466  llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1467  if (bbArgType != inputElemType) {
1468  return emitOpError() << "expected element type of input " << inputElemType
1469  << " to match bbArg type " << bbArgType;
1470  }
1471  }
1472 
1473  // The shape of each input must match the shape of the output.
1474  auto outputShape = getInit().getType().getShape();
1475  for (Type inputArgType : TypeRange{getInputs()}) {
1476  auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1477  if (inputElemShape != outputShape) {
1478  return emitOpError() << "expected shape of input (" << inputElemShape
1479  << ") to match shape of output (" << outputShape
1480  << ")";
1481  }
1482  }
1483 
1484  return success();
1485 }
1486 
1487 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1488  int64_t rank = getInit().getType().getRank();
1489  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1490 }
1491 
1492 ArrayAttr MapOp::getIndexingMaps() {
1493  Builder builder(getContext());
1494  int64_t rank = getInit().getType().getRank();
1495  int64_t numIndexingMaps = getOperands().size();
1497  numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1498 }
1499 
1500 void MapOp::getEffects(
1502  &effects) {
1503  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1504 }
1505 
1506 //===----------------------------------------------------------------------===//
1507 // ReduceOp
1508 //===----------------------------------------------------------------------===//
1509 
1510 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1511  OpAsmSetValueNameFn setNameFn) {
1512  for (Value v : getRegionInputArgs())
1513  setNameFn(v, "in");
1514  for (Value v : getRegionOutputArgs())
1515  setNameFn(v, "init");
1516 }
1517 
1518 void ReduceOp::getAsmResultNames(
1519  function_ref<void(Value, StringRef)> setNameFn) {
1520  if (!getResults().empty())
1521  setNameFn(getResults().front(), "reduced");
1522 }
1523 
1524 void ReduceOp::build(
1525  OpBuilder &builder, OperationState &result, ValueRange inputs,
1526  ValueRange inits, ArrayRef<int64_t> dimensions,
1527  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1528  ArrayRef<NamedAttribute> attributes) {
1529  build(builder, result, TypeRange{}, inputs, inits, dimensions);
1530  result.addAttributes(attributes);
1531 
1532  // Add output types for `RankedTensorType` output arguments.
1533  for (Value init : inits) {
1534  Type initType = init.getType();
1535  if (llvm::isa<RankedTensorType>(initType))
1536  result.addTypes(initType);
1537  }
1538 
1539  if (bodyBuild)
1540  buildGenericRegion(builder, result.location, *result.regions.front(),
1541  inputs, inits, bodyBuild);
1542 }
1543 
1544 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1545  int64_t inputRank =
1546  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1547  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1548  utils::IteratorType::parallel);
1549  for (int64_t reductionDim : getDimensions())
1550  iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1551  return iteratorTypes;
1552 }
1553 
1554 ArrayAttr ReduceOp::getIndexingMaps() {
1555  int64_t inputRank =
1556  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1557  SmallVector<AffineMap> affineMaps(
1558  getNumDpsInputs(),
1560  AffineMap resultMap =
1562  .dropResults(getDimensions());
1563  for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1564  affineMaps.push_back(resultMap);
1565  return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1566 }
1567 
1568 void ReduceOp::getEffects(
1570  &effects) {
1571  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1572 }
1573 
1574 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1575  NamedAttrList &attributes,
1576  StringRef attributeName) {
1577  if (parser.parseKeyword(attributeName) || parser.parseEqual())
1578  return failure();
1579 
1580  attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1581  return success();
1582 }
1583 
1584 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1585  std::optional<OperationName> payloadOpName;
1586  NamedAttrList payloadOpAttrs;
1587  if (succeeded(parser.parseOptionalLBrace())) {
1588  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1589  if (failed(operationName))
1590  return failure();
1591  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1592  return failure();
1593  payloadOpName = operationName.value();
1594  if (parser.parseRBrace())
1595  return failure();
1596  }
1597 
1598  if (parseDstStyleOp(
1599  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1600  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1601  }))
1602  return failure();
1603 
1604  if (payloadOpName.has_value()) {
1605  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1606  ArrayRef(result.operands), /*initFirst=*/true);
1607  } else {
1609  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1610  /*allowType=*/true, /*allowAttrs=*/true)) {
1611  return failure();
1612  }
1613 
1614  Region *body = result.addRegion();
1615  if (parser.parseRegion(*body, regionArgs))
1616  return failure();
1617  }
1618 
1619  return success();
1620 }
1621 
1622 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1623  ArrayRef<int64_t> attributeValue) {
1624  p << ' ' << attributeName << " = [" << attributeValue << "] ";
1625 }
1626 
1627 void ReduceOp::print(OpAsmPrinter &p) {
1628  Block *mapper = getBody();
1629  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1630  if (payloadOp) {
1631  printShortForm(p, payloadOp);
1632  }
1633 
1634  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1635  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1636  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1637  if (!payloadOp) {
1638  // Print region if the payload op was not detected.
1639  p.increaseIndent();
1640  p.printNewline();
1641  p << "(";
1642  llvm::interleaveComma(mapper->getArguments(), p,
1643  [&](auto arg) { p.printRegionArgument(arg); });
1644  p << ") ";
1645 
1646  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1647  p.decreaseIndent();
1648  }
1649 }
1650 
1651 LogicalResult ReduceOp::verify() {
1652  ArrayRef<int64_t> dimensionsRef = getDimensions();
1653 
1654  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1655  if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1656  llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1657  return emitOpError() << "expects all inputs to have the same shapes. "
1658  "Shape at input-index "
1659  << i
1660  << " is not equal to the shape at input-index 0.";
1661  }
1662  }
1663  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1664  if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1665  llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1666  return emitOpError() << "expects all outputs to have the same shapes. "
1667  "Shape at output-index "
1668  << i
1669  << " is not equal to the shape at output-index 0.";
1670  }
1671  }
1672  auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1673  auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1674 
1675  DenseSet<int64_t> dimensionsToReduce;
1676  for (int64_t dimension : dimensionsRef) {
1677  if (dimension < 0 || dimension >= inputType.getRank()) {
1678  return emitOpError()
1679  << "dimensions for reduction should be in the range [0, "
1680  << inputType.getRank() - 1 << "].";
1681  }
1682  dimensionsToReduce.insert(dimension);
1683  }
1684 
1685  auto inputDims = inputType.getShape();
1686  auto initDims = initType.getShape();
1687 
1688  // Input dimensions that will be left after the reduction.
1689  SmallVector<int64_t> reducedInputDims;
1690  for (const auto &en : llvm::enumerate(inputDims)) {
1691  if (!dimensionsToReduce.count(en.index()))
1692  reducedInputDims.push_back(en.value());
1693  }
1694 
1695  if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1696  return emitOpError() << "number of dimensions after reduction "
1697  << reducedInputDims.size()
1698  << " doesn't match the init rank "
1699  << initType.getRank();
1700  }
1701 
1702  if (reducedInputDims != initDims)
1703  return emitOpError() << "init dimensions [" << initDims
1704  << "] doesn't match input dimensions after reduction ["
1705  << reducedInputDims << "]";
1706 
1707  Block *block = getBody();
1708  if (block->getNumArguments() != this->getNumOperands())
1709  return emitOpError()
1710  << "mismatching number of operands and block arguments";
1711 
1712  // Check that the first block arguments match the element type of the inputs.
1713  for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1714  Type inputElementType =
1715  llvm::cast<ShapedType>(input.getType()).getElementType();
1716  if (inputElementType != bbArg.getType())
1717  return emitOpError()
1718  << "input element type " << inputElementType
1719  << " does not match corresponding block argument type "
1720  << bbArg.getType();
1721  }
1722 
1723  // Check that the last block arguments match the element type of the outputs.
1724  for (auto [output, bbArg] : llvm::zip(
1725  getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1726  auto outputElementType =
1727  llvm::cast<ShapedType>(output.getType()).getElementType();
1728  if (outputElementType != bbArg.getType())
1729  return emitOpError()
1730  << "output element type " << outputElementType
1731  << " does not match corresponding block argument type "
1732  << bbArg.getType();
1733  }
1734  return success();
1735 }
1736 
1737 //===----------------------------------------------------------------------===//
1738 // TransposeOp
1739 //===----------------------------------------------------------------------===//
1740 
1741 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1742  Region &region, ValueRange inputs,
1743  ValueRange outputs) {
1744  buildGenericRegion(builder, loc, region, inputs, outputs,
1745  [](OpBuilder &b, Location loc, ValueRange args) {
1746  if (!args.empty())
1747  b.create<linalg::YieldOp>(loc, args[0]);
1748  });
1749 }
1750 
1751 void TransposeOp::build(::mlir::OpBuilder &builder,
1752  ::mlir::OperationState &result, Value input, Value init,
1753  DenseI64ArrayAttr permutation,
1754  ArrayRef<NamedAttribute> attributes) {
1755  result.addOperands(input);
1756  result.addOperands(init);
1757  result.addAttribute(getPermutationAttrName(result.name), permutation);
1758  result.addAttributes(attributes);
1759 
1760  // Add output types for `RankedTensorType` output arguments.
1761  Type initType = init.getType();
1762  if (llvm::isa<RankedTensorType>(initType))
1763  result.addTypes(initType);
1764 
1765  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1766  init);
1767 }
1768 
1769 void TransposeOp::build(::mlir::OpBuilder &builder,
1770  ::mlir::OperationState &result, Value input, Value init,
1771  ArrayRef<int64_t> permutation,
1772  ArrayRef<NamedAttribute> attributes) {
1773  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1774  attributes);
1775 }
1776 
1777 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1778  if (failed(parseDstStyleOp(
1779  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1780  return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1781  })))
1782  return failure();
1783 
1784  OpBuilder builder(parser.getContext());
1785  buildIdentityRegion(builder, result.location, *result.addRegion(),
1786  /*inputs=*/result.operands,
1787  /*outputs=*/{});
1788  return success();
1789 }
1790 
1791 void TransposeOp::getAsmResultNames(
1792  function_ref<void(Value, StringRef)> setNameFn) {
1793  if (!getResults().empty())
1794  setNameFn(getResults().front(), "transposed");
1795 }
1796 
1798  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1799  printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1800  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1801 }
1802 
1803 LogicalResult TransposeOp::verify() {
1804  ArrayRef<int64_t> permutationRef = getPermutation();
1805 
1806  if (!isPermutationVector(permutationRef))
1807  return emitOpError("permutation is not valid");
1808 
1809  auto inputType = getInput().getType();
1810  auto initType = getInit().getType();
1811 
1812  int64_t rank = inputType.getRank();
1813 
1814  if (rank != initType.getRank())
1815  return emitOpError() << "input rank " << rank
1816  << " does not match init rank " << initType.getRank();
1817 
1818  if (rank != static_cast<int64_t>(permutationRef.size()))
1819  return emitOpError() << "size of permutation " << permutationRef.size()
1820  << " does not match the argument rank " << rank;
1821 
1822  auto inputDims = inputType.getShape();
1823  auto initDims = initType.getShape();
1824 
1825  for (int64_t i = 0; i < rank; ++i) {
1826  int64_t inputDim = inputDims[permutationRef[i]];
1827  int64_t initDim = initDims[i];
1828 
1829  if (inputDim != initDim) {
1830  return emitOpError() << "dim(result, " << i << ") = " << initDim
1831  << " doesn't match dim(input, permutation[" << i
1832  << "]) = " << inputDim;
1833  }
1834  }
1835 
1836  return success();
1837 }
1838 
1839 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1840  int64_t rank = getInit().getType().getRank();
1841  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1842 }
1843 
1844 ArrayAttr TransposeOp::getIndexingMaps() {
1845  Builder builder(getContext());
1846  int64_t rank = getInit().getType().getRank();
1847  return builder.getAffineMapArrayAttr(
1849  llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
1850  builder.getMultiDimIdentityMap(rank)});
1851 }
1852 
1853 void TransposeOp::getEffects(
1855  &effects) {
1856  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1857 }
1858 
1859 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1861  // Single dimension transpose.
1862  if (getPermutation().size() == 0) {
1863  result.push_back(getInput());
1864  return success();
1865  }
1866  // Identity permutation.
1867  if (isIdentityPermutation(getPermutation())) {
1868  result.push_back(getInput());
1869  return success();
1870  }
1871 
1872  return failure();
1873 }
1874 
1875 /// Fold transpose with transpose.
1876 struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1878 
1879  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1880  PatternRewriter &rewriter) const override {
1881  auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
1882  if (!defTransposeOp)
1883  return failure();
1884  ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
1885  ArrayRef<int64_t> perms = transposeOp.getPermutation();
1886  SmallVector<int64_t> foldedPerms;
1887  foldedPerms.reserve(perms.size());
1888  for (int64_t perm : perms)
1889  foldedPerms.push_back(defPerms[perm]);
1890 
1891  rewriter.replaceOpWithNewOp<TransposeOp>(
1892  transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
1893  foldedPerms);
1894  return success();
1895  }
1896 };
1897 
1898 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1899  MLIRContext *context) {
1900  results.add<FoldTransposeWithTranspose>(context);
1901 }
1902 
1903 //===----------------------------------------------------------------------===//
1904 // BroadcastOp
1905 //===----------------------------------------------------------------------===//
1906 
1907 void BroadcastOp::build(::mlir::OpBuilder &builder,
1908  ::mlir::OperationState &result, Value input, Value init,
1909  DenseI64ArrayAttr dimensions,
1910  ArrayRef<NamedAttribute> attributes) {
1911  result.addOperands(input);
1912  result.addOperands(init);
1913  result.addAttribute(getDimensionsAttrName(result.name), dimensions);
1914  result.addAttributes(attributes);
1915 
1916  // Add output types for `RankedTensorType` output arguments.
1917  Type initType = init.getType();
1918  if (llvm::isa<RankedTensorType>(initType))
1919  result.addTypes(initType);
1920 
1921  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1922  init);
1923 }
1924 
1925 void BroadcastOp::build(::mlir::OpBuilder &builder,
1926  ::mlir::OperationState &result, Value input, Value init,
1927  ArrayRef<int64_t> dimensions,
1928  ArrayRef<NamedAttribute> attributes) {
1929  build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
1930  attributes);
1931 }
1932 
1933 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
1934  if (failed(parseDstStyleOp(
1935  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1936  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1937  })))
1938  return failure();
1939 
1940  OpBuilder builder(parser.getContext());
1941  buildIdentityRegion(builder, result.location, *result.addRegion(),
1942  /*inputs=*/result.operands,
1943  /*outputs=*/{});
1944  return success();
1945 }
1946 
1947 void BroadcastOp::getAsmResultNames(
1948  function_ref<void(Value, StringRef)> setNameFn) {
1949  if (!getResults().empty())
1950  setNameFn(getResults().front(), "broadcasted");
1951 }
1952 
1954  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1955  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1956  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1957 }
1958 
1959 LogicalResult BroadcastOp::verify() {
1960  ArrayRef<int64_t> dimensionsRef = getDimensions();
1961 
1962  auto inputType = getInput().getType();
1963  auto initType = getInit().getType();
1964 
1965  int64_t inputRank = inputType.getRank();
1966  int64_t initRank = initType.getRank();
1967 
1968  auto inputShape = inputType.getShape();
1969  auto initShape = initType.getShape();
1970 
1971  if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
1972  return emitOpError() << "input rank plus added dimensions does not "
1973  "match init rank. input rank: "
1974  << inputRank
1975  << ", dimensions size: " << dimensionsRef.size()
1976  << ", init rank: " << initRank;
1977 
1978  for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
1979  if (dim < 0 || dim >= initRank)
1980  return emitOpError() << "dimension " << idx
1981  << " is out of range. expected range: [0, "
1982  << initRank - 1 << "], got: " << dim;
1983  }
1984 
1985  // Mapping from input dims to init dims.
1986  SmallVector<int64_t> dimMap;
1987  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
1988  if (!llvm::is_contained(dimensionsRef, dim))
1989  dimMap.push_back(dim);
1990  }
1991 
1992  for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
1993  // This dimensions is mapped from the input. Init and input dims should
1994  // match.
1995  if (inputShape[inputDimIdx] != initShape[initDimIdx])
1996  return emitOpError() << "input dim " << inputDimIdx
1997  << " should match init dim " << initDimIdx
1998  << ". input: " << inputShape[inputDimIdx]
1999  << ", init: " << initShape[initDimIdx];
2000  }
2001 
2002  return success();
2003 }
2004 
2005 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2006  int64_t rank = getInit().getType().getRank();
2007  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2008 }
2009 
2010 ArrayAttr BroadcastOp::getIndexingMaps() {
2011  Builder builder(getContext());
2012  int64_t rank = getInit().getType().getRank();
2013  return builder.getAffineMapArrayAttr(
2014  {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2015  builder.getMultiDimIdentityMap(rank)});
2016 }
2017 
2018 void BroadcastOp::getEffects(
2020  &effects) {
2021  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2022 }
2023 
2024 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2025  MLIRContext *context) {
2026  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2027 }
2028 
2029 //===----------------------------------------------------------------------===//
2030 // YieldOp
2031 //===----------------------------------------------------------------------===//
2032 
2034  if (getNumOperands() > 0)
2035  p << ' ' << getOperands();
2036  p.printOptionalAttrDict((*this)->getAttrs());
2037  if (getNumOperands() > 0)
2038  p << " : " << getOperandTypes();
2039 }
2040 
2041 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2043  SmallVector<Type, 2> types;
2044  SMLoc loc = parser.getCurrentLocation();
2045  return failure(parser.parseOperandList(opInfo) ||
2046  parser.parseOptionalAttrDict(result.attributes) ||
2047  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2048  parser.resolveOperands(opInfo, types, loc, result.operands));
2049 }
2050 
2051 // Check the operand number and types must match the element types of the
2052 // LinalgOp interface's shaped operands.
2053 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2054  if (op.getNumOperands() != linalgOp.getNumDpsInits())
2055  return op.emitOpError("expected number of yield values (")
2056  << op.getNumOperands()
2057  << ") to match the number of inits / outs operands of the enclosing "
2058  << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2059 
2060  for (OpOperand &opOperand : op->getOpOperands()) {
2061  OpOperand *outputOperand =
2062  linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2063  Type elementType = outputOperand->get().getType();
2064  if (isa<MemRefType, RankedTensorType>(elementType))
2065  elementType = getElementTypeOrSelf(outputOperand->get().getType());
2066  if (opOperand.get().getType() != elementType)
2067  return op.emitOpError("type of yield operand ")
2068  << (opOperand.getOperandNumber() + 1) << " ("
2069  << opOperand.get().getType() << ") doesn't match "
2070  << "the element type of the enclosing linalg.generic op ("
2071  << elementType << ")";
2072  }
2073  return success();
2074 }
2075 
2076 LogicalResult linalg::YieldOp::verify() {
2077  auto *parentOp = (*this)->getParentOp();
2078  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2079  return emitOpError("expected single non-empty parent region");
2080 
2081  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2082  return verifyYield(*this, linalgOp);
2083 
2084  return emitOpError("expected parent op with LinalgOp interface");
2085 }
2086 
2087 //===----------------------------------------------------------------------===//
2088 // IndexOp
2089 //===----------------------------------------------------------------------===//
2090 
2091 LogicalResult IndexOp::verify() {
2092  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2093  if (!linalgOp)
2094  return emitOpError("expected parent op with LinalgOp interface");
2095  if (linalgOp.getNumLoops() <= getDim())
2096  return emitOpError("expected dim (")
2097  << getDim() << ") to be lower than the number of loops ("
2098  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2099  return success();
2100 }
2101 
2102 /////// Operations corresponding to library calls defined with Tablegen ////////
2103 
2104 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2105 
2106 #define GET_OP_CLASSES
2107 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2108 
2109 #define GET_OP_CLASSES
2110 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2111 
2112 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2113  unsigned rank,
2114  MLIRContext *context) {
2115  if (maybeMap)
2116  return *maybeMap;
2117  if (rank == 0)
2118  return AffineMap::get(context);
2119  return AffineMap::getMultiDimIdentityMap(rank, context);
2120 }
2121 
2123 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2124  MLIRContext *context) {
2126  res.reserve(num);
2127  for (unsigned i = 0; i < num; ++i)
2128  res.push_back(getAffineDimExpr(startIdx++, context));
2129  return res;
2130 }
2131 
2134  auto rangeA = llvm::make_range(a.begin(), a.end());
2135  auto rangeB = llvm::make_range(b.begin(), b.end());
2136  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2137  return llvm::to_vector<4>(concatRanges);
2138 }
2139 
2140 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2141  if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2142  ss << "view";
2143  for (auto size : memref.getShape())
2144  if (size < 0)
2145  ss << "sx";
2146  else
2147  ss << size << "x";
2148  if (failed(appendMangledType(ss, memref.getElementType())))
2149  return failure();
2150  if (auto as = memref.getMemorySpace()) {
2151  if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2152  ss << "as" << attr.getInt();
2153  else
2154  return failure();
2155  }
2156  return success();
2157  }
2158  if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2159  ss << "vector";
2160  llvm::interleave(
2161  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2162  if (failed(appendMangledType(ss, vec.getElementType())))
2163  return failure();
2164  return success();
2165  }
2166  if (t.isSignlessIntOrIndexOrFloat()) {
2167  ss << t;
2168  return success();
2169  }
2170  return failure();
2171 }
2172 
2174  assert(isa<LinalgOp>(op));
2175  std::string name(op->getName().getStringRef().str());
2176  std::string fun = "";
2177  for (NamedAttribute kv : op->getAttrs()) {
2178  if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2179  fun = stringifyEnum(ufa.getValue()).str() + "_";
2180  } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2181  fun = stringifyEnum(bfa.getValue()).str() + "_";
2182  }
2183  }
2184  name.reserve(128);
2185  std::replace(name.begin(), name.end(), '.', '_');
2186  llvm::raw_string_ostream ss(name);
2187  ss << "_" << fun;
2188  for (Type t : op->getOperandTypes()) {
2189  if (failed(appendMangledType(ss, t)))
2190  return std::string();
2191  ss << "_";
2192  }
2193  std::string res = ss.str();
2194  res.pop_back();
2195  return res;
2196 }
2197 
2198 //===----------------------------------------------------------------------===//
2199 // Canonicalizers and Folders.
2200 //===----------------------------------------------------------------------===//
2201 
2202 namespace {
2203 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2205 
2206  LogicalResult matchAndRewrite(LinalgOp op,
2207  PatternRewriter &rewriter) const override {
2208  for (OpOperand &opOperand : op->getOpOperands()) {
2209  // Linalg "inputs" may be either tensor or memref type.
2210  // tensor<0xelt_type> is a convention that may not always mean
2211  // "0 iterations". Only erase in cases we see memref<...x0x...>.
2212  auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2213  if (!mt)
2214  continue;
2215  if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2216  rewriter.eraseOp(op);
2217  return success();
2218  }
2219  }
2220  return failure();
2221  }
2222 };
2223 
2224 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2225 /// result that is more static than the linalg op.
2226 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2228 
2229  LogicalResult matchAndRewrite(tensor::CastOp castOp,
2230  PatternRewriter &rewriter) const override {
2231  if (!tensor::canFoldIntoProducerOp(castOp))
2232  return failure();
2233 
2234  auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2235  if (!linalgOp)
2236  return failure();
2237 
2238  // Cast can be in conditionally reachable region, if which case folding will
2239  // generate invalid code. Only conservatively fold ops in same block for
2240  // now.
2241  if (castOp->getBlock() != linalgOp->getBlock())
2242  return failure();
2243 
2244  OpBuilder::InsertionGuard guard(rewriter);
2245  rewriter.setInsertionPoint(linalgOp);
2246 
2247  Location loc = linalgOp.getLoc();
2248  OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2249  unsigned resultNumber = resultValue.getResultNumber();
2250  auto resultType =
2251  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2252  // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2253  // going from a more dynamic shape to a less dynamic shape. If the producer
2254  // for this cast, i.e. producer of the out operand, is also an operation
2255  // that folds with tensor.cast consumer (like this pattern), the cast will
2256  // continue to propagate as far up the stack as it can go.
2257  OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2258  Value newOperand =
2259  rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
2260  SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2261  SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2262  linalgOp.getDpsInits().end());
2263  outputOperands[resultNumber] = newOperand;
2264  newOperands.append(outputOperands.begin(), outputOperands.end());
2265 
2266  SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2267  linalgOp->result_type_end());
2268  resultTypes[resultNumber] = resultType;
2269  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2270 
2271  // Create a tensor.cast operation back to the original type.
2272  Value castBack = rewriter.create<tensor::CastOp>(
2273  loc, resultValue.getType(), newOp->getResult(resultNumber));
2274 
2275  SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2276  results[resultNumber] = castBack;
2277  rewriter.replaceOp(linalgOp, results);
2278  rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2279  return success();
2280  }
2281 };
2282 
2283 /// For each of the operand in `operands` this function maps the static sizes of
2284 /// dimensions to their affine dim expressions.
2285 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2286  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2287  for (OpOperand &opOperand : operands) {
2288  if (linalgOp.isScalar(&opOperand))
2289  continue;
2290  Value src = opOperand.get();
2291  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2292  auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2293 
2294  // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2295  // `tensor.cast` operation and source of the cast operation has a static
2296  // shape, then assign it to the `sourceShape`.
2297  auto *parentOp = src.getDefiningOp();
2298  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2299  if (parentOp) {
2300  if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2301  Value castSource = castOp.getSource();
2302  auto castSourceType =
2303  llvm::dyn_cast<RankedTensorType>(castSource.getType());
2304  if (castSourceType && castSourceType.hasStaticShape())
2305  sourceShape = castSourceType.getShape();
2306  }
2307  }
2308 
2309  // If the source shape's dimension has a static shape, map the affine dim
2310  // expression to the known static size.
2311  for (unsigned i = 0; i < sourceShape.size(); i++) {
2312  if (sourceType.isDynamicDim(i))
2313  continue;
2314  if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2315  affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2316  }
2317  }
2318 }
2319 
2320 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2321 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2322 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2323 /// change then `changeNeeded` is false and same operand is added in the
2324 /// `newOperands` list.
2325 static void createNewOperandWithStaticSizes(
2326  Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2327  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2328  SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2329  bool &changeNeeded) {
2330  Value src = opOperand->get();
2331  newOperands.push_back(src);
2332  if (linalgOp.isScalar(opOperand))
2333  return;
2334  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2335  Type resultType = sourceType;
2336  if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2337  resultTypes.push_back(resultType);
2338  return;
2339  }
2340  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2341  AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2342  SmallVector<int64_t> newShape;
2343  // If operand is updated with new shape, `newOperandNeeded` will be
2344  // true.
2345  bool newOperandNeeded = false;
2346  for (unsigned i = 0; i < sourceShape.size(); i++) {
2347  int64_t dimShape = sourceShape[i];
2348  AffineExpr dimExpr = sourceMap.getResult(i);
2349  if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2350  newShape.push_back(dimShape);
2351  continue;
2352  }
2353  // Dimension has a dynamic shape and corresponding affine dim
2354  // expression is present in the map. So assign the size for the
2355  // given affine dim expression to the dimension.
2356  newShape.push_back(affineExprToSize[dimExpr]);
2357  newOperandNeeded = true;
2358  }
2359  resultType = RankedTensorType::get(newShape, sourceType.getElementType());
2360  if (newOperandNeeded) {
2361  changeNeeded = true;
2362  // Get the new operand value given its size and element type by
2363  // casting it.
2364  Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2365  unsigned index = opOperand->getOperandNumber();
2366  newOperands[index] = newOperand;
2367  }
2368  if (linalgOp.isDpsInit(opOperand))
2369  resultTypes.push_back(resultType);
2370 }
2371 
2372 /// Static shapes for the operands can be inferred if any one of the operands
2373 /// have a static shape. This can be done by referring to the affine dim
2374 /// expressions for the operand.
2375 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2377 
2378  LogicalResult matchAndRewrite(LinalgOp linalgOp,
2379  PatternRewriter &rewriter) const override {
2380  if (!linalgOp.hasPureTensorSemantics())
2381  return failure();
2382 
2383  // Maps must be projected permutations.
2384  if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2385  return !map.isProjectedPermutation();
2386  }))
2387  return failure();
2388 
2389  // Maps affine dim expressions to the static size of that dimension.
2390  llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2391  Location loc = linalgOp.getLoc();
2392 
2393  // For each of the affine dim expression, check if the size is known. If
2394  // known add that in the map.
2395  populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2396 
2397  SmallVector<Value> newOperands;
2398  SmallVector<Type> resultTypes;
2399 
2400  // `changeNeeded` is `false` if the operands of `linalgOp` require no
2401  // change in their types.
2402  bool changeNeeded = false;
2403  newOperands.reserve(linalgOp->getNumOperands());
2404  resultTypes.reserve(linalgOp.getNumDpsInits());
2405 
2406  // Iterate over all the operands and update the static sizes.
2407  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2408  createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2409  affineExprToSize, linalgOp, newOperands,
2410  resultTypes, changeNeeded);
2411  }
2412 
2413  // If the generic op has all the required static information, no
2414  // canonicalization needed.
2415  if (!changeNeeded)
2416  return failure();
2417 
2418  // Clone op.
2419  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2420  SmallVector<Value> replacements;
2421  replacements.reserve(newOp->getNumResults());
2422  for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2423  Value newResult = std::get<1>(it);
2424  Value oldResult = std::get<0>(it);
2425  Type newType = newResult.getType();
2426  Type oldType = oldResult.getType();
2427  replacements.push_back(
2428  (newType != oldType)
2429  ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2430  : newResult);
2431  }
2432  rewriter.replaceOp(linalgOp, replacements);
2433  return success();
2434  }
2435 };
2436 
2437 } // namespace
2438 
2439 // All named ops canonicalizers and folders are auto-generated in the
2440 // .cpp.inc.
2441 
2442 //===----------------------------------------------------------------------===//
2443 // SoftmaxOp
2444 //===----------------------------------------------------------------------===//
2445 
2446 LogicalResult SoftmaxOp::verify() {
2447  ShapedType inputType = getInputOperandType();
2448  ShapedType outputType = getOutputOperandType();
2449 
2450  ArrayRef<int64_t> inputShape = inputType.getShape();
2451  ArrayRef<int64_t> outputShape = outputType.getShape();
2452  if (failed(verifyCompatibleShape(inputShape, outputShape)))
2453  return emitOpError("incompatible output shape");
2454 
2455  int64_t inputRank = getInputOperandRank();
2456  int64_t dimension = getDimension();
2457  if ((dimension < 0) || (dimension >= inputRank))
2458  return emitOpError("incorrect dimension specified");
2459 
2460  return success();
2461 }
2462 
2463 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2464  int64_t operandRank = getInputOperandRank();
2465  SmallVector<Range> loopBounds(operandRank);
2466  Location loc = getLoc();
2467  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2468  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2469  Value source = getInput();
2470  for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2471  loopBounds[dim].offset = zero;
2472  loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2473  loopBounds[dim].stride = one;
2474  }
2475  return loopBounds;
2476 }
2477 
2478 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2479  SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2480  utils::IteratorType::parallel);
2481  iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2482  return iteratorTypes;
2483 }
2484 
2485 FailureOr<TilingResult>
2486 SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2487  ArrayRef<OpFoldResult> offsets,
2488  ArrayRef<OpFoldResult> sizes) {
2489  int64_t rank = getInputOperandRank();
2490  auto oneAttr = builder.getI64IntegerAttr(1);
2491  SmallVector<OpFoldResult> strides(rank, oneAttr);
2492  SmallVector<Value> tiledOperands;
2493  tiledOperands.emplace_back(
2494  getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
2495  tiledOperands.emplace_back(
2496  getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
2497 
2498  SmallVector<Type, 4> resultTypes;
2499  if (hasPureTensorSemantics())
2500  resultTypes.push_back(tiledOperands[1].getType());
2501  Operation *tiledOp =
2502  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2503 
2504  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
2505 }
2506 
2507 LogicalResult SoftmaxOp::getResultTilePosition(
2508  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2509  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2510  SmallVector<OpFoldResult> &resultSizes) {
2511  if (resultNumber == 0) {
2512  resultOffsets.assign(offsets.begin(), offsets.end());
2513  resultSizes.assign(sizes.begin(), sizes.end());
2514  return success();
2515  }
2516  return failure();
2517 }
2518 
2519 // cast(dynamic) -> static.
2520 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2521  return memref::foldMemRefCast(*this);
2522 }
2523 
2524 LogicalResult
2526  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2528  Location loc = getOperation()->getLoc();
2529  IRRewriter rewriter(b);
2530  auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2531  auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2532  for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2533  if (!outputShapedType.isDynamicDim(dim)) {
2534  // Static dim: Return IntegerAttr.
2535  shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2536  } else {
2537  // Dynamic dim: Return Value.
2538  OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2539  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2540  }
2541  }
2542  reifiedReturnShapes.emplace_back(std::move(shapes));
2543  return success();
2544 }
2545 
2546 void SoftmaxOp::getEffects(
2548  &effects) {
2549  for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2550  if (!llvm::isa<MemRefType>(operand.getType()))
2551  continue;
2552  effects.emplace_back(MemoryEffects::Read::get(),
2553  &getOperation()->getOpOperand(index), /*stage=*/0,
2554  /*effectOnFullRegion=*/true,
2556  }
2557 
2558  for (OpOperand &operand : getDpsInitsMutable()) {
2559  if (!llvm::isa<MemRefType>(operand.get().getType()))
2560  continue;
2561  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2562  /*effectOnFullRegion=*/true,
2564  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2565  /*effectOnFullRegion=*/true,
2567  }
2568 }
2569 
2570 // Helper functions for softmax decomposition.
2571 // @{
2572 
2573 // Helper function to produce the iterator types (reduction or parallel) and
2574 // affine maps for the iterators used in the decomposition of softmax.
2575 // This method creates:
2576 // If allParallel == true:
2577 // - iterator type: {parallel, ..., parallel}
2578 // - affine maps:
2579 // -- identity with inputRank dimensions.
2580 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2581 // where N == inputRank.
2582 //
2583 // If allParallel == false:
2584 // - iterator type at dim(i) == parallel for i != \p dim and
2585 // dim(dim) == reduction.
2586 // - affine map:
2587 // -- identity with inputRank dimensions.
2588 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2589 // where N == inputRank.
2590 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2592  int64_t dim, bool allParallel = false) {
2593  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2594  utils::IteratorType::parallel);
2595  if (!allParallel)
2596  iteratorTypes[dim] = utils::IteratorType::reduction;
2597  MLIRContext *ctxt = builder.getContext();
2598  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2599  SmallVector<AffineExpr, 2> affineExprs;
2600  for (int i = 0; i < inputRank; i++) {
2601  if (i != dim)
2602  affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2603  }
2604  auto reductionMap =
2605  AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2606  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2607  return std::make_tuple(iteratorTypes, indexingMaps);
2608 }
2609 
2610 // Helper function to produce a linalg.generic that computes a reduction on
2611 // dimension \p dim with the operation type \p T.
2612 template <typename T>
2613 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2614  int64_t dim) {
2615  auto inputType = cast<ShapedType>(input.getType());
2616  ArrayRef<int64_t> inputShape = inputType.getShape();
2617  int64_t inputRank = inputShape.size();
2618  auto [iteratorTypes, indexingMaps] =
2619  computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2620  assert(indexingMaps.size() == 2 &&
2621  "We should have two maps: 1 for the input, 1 for the output");
2622  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2623 
2624  auto genericOp = builder.create<linalg::GenericOp>(
2625  loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2626  [&](OpBuilder &b, Location loc, ValueRange args) {
2627  Value result = b.create<T>(loc, args[0], args[1]);
2628  b.create<linalg::YieldOp>(loc, result);
2629  });
2630  return genericOp.getResult(0);
2631 }
2632 
2633 /// Produce a linalg generic that computes the second step of the softmax
2634 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2635 /// on dimension \p dim.
2636 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2637  Value max, Value output, int64_t dim) {
2638  auto inputType = cast<ShapedType>(input.getType());
2639  ArrayRef<int64_t> inputShape = inputType.getShape();
2640  int64_t inputRank = inputShape.size();
2641  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2642  builder, inputRank, dim, /*allParallel=*/true);
2643  assert(indexingMaps.size() == 2 && "We should have one map for each input");
2644  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2645  // Add the affine map for the output argument.
2646  indexingMaps.push_back(indexingMaps[0]);
2647  auto genericOp = builder.create<linalg::GenericOp>(
2648  loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2649  iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2650  Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2651  Value result = b.create<math::ExpOp>(loc, diff);
2652  b.create<linalg::YieldOp>(loc, result);
2653  });
2654  return genericOp.getResult(0);
2655 }
2656 
2657 /// Produce a linalg generic that computes the final step of the softmax
2658 /// decomposition.
2659 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2660 /// yield n / d
2661 /// }
2662 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2663  Value denominator, Value output, int64_t dim) {
2664  auto inputType = cast<ShapedType>(numerator.getType());
2665  ArrayRef<int64_t> inputShape = inputType.getShape();
2666  int64_t inputRank = inputShape.size();
2667  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2668  builder, inputRank, dim, /*allParallel=*/true);
2669  assert(indexingMaps.size() == 2 &&
2670  "We should have one map for each input (2)");
2671  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2672  // Add the affine map for the output tensor.
2673  indexingMaps.push_back(indexingMaps[0]);
2674  auto genericOp = builder.create<linalg::GenericOp>(
2675  loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2676  indexingMaps, iteratorTypes,
2677  [&](OpBuilder &b, Location loc, ValueRange args) {
2678  Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2679  b.create<linalg::YieldOp>(loc, result);
2680  });
2681  return genericOp.getResult(0);
2682 }
2683 // @} End helper functions for softmax decomposition.
2684 
2685 /// Given an N-dimensional tensor x, this method converts
2686 /// softmax(x) to the following sequence of operations:
2687 ///
2688 /// 1. Compute the max of x along dimension d. This results
2689 /// in a N-1 dimensional tensor m.
2690 /// m = max(x, dim = d)
2691 ///
2692 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
2693 /// a N dimensional tensor z.
2694 /// z = exp(x - m)
2695 ///
2696 /// 3. Compute the sum of z along dimension d. This results in
2697 /// a N-1 dimensional tensor l.
2698 /// l = sum(z, dim = d)
2699 ///
2700 /// 4. Divide z and l. This gives the N-dimensional softmax.
2701 /// softmax = z / l
2702 ///
2703 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2704  OpBuilder::InsertionGuard guard(b);
2705  b.setInsertionPoint(*this);
2706  Location loc = getLoc();
2707  Value input = getInput();
2708  ShapedType inputType = getInputOperandType();
2709  Type elementType = inputType.getElementType();
2710  int64_t reductionDim = getDimension();
2711  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
2712  Value output = getOutput();
2713  dims.erase(dims.begin() + reductionDim);
2714  // Step 1: Compute max along dim.
2715  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2716  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
2717  elementType, b, loc,
2718  /*useOnlyFiniteValue=*/true);
2719  Value neutralForMaxFInit =
2720  b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
2721  .result();
2722  Value max =
2723  reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2724 
2725  // Step 2: Subtract max from input and exponentiate.
2726  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
2727 
2728  // Step 3: Compute sum along dim.
2729  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
2730  b, loc, /*useOnlyFiniteValue=*/true);
2731  Value zeroInit =
2732  b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
2733  Value denominator =
2734  reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2735 
2736  // Step 4: Compute softmax.
2737  Value result =
2738  buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2739  return SmallVector<Value>{result};
2740 }
2741 
2742 //===----------------------------------------------------------------------===//
2743 // WinogradFilterTransformOp
2744 //===----------------------------------------------------------------------===//
2745 
2746 LogicalResult WinogradFilterTransformOp::verify() {
2747  auto filterType = cast<ShapedType>(getFilter().getType());
2748  ArrayRef<int64_t> filterShape = filterType.getShape();
2749  int64_t filterH = filterShape[1];
2750  int64_t filterW = filterShape[2];
2751  int64_t r = getR();
2752  int64_t m = getM();
2753 
2754  if (filterH != r && filterH != 1)
2755  return emitOpError("expect filter height either equals to r or 1");
2756  if (filterW != r && filterW != 1)
2757  return emitOpError("expect filter width either equals to r or 1");
2758  if (filterH == 1 && filterW == 1)
2759  return emitOpError("expect either filter height or width equals to r");
2760 
2761  SmallVector<int64_t> expectedOutputShape;
2762  expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2763  expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2764  expectedOutputShape.push_back(filterShape[3]);
2765  expectedOutputShape.push_back(filterShape[0]);
2766 
2767  auto outputType = cast<ShapedType>(getOutput().getType());
2768  ArrayRef<int64_t> outputShape = outputType.getShape();
2769  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2770  return emitOpError("the output shape is not expected");
2771  }
2772  return success();
2773 }
2774 
2775 //===----------------------------------------------------------------------===//
2776 // WinogradInputTransformOp
2777 //===----------------------------------------------------------------------===//
2778 
2779 LogicalResult WinogradInputTransformOp::verify() {
2780  auto inputType = cast<ShapedType>(getInput().getType());
2781  ArrayRef<int64_t> inputShape = inputType.getShape();
2782  int64_t inputH = inputShape[1];
2783  int64_t inputW = inputShape[2];
2784  int m = getM();
2785  int r = getR();
2786  int64_t tileSize = m + r - 1;
2787  bool leftTransform = inputH != 1;
2788  bool rightTransform = inputW != 1;
2789 
2790  SmallVector<int64_t> expectedOutputShape(6, inputH);
2791  if (ShapedType::isDynamic(inputH)) {
2792  expectedOutputShape[0] = tileSize;
2793  expectedOutputShape[2] = ShapedType::kDynamic;
2794  } else {
2795  expectedOutputShape[0] = leftTransform ? tileSize : 1;
2796  expectedOutputShape[2] = leftTransform ? (inputH - (r - 1)) / m : 1;
2797  }
2798  if (ShapedType::isDynamic(inputW)) {
2799  expectedOutputShape[1] = tileSize;
2800  expectedOutputShape[3] = ShapedType::kDynamic;
2801  } else {
2802  expectedOutputShape[1] = rightTransform ? tileSize : 1;
2803  expectedOutputShape[3] = rightTransform ? (inputW - (r - 1)) / m : 1;
2804  }
2805  expectedOutputShape[4] = inputShape[0];
2806  expectedOutputShape[5] = inputShape[3];
2807 
2808  auto outputType = cast<ShapedType>(getOutput().getType());
2809  ArrayRef<int64_t> outputShape = outputType.getShape();
2810  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2811  return emitOpError("the output shape is not expected");
2812  }
2813  return success();
2814 }
2815 
2816 //===----------------------------------------------------------------------===//
2817 // WinogradOutputTransformOp
2818 //===----------------------------------------------------------------------===//
2819 
2820 LogicalResult WinogradOutputTransformOp::verify() {
2821  auto valueType = cast<ShapedType>(getValue().getType());
2822  ArrayRef<int64_t> valueShape = valueType.getShape();
2823  int64_t valueH = valueShape[0];
2824  int64_t valueW = valueShape[1];
2825  int64_t valueTileH = valueShape[2];
2826  int64_t valueTileW = valueShape[3];
2827  int m = getM();
2828  int r = getR();
2829  bool leftTransform = valueH != 1;
2830  bool rightTransform = valueW != 1;
2831 
2832  SmallVector<int64_t> expectedOutputShape(4, valueH);
2833  if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
2834  expectedOutputShape[1] = ShapedType::kDynamic;
2835  } else {
2836  if (valueH != (leftTransform ? m + r - 1 : 1))
2837  return emitOpError("expect input height equals to input tile size");
2838  expectedOutputShape[1] = (leftTransform ? m : 1) * valueTileH;
2839  }
2840  if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
2841  expectedOutputShape[2] = ShapedType::kDynamic;
2842  } else {
2843  if (valueW != (rightTransform ? m + r - 1 : 1))
2844  return emitOpError("expect input width equals to input tile size");
2845  expectedOutputShape[2] = (rightTransform ? m : 1) * valueTileW;
2846  }
2847  expectedOutputShape[0] = valueShape[4];
2848  expectedOutputShape[3] = valueShape[5];
2849 
2850  auto outputType = cast<ShapedType>(getOutput().getType());
2851  ArrayRef<int64_t> outputShape = outputType.getShape();
2852  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2853  return emitOpError("the output shape is not expected");
2854  }
2855  return success();
2856 }
2857 
2858 //===----------------------------------------------------------------------===//
2859 // LinalgDialect
2860 //===----------------------------------------------------------------------===//
2861 
2862 void LinalgDialect::getCanonicalizationPatterns(
2863  RewritePatternSet &results) const {
2864  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
2865  InferStaticShapeOfOperands>(getContext());
2866 }
2867 
2869  Attribute value, Type type,
2870  Location loc) {
2871  return arith::ConstantOp::materialize(builder, value, type, loc);
2872 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:1741
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:271
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
Definition: LinalgOps.cpp:2591
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
Definition: LinalgOps.cpp:2662
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
Definition: LinalgOps.cpp:119
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:2140
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:259
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:330
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
Definition: LinalgOps.cpp:1574
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
Definition: LinalgOps.cpp:1622
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
Definition: LinalgOps.cpp:2636
static Operation * findPayloadOp(Block *body, bool initFirst=false)
Definition: LinalgOps.cpp:1382
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
Definition: LinalgOps.cpp:155
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
Definition: LinalgOps.cpp:1251
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2613
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1122
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2053
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:297
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
Definition: LinalgOps.cpp:897
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:290
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition: LinalgOps.cpp:51
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
Definition: LinalgOps.cpp:1305
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
Definition: LinalgOps.cpp:1411
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:323
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
Definition: LinalgOps.cpp:185
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:291
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:321
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:398
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:251
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
OpListType & getOperations()
Definition: Block.h:135
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:183
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:394
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:269
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_iterator result_begin()
Definition: Operation.h:408
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_iterator result_end()
Definition: Operation.h:409
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:107
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2585
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2132
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:98
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:2173
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
Definition: LinalgOps.cpp:2112
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:2123
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:89
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:76
DynamicAPInt ceil(const Fraction &f)
Definition: Fraction.h:78
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:135
Fraction abs(const Fraction &f)
Definition: Fraction.h:106
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
Definition: Enums.h:443
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:349
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:65
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:755
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:606
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
Fold transpose with transpose.
Definition: LinalgOps.cpp:1876
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:1879
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.