MLIR  21.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
29 #include "mlir/IR/AffineMap.h"
30 #include "mlir/IR/Attributes.h"
33 #include "mlir/IR/Matchers.h"
36 #include "mlir/IR/PatternMatch.h"
39 
40 #include "llvm/ADT/DenseMap.h"
41 #include "llvm/ADT/STLExtras.h"
42 #include "llvm/ADT/SetOperations.h"
43 #include "llvm/ADT/SmallSet.h"
44 #include "llvm/ADT/SmallVector.h"
45 #include "llvm/ADT/StringSet.h"
46 #include "llvm/ADT/TypeSwitch.h"
47 #include "llvm/Support/FormatVariadic.h"
48 #include "llvm/Support/LogicalResult.h"
49 #include "llvm/Support/MathExtras.h"
50 #include "llvm/Support/raw_ostream.h"
51 #include <cassert>
52 #include <optional>
53 
54 using namespace mlir;
55 using namespace mlir::linalg;
56 
57 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
59  int64_t dim) {
60  auto type = cast<ShapedType>(v.getType());
61  if (!type.isDynamicDim(dim))
62  return builder.getIndexAttr(type.getDimSize(dim));
63 
64  return getAsOpFoldResult(
66  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
67  return builder.create<tensor::DimOp>(loc, v, dim);
68  })
69  .Case<MemRefType>([&](MemRefType t) -> Value {
70  return builder.create<memref::DimOp>(loc, v, dim);
71  }));
72 }
73 
74 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
75 /// `source`.
76 static Operation *getSlice(OpBuilder &b, Location loc, Value source,
77  ArrayRef<OpFoldResult> offsets,
79  ArrayRef<OpFoldResult> strides) {
80  return TypeSwitch<Type, Operation *>(source.getType())
81  .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
82  return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
83  strides);
84  })
85  .Case<MemRefType>([&](MemRefType type) -> Operation * {
86  return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
87  strides);
88  })
89  .Default([&](Type t) -> Operation * { return nullptr; });
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // Helper functions
94 //===----------------------------------------------------------------------===//
95 
97  int64_t dim) {
98  if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
99  return b.createOrFold<memref::DimOp>(loc, source, dim);
100  if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
101  return b.createOrFold<tensor::DimOp>(loc, source, dim);
102  llvm_unreachable("Expected MemRefType or TensorType");
103 }
104 
106  int64_t dim) {
107  auto shapedType = llvm::cast<ShapedType>(source.getType());
108  if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
109  return createOrFoldDimOp(b, loc, source, dim);
110  return b.getIndexAttr(shapedType.getDimSize(dim));
111 }
112 
113 //===----------------------------------------------------------------------===//
114 // Support for named Linalg ops defined in ods-gen.
115 //===----------------------------------------------------------------------===//
116 
119 
120 /// Fills the region of a structured operation using the provided
121 /// `regionBuilder`. The method is used by both named structured ops created by
122 /// ods-gen and by manually defined C++ ops. It is called by both builders and
123 /// parsers and creates a block with arguments corresponding to the elemental
124 /// types of `inputTypes` and `outputTypes`.
125 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
126  TypeRange inputTypes, TypeRange outputTypes,
128  RegionBuilderFn regionBuilder) {
129  SmallVector<Type, 8> argTypes;
130  SmallVector<Location, 8> argLocs;
131  for (auto containers : {inputTypes, outputTypes}) {
132  for (auto t : containers) {
133  argTypes.push_back(
134  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
135 
136  // TODO: Pass in a proper location here.
137  argLocs.push_back(opBuilder.getUnknownLoc());
138  }
139  }
140 
141  // RAII.
142  OpBuilder::InsertionGuard guard(opBuilder);
143  Block *body =
144  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
145 
146  opBuilder.setInsertionPointToStart(body);
147  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
148  regionBuilder(b, *body, attrs);
149 
150  // indexing_maps is an auto-generated method.
151 
152  // iterator_types is an auto-generated method.
153 }
154 
155 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
156 /// The result types are derived automatically if `resultTensorTypes` is none.
157 /// The body of the operation is filled using `regionBuilder`. All ods-gen
158 /// created structured operations use the method to implement their builders.
160  std::optional<TypeRange> resultTensorTypes,
161  ValueRange inputs, ValueRange outputs,
162  ArrayRef<NamedAttribute> attributes,
163  RegionBuilderFn regionBuilder) {
164  // Derive the result types if needed.
165  SmallVector<Type> derivedResultTypes =
166  resultTensorTypes.value_or(TypeRange());
167  if (!resultTensorTypes)
168  copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
169  llvm::IsaPred<RankedTensorType>);
170 
171  state.addOperands(inputs);
172  state.addOperands(outputs);
173  state.addTypes(derivedResultTypes);
174 
175  state.addAttributes(attributes);
176  state.addAttribute(
177  "operandSegmentSizes",
178  b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
179  static_cast<int32_t>(outputs.size())}));
180 
181  // Create and fill the region of the structured operation.
182  Region &region = *state.addRegion();
183  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
184  state.attributes.getAttrs(), regionBuilder);
185 }
186 
187 static void buildMatmulOp(OpBuilder &b, OperationState &state,
188  std::optional<TypeRange> resultTensorTypes,
189  ValueRange inputs, ValueRange outputs,
190  ArrayRef<NamedAttribute> attributes,
191  RegionBuilderFn regionBuilder,
192  ArrayRef<AffineMap> indexingMaps) {
193  // Initialize indexingMaps attribute, for MatmulOp.
194  SmallVector<Attribute, 3> indexingMapsAttrVal;
195  indexingMapsAttrVal = llvm::map_to_vector(
196  MatmulOp::getDefaultIndexingMaps(b.getContext()),
197  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
198  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
199  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
200  attributes, regionBuilder);
201 }
202 
204  std::optional<TypeRange> resultTensorTypes,
205  ValueRange inputs, ValueRange outputs,
206  ArrayRef<NamedAttribute> attributes,
207  RegionBuilderFn regionBuilder,
208  ArrayRef<AffineMap> indexingMaps) {
209  // Initialize indexingMaps attribute, for BatchMatmulOp.
210  SmallVector<Attribute, 4> indexingMapsAttrVal;
211  indexingMapsAttrVal =
212  llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
213  return AffineMapAttr::get(map);
214  });
215  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
216  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
217  attributes, regionBuilder);
218 }
219 
220 /// Common parsing used for both named structured ops created by ods-gen and by
221 /// manually defined C++ ops. Does not handle regions.
222 static ParseResult
224  SmallVectorImpl<Type> &inputTypes,
225  SmallVectorImpl<Type> &outputTypes,
226  bool addOperandSegmentSizes = true) {
227  SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
229  outputsOperands;
230 
231  if (succeeded(parser.parseOptionalLess())) {
232  if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
233  return failure();
234  }
235  attrsLoc = parser.getCurrentLocation();
236  if (parser.parseOptionalAttrDict(result.attributes))
237  return failure();
238 
239  if (succeeded(parser.parseOptionalKeyword("ins"))) {
240  if (parser.parseLParen())
241  return failure();
242 
243  inputsOperandsLoc = parser.getCurrentLocation();
244  if (parser.parseOperandList(inputsOperands) ||
245  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
246  return failure();
247  }
248 
249  if (succeeded(parser.parseOptionalKeyword("outs"))) {
250  outputsOperandsLoc = parser.getCurrentLocation();
251  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
252  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
253  return failure();
254  }
255 
256  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
257  result.operands) ||
258  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
259  result.operands))
260  return failure();
261 
262  if (addOperandSegmentSizes) {
263  // This is a bit complex because we're trying to be backward compatible with
264  // operation syntax that mix the inherent attributes and the discardable
265  // ones in the same dictionary. If the properties are used, we append the
266  // operandSegmentSizes there directly. Otherwise we append it to the
267  // discardable attributes dictionary where it is handled by the generic
268  // Operation::create(...) method.
269  if (result.propertiesAttr) {
270  NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
271  attrs.append("operandSegmentSizes",
273  {static_cast<int32_t>(inputsOperands.size()),
274  static_cast<int32_t>(outputsOperands.size())}));
275  result.propertiesAttr = attrs.getDictionary(parser.getContext());
276  } else {
277  result.addAttribute("operandSegmentSizes",
279  {static_cast<int32_t>(inputsOperands.size()),
280  static_cast<int32_t>(outputsOperands.size())}));
281  }
282  }
283  if (!result.propertiesAttr) {
284  std::optional<RegisteredOperationName> info =
285  result.name.getRegisteredInfo();
286  if (info) {
287  if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
288  return parser.emitError(attrsLoc)
289  << "'" << result.name.getStringRef() << "' op ";
290  })))
291  return failure();
292  }
293  }
294  return success();
295 }
296 
298  ValueRange outputs) {
299  if (!inputs.empty())
300  p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
301  if (!outputs.empty())
302  p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
303 }
304 
305 //===----------------------------------------------------------------------===//
306 // Specific parsing and printing for named structured ops created by ods-gen.
307 //===----------------------------------------------------------------------===//
308 
309 static ParseResult parseNamedStructuredOpRegion(
310  OpAsmParser &parser, Region &region, unsigned numRegionArgs,
311  TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
312  RegionBuilderFn regionBuilder) {
313  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
314  return parser.emitError(
315  parser.getCurrentLocation(),
316  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
317  "region expects {0} args, got {1}",
318  numRegionArgs, inputTypes.size() + outputTypes.size()));
319  }
320 
321  OpBuilder opBuilder(parser.getContext());
322  fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
323  regionBuilder);
324  return success();
325 }
326 
327 static ParseResult
329  SmallVectorImpl<Type> &resultTypes) {
330  if (parser.parseOptionalArrowTypeList(resultTypes))
331  return failure();
332  return success();
333 }
334 
335 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
336  OperationState &result,
337  unsigned numRegionArgs,
338  RegionBuilderFn regionBuilder) {
339  // TODO: Enable when ods-gen supports captures.
340  SmallVector<Type, 1> inputTypes, outputTypes;
341  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
342  return failure();
343 
344  // Parse optional attributes.
345  if (parser.parseOptionalAttrDict(result.attributes))
346  return failure();
347 
348  // TODO: consider merging results parsing into region parsing.
349  // Need to wait for declarative assembly resolution to decide.
350  SmallVector<Type, 1> outputTensorsTypes;
351  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
352  return failure();
353  result.addTypes(outputTensorsTypes);
354 
355  std::unique_ptr<Region> region = std::make_unique<Region>();
356  if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
357  outputTypes, result.attributes.getAttrs(),
358  regionBuilder))
359  return failure();
360  result.addRegion(std::move(region));
361 
362  return success();
363 }
364 
366  TypeRange resultTypes) {
367  if (resultTypes.empty())
368  return;
369  p.printOptionalArrowTypeList(resultTypes);
370 }
371 
373  ValueRange inputs, ValueRange outputs,
374  ArrayRef<StringRef> elidedAttrs = {}) {
375  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
376 
377  // Printing is shared with generic ops, except for the region and
378  // attributes.
379  printCommonStructuredOpParts(p, inputs, outputs);
380 
381  // Results printing.
383 
384  // Region is elided.
385 }
386 
387 //===----------------------------------------------------------------------===//
388 // Region builder helper.
389 // TODO: Move this to a utility library.
390 // The public methods on this class are referenced directly from generated code.
391 // Helper build the unary, binary, and type conversion functions defined by the
392 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
393 // class.
394 //
395 // Implementations of the math functions must be polymorphic over numeric types,
396 // internally performing necessary casts. If the function application makes no
397 // sense, then the only recourse is to assert and return nullptr. This can be
398 // extended later if it becomes possible to fail construction of the region. The
399 // invariant should be enforced at a higher level.
400 //
401 // TODO: These helpers are currently type polymorphic over the class of integer
402 // and floating point types, but they will not internally cast within bit
403 // widths of a class (mixed precision such as i8->i32) or across classes
404 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
405 // to be handled with care and work is being considered to extend the op
406 // language to make such cases explicit. In the mean-time, violating this will
407 // fail verification, which is deemed acceptable.
408 //===----------------------------------------------------------------------===//
409 
410 namespace {
411 
412 class RegionBuilderHelper {
413 public:
414  RegionBuilderHelper(OpBuilder &builder, Block &block)
415  : builder(builder), block(block) {}
416 
417  // Build the unary functions defined by OpDSL.
418  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
419  if (!isFloatingPoint(arg))
420  llvm_unreachable("unsupported non numeric type");
421  OpBuilder::InsertionGuard g(builder);
422  builder.setInsertionPointToEnd(&block);
423  switch (unaryFn) {
424  case UnaryFn::exp:
425  return builder.create<math::ExpOp>(arg.getLoc(), arg);
426  case UnaryFn::log:
427  return builder.create<math::LogOp>(arg.getLoc(), arg);
428  case UnaryFn::abs:
429  return builder.create<math::AbsFOp>(arg.getLoc(), arg);
430  case UnaryFn::ceil:
431  return builder.create<math::CeilOp>(arg.getLoc(), arg);
432  case UnaryFn::floor:
433  return builder.create<math::FloorOp>(arg.getLoc(), arg);
434  case UnaryFn::negf:
435  return builder.create<arith::NegFOp>(arg.getLoc(), arg);
436  case UnaryFn::reciprocal: {
437  Attribute oneAttr = builder.getOneAttr(arg.getType());
438  auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
439  ::cast<TypedAttr>(oneAttr));
440  return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
441  }
442  case UnaryFn::round:
443  return builder.create<math::RoundOp>(arg.getLoc(), arg);
444  case UnaryFn::sqrt:
445  return builder.create<math::SqrtOp>(arg.getLoc(), arg);
446  case UnaryFn::rsqrt:
447  return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
448  case UnaryFn::square:
449  return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
450  case UnaryFn::tanh:
451  return builder.create<math::TanhOp>(arg.getLoc(), arg);
452  case UnaryFn::erf:
453  return builder.create<math::ErfOp>(arg.getLoc(), arg);
454  }
455  llvm_unreachable("unsupported unary function");
456  }
457 
458  // Build the binary functions defined by OpDSL.
459  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
460  bool allComplex = isComplex(arg0) && isComplex(arg1);
461  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
462  bool allInteger = isInteger(arg0) && isInteger(arg1);
463  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
464  arg1.getType().getIntOrFloatBitWidth() == 1;
465  if (!allComplex && !allFloatingPoint && !allInteger)
466  llvm_unreachable("unsupported non numeric type");
467  OpBuilder::InsertionGuard g(builder);
468  builder.setInsertionPointToEnd(&block);
469  switch (binaryFn) {
470  case BinaryFn::add:
471  if (allComplex)
472  return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
473  if (allFloatingPoint)
474  return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
475  if (allBool)
476  return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
477  return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
478  case BinaryFn::sub:
479  if (allComplex)
480  return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
481  if (allFloatingPoint)
482  return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
483  if (allBool)
484  llvm_unreachable("unsupported operation: sub with bools");
485  return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
486  case BinaryFn::mul:
487  if (allComplex)
488  return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
489  if (allFloatingPoint)
490  return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
491  if (allBool)
492  return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
493  return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
494  case BinaryFn::div:
495  if (allComplex)
496  return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
497  if (allFloatingPoint)
498  return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
499  if (allBool)
500  llvm_unreachable("unsupported operation: div with bools");
501  return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
502  case BinaryFn::div_unsigned:
503  if (!allInteger || allBool)
504  llvm_unreachable("unsupported operation: unsigned div not on uint");
505  return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
506  case BinaryFn::max_signed:
507  assert(!allComplex);
508  if (allFloatingPoint)
509  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
510  return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
511  case BinaryFn::min_signed:
512  assert(!allComplex);
513  if (allFloatingPoint)
514  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
515  return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
516  case BinaryFn::max_unsigned:
517  assert(!allComplex);
518  if (allFloatingPoint)
519  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
520  return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
521  case BinaryFn::min_unsigned:
522  assert(!allComplex);
523  if (allFloatingPoint)
524  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
525  return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
526  case BinaryFn::powf:
527  assert(allFloatingPoint);
528  return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
529  }
530  llvm_unreachable("unsupported binary function");
531  }
532 
533  // Build the ternary functions defined by OpDSL.
534  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
535  Value arg2) {
536  bool headBool =
537  isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
538  bool tailFloatingPoint =
539  isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
540  bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
541  OpBuilder::InsertionGuard g(builder);
542  builder.setInsertionPointToEnd(&block);
543  switch (ternaryFn) {
544  case TernaryFn::select:
545  if (!headBool && !(tailFloatingPoint || tailInteger))
546  llvm_unreachable("unsupported non numeric type");
547  return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
548  }
549  llvm_unreachable("unsupported ternary function");
550  }
551 
552  // Build the type functions defined by OpDSL.
553  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
554  switch (typeFn) {
555  case TypeFn::cast_signed:
556  return cast(toType, operand, false);
557  case TypeFn::cast_unsigned:
558  return cast(toType, operand, true);
559  }
560  llvm_unreachable("unsupported type conversion function");
561  }
562 
563  void yieldOutputs(ValueRange values) {
564  OpBuilder::InsertionGuard g(builder);
565  builder.setInsertionPointToEnd(&block);
566  Location loc = builder.getUnknownLoc();
567  builder.create<YieldOp>(loc, values);
568  }
569 
570  Value constant(const std::string &value) {
571  OpBuilder::InsertionGuard g(builder);
572  builder.setInsertionPointToEnd(&block);
573  Location loc = builder.getUnknownLoc();
574  Attribute valueAttr = parseAttribute(value, builder.getContext());
575  return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
576  }
577 
578  Value index(int64_t dim) {
579  OpBuilder::InsertionGuard g(builder);
580  builder.setInsertionPointToEnd(&block);
581  return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
582  }
583 
584  Type getIntegerType(unsigned width) {
585  return IntegerType::get(builder.getContext(), width);
586  }
587 
588  Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
589  Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
590 
591 private:
592  // Generates operations to cast the given operand to a specified type.
593  // If the cast cannot be performed, a warning will be issued and the
594  // operand returned as-is (which will presumably yield a verification
595  // issue downstream).
596  Value cast(Type toType, Value operand, bool isUnsignedCast) {
597  OpBuilder::InsertionGuard g(builder);
598  builder.setInsertionPointToEnd(&block);
599  auto loc = operand.getLoc();
600  return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
601  }
602 
603  bool isComplex(Value value) {
604  return llvm::isa<ComplexType>(value.getType());
605  }
606  bool isFloatingPoint(Value value) {
607  return llvm::isa<FloatType>(value.getType());
608  }
609  bool isInteger(Value value) {
610  return llvm::isa<IntegerType>(value.getType());
611  }
612 
613  OpBuilder &builder;
614  Block &block;
615 };
616 
617 } // namespace
618 
619 //===----------------------------------------------------------------------===//
620 // CopyOp
621 //===----------------------------------------------------------------------===//
622 
623 namespace {
624 
625 struct EraseSelfCopy : OpRewritePattern<CopyOp> {
627  LogicalResult matchAndRewrite(CopyOp copyOp,
628  PatternRewriter &rewriter) const override {
629  if (copyOp.getInputs() != copyOp.getOutputs())
630  return rewriter.notifyMatchFailure(copyOp, "not a self copy");
631  if (copyOp.hasPureBufferSemantics())
632  rewriter.eraseOp(copyOp);
633  else
634  rewriter.replaceOp(copyOp, copyOp.getInputs());
635 
636  return success();
637  }
638 };
639 
640 } // namespace
641 
642 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
643  MLIRContext *context) {
644  results.add<EraseSelfCopy>(context);
645 }
646 
647 //===----------------------------------------------------------------------===//
648 // FillOp
649 //===----------------------------------------------------------------------===//
650 
651 namespace {
652 
653 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
654 ///
655 /// For such op chains, we can create new linalg.fill ops with the result
656 /// type of the tensor.expand/collapse_shape op.
657 template <typename TensorReshapeOp>
658 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
660  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
661  PatternRewriter &rewriter) const override {
662  auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
663  if (!oldFill)
664  return failure();
665 
666  Location loc = oldFill.getLoc();
667  TensorReshapeOp newInit;
668  if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
669 
670  newInit = rewriter.create<TensorReshapeOp>(
671  loc, reshapeOp.getResultType(), oldFill.output(),
672  reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
673  reshapeOp.getStaticOutputShape());
674  } else {
675  newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
676  oldFill.output(),
677  reshapeOp.getReassociation());
678  }
679  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
680  ValueRange{newInit});
681  return success();
682  }
683 };
684 
685 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
686 /// filling value are the same.
687 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
689 
690  LogicalResult matchAndRewrite(tensor::PadOp padOp,
691  PatternRewriter &rewriter) const override {
692  auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
693  if (!fillOp)
694  return failure();
695 
696  // We can only fold if the padding value is the same as the original
697  // filling value.
698  Value padValue = padOp.getConstantPaddingValue();
699  if (!padValue || fillOp.value() != padValue)
700  return failure();
701 
702  ReifiedRankedShapedTypeDims reifiedShape;
703  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
704  return rewriter.notifyMatchFailure(
705  padOp, "failed to reify tensor.pad op result shape");
706 
707  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
708  padOp.getLoc(), reifiedShape.front(),
709  padOp.getResultType().getElementType());
710  Value replacement =
711  rewriter
712  .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
713  ValueRange{emptyTensor})
714  .getResult(0);
715  if (replacement.getType() != padOp.getResultType()) {
716  replacement = rewriter.create<tensor::CastOp>(
717  fillOp.getLoc(), padOp.getResultType(), replacement);
718  }
719  rewriter.replaceOp(padOp, replacement);
720  return success();
721  }
722 };
723 
724 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
725 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
726 /// filling value are the same.
727 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
729 
730  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
731  PatternRewriter &rewriter) const override {
732  auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
733  if (!srcPadOp)
734  return failure();
735 
736  if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
737  return failure();
738 
739  // Walk back the tensor.insert_slice chain and find the first destination
740  // value at the start of the chain.
741  Value firstDest = insertOp.getDest();
742  while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
743  if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
744  return failure();
745 
746  // Make sure the range of values accessed are disjoint. Without this, we
747  // cannot fold tensor.pad away.
748  bool disjoint = false;
749  for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
750  // If the dimension has dynamic offset/size, we cannot guarantee
751  // disjoint. So just skip it.
752  if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
753  insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
754  prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
755  continue;
756 
757  // Get the range start and end, inclusively for both.
758  int64_t prevStart = prevOp.getStaticOffset(i);
759  int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
760  prevOp.getStaticStride(i);
761  int64_t nextStart = insertOp.getStaticOffset(i);
762  int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
763  insertOp.getStaticStride(i);
764  if (prevEnd < nextStart || nextEnd < prevStart) {
765  disjoint = true;
766  break;
767  }
768  }
769 
770  if (!disjoint)
771  break;
772  firstDest = prevOp.getDest();
773  }
774 
775  // Check whether the first destination is a fill op. For overlapped cases,
776  // this also cannot be true.
777  auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
778  if (!dstFillOp)
779  return failure();
780 
781  // We can only fold if the padding value is the same as the original
782  // filling value.
783  Value padValue = srcPadOp.getConstantPaddingValue();
784  if (!padValue || dstFillOp.value() != padValue)
785  return failure();
786 
787  SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
788  SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
789 
790  Location loc = insertOp.getLoc();
791  MLIRContext *context = getContext();
792 
793  AffineExpr sym0, sym1;
794  bindSymbols(context, sym0, sym1);
795  auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
796 
797  // Calculate the new offsets for the insert. It should be the old offsets
798  // plus low padding sizes.
799  SmallVector<OpFoldResult, 4> newOffsets;
800  for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
801  newOffsets.push_back(affine::makeComposedFoldedAffineApply(
802  rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
803  }
804 
805  RankedTensorType srcPadType = srcPadOp.getSourceType();
807  for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
808  if (srcPadType.isDynamicDim(i)) {
809  newSizes.push_back(
810  rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
811  .getResult());
812  } else {
813  newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
814  }
815  }
816 
817  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
818  insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
819  newSizes, insertOp.getMixedStrides());
820  return success();
821  }
822 };
823 
824 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
825 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
826 public:
828 
829  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
830  PatternRewriter &rewriter) const override {
831  // See if tensor input of tensor.extract op is the result of a linalg.fill
832  // op.
833  auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
834  if (!fillOp)
835  return failure();
836 
837  // Get scalar input operand of linalg.fill op.
838  Value extractedScalar = fillOp.getInputs()[0];
839 
840  // Replace tensor.extract op with scalar value used to fill the tensor.
841  rewriter.replaceOp(extractOp, extractedScalar);
842  return success();
843  }
844 };
845 
846 /// Folds pack(fill) into a single fill op if
847 /// 1. The pack op does not have padding value, or
848 /// 2. The filled value and padding value are the same.
849 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
850  tensor::PackOp packOp) {
851  auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
852  if (!fillOp)
853  return failure();
854 
855  if (auto paddingValue = packOp.getPaddingValue())
856  if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
857  return failure();
858 
859  Value packOpDest = packOp.getDest();
860  if (!packOpDest.hasOneUse())
861  return failure();
862 
863  return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
864  packOp.getDest());
865 }
866 
867 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
868 struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
869 public:
870  FoldFillWithPack(MLIRContext *context)
871  : OpRewritePattern<tensor::PackOp>(context) {}
872 
873  LogicalResult matchAndRewrite(tensor::PackOp packOp,
874  PatternRewriter &rewriter) const override {
875  auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
876  if (failed(fillOp))
877  return failure();
878  rewriter.replaceOp(packOp, fillOp.value().result());
879  return success();
880  }
881 };
882 
883 /// Fold fill with copy.
884 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
886 
887  LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
888  PatternRewriter &rewriter) const override {
889  if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
890  rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
891  fillOp.getInputs(),
892  copyOp.getOutputs());
893  return success();
894  }
895  if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
896  rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
897  fillOp.getOutputs());
898  return success();
899  }
900  return failure();
901  }
902 };
903 
904 /// Fold fill with transpose.
905 struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
907 
908  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
909  PatternRewriter &rewriter) const override {
910  if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
911  rewriter.replaceOpWithNewOp<FillOp>(
912  transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
913  transposeOp.getDpsInitOperand(0)->get());
914  return success();
915  }
916  return failure();
917  }
918 };
919 
920 /// Fold a concat with all elements being fills of the same value
921 /// into a fill of the concat result shape.
922 struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
924 
925  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
926  PatternRewriter &rewriter) const override {
927  auto concatOperands = concatOp.getInputs();
928  if (concatOperands.empty()) {
929  return failure();
930  }
931 
932  auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
933  if (!firstFillOp) {
934  return failure();
935  }
936  // Prefetch the fill value.
937  OpFoldResult firstFillVal =
938  getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
939  // Collect all the outs values for the fill operations.
940  SmallVector<Value> allOuts;
941  allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
942 
943  auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
944  auto fillOp = v.getDefiningOp<linalg::FillOp>();
945  if (!fillOp) {
946  return false;
947  }
948 
949  OpFoldResult fillVal =
950  getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
951  if (fillVal != firstFillVal)
952  return false;
953 
954  allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
955  return true;
956  };
957  if (!llvm::all_of(concatOperands.drop_front(),
958  isDefinedByCompatibleFillOp)) {
959  return rewriter.notifyMatchFailure(
960  concatOp, "not all operands are defined by a compatible fill op");
961  }
962 
963  Value outsConcat = rewriter.create<tensor::ConcatOp>(
964  concatOp.getLoc(), concatOp.getDim(), allOuts);
965  rewriter.replaceOpWithNewOp<linalg::FillOp>(
966  concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
967  return success();
968  }
969 };
970 
971 } // namespace
972 
973 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
974  MLIRContext *context) {
975  results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
976  FoldFillWithPack, FoldFillWithPad,
977  FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
978  FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
979  FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
980 }
981 
982 //===----------------------------------------------------------------------===//
983 // GenericOp
984 //===----------------------------------------------------------------------===//
985 
986 static void buildGenericRegion(
987  OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
988  ValueRange outputs,
989  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
990  SmallVector<Type, 4> blockArgTypes;
991  SmallVector<Location, 4> blockArgLocs;
992  for (ValueRange container : {inputs, outputs}) {
993  for (Value v : container) {
994  Type t = v.getType();
995  blockArgTypes.push_back(
996  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
997  blockArgLocs.push_back(v.getLoc());
998  }
999  }
1000 
1001  OpBuilder::InsertionGuard guard(builder);
1002  Block *bodyBlock =
1003  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1004  bodyBuild(builder, loc, bodyBlock->getArguments());
1005 }
1006 
1007 void GenericOp::getAsmBlockArgumentNames(Region &region,
1008  OpAsmSetValueNameFn setNameFn) {
1009  for (Value v : getRegionInputArgs())
1010  setNameFn(v, "in");
1011  for (Value v : getRegionOutputArgs())
1012  setNameFn(v, "out");
1013 }
1014 
1015 void GenericOp::build(
1016  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1017  ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1018  ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1019  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1020  ArrayRef<NamedAttribute> attributes) {
1021  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1022  iteratorTypes, doc, libraryCall);
1023  result.addAttributes(attributes);
1024  if (bodyBuild)
1025  buildGenericRegion(builder, result.location, *result.regions.front(),
1026  inputs, outputs, bodyBuild);
1027 }
1028 
1029 void GenericOp::build(
1030  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1031  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1032  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1033  StringRef libraryCall,
1034  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1035  ArrayRef<NamedAttribute> attributes) {
1036  build(builder, result, resultTensorTypes, inputs, outputs,
1037  builder.getAffineMapArrayAttr(indexingMaps),
1038  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1039  iteratorTypes,
1040  [&](utils::IteratorType iter) -> mlir::Attribute {
1041  return IteratorTypeAttr::get(builder.getContext(), iter);
1042  }))),
1043  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1044  libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1045  bodyBuild, attributes);
1046 }
1047 
1048 void GenericOp::build(
1049  OpBuilder &builder, OperationState &result, ValueRange inputs,
1050  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1051  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1052  StringRef libraryCall,
1053  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1054  ArrayRef<NamedAttribute> attributes) {
1055  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1056  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1057 }
1058 
1059 void GenericOp::build(
1060  OpBuilder &builder, OperationState &result, ValueRange inputs,
1061  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1062  ArrayRef<utils::IteratorType> iteratorTypes,
1063  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1064  ArrayRef<NamedAttribute> attributes) {
1065  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1066  /*doc=*/"",
1067  /*libraryCall=*/"", bodyBuild, attributes);
1068 }
1069 
1070 void GenericOp::build(
1071  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1072  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1073  ArrayRef<utils::IteratorType> iteratorTypes,
1074  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1075  ArrayRef<NamedAttribute> attributes) {
1076  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1077  iteratorTypes,
1078  /*doc=*/"",
1079  /*libraryCall=*/"", bodyBuild, attributes);
1080 }
1081 
1082 void GenericOp::print(OpAsmPrinter &p) {
1083  p << " ";
1084 
1085  // Print extra attributes.
1086  auto genericAttrNames = linalgTraitAttrNames();
1087 
1088  llvm::StringSet<> genericAttrNamesSet;
1089  genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
1090  SmallVector<NamedAttribute, 8> genericAttrs;
1091  for (auto attr : (*this)->getAttrs()) {
1092  if (attr.getName() == getIteratorTypesAttrName()) {
1093  auto iteratorTypes =
1094  llvm::cast<ArrayAttr>(attr.getValue())
1095  .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1096  // Convert IteratorType enums into the string representation. This is
1097  // needed, because tests still use the old format when 'iterator_types'
1098  // attribute is represented as an array of strings.
1099  // TODO: Remove this conversion once tests are fixed.
1100  SmallVector<Attribute> iteratorTypeNames =
1101  llvm::to_vector(llvm::map_range(
1102  iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1103  return StringAttr::get(getContext(), stringifyIteratorType(t));
1104  }));
1105 
1106  genericAttrs.emplace_back(
1107  getIteratorTypesAttrName(),
1108  ArrayAttr::get(getContext(), iteratorTypeNames));
1109  } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1110  genericAttrs.push_back(attr);
1111  }
1112  }
1113  if (!genericAttrs.empty()) {
1114  auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1115  p << genericDictAttr;
1116  }
1117 
1118  // Printing is shared with named ops, except for the region and attributes
1119  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1120 
1121  genericAttrNames.push_back("operandSegmentSizes");
1122  genericAttrNamesSet.insert(genericAttrNames.back());
1123 
1124  bool hasExtraAttrs = false;
1125  for (NamedAttribute n : (*this)->getAttrs()) {
1126  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1127  break;
1128  }
1129  if (hasExtraAttrs) {
1130  p << " attrs = ";
1131  p.printOptionalAttrDict((*this)->getAttrs(),
1132  /*elidedAttrs=*/genericAttrNames);
1133  }
1134 
1135  // Print region.
1136  if (!getRegion().empty()) {
1137  p << ' ';
1138  p.printRegion(getRegion());
1139  }
1140 
1141  // Print results.
1142  printNamedStructuredOpResults(p, getResultTensors().getTypes());
1143 }
1144 
1145 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1146  DictionaryAttr dictAttr;
1147  // Parse the core linalg traits that must check into a dictAttr.
1148  // The name is unimportant as we will overwrite result.attributes.
1149  // The core linalg traits must contain the information necessary to pass the
1150  // verifier.
1151  llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1152  if (parser.parseAttribute(dictAttr, "_", result.attributes))
1153  return failure();
1154  result.attributes.assign(dictAttr.getValue().begin(),
1155  dictAttr.getValue().end());
1156 
1157  // Convert array of string into an array of IteratorType enums. This is
1158  // needed, because tests still use the old format when 'iterator_types'
1159  // attribute is represented as an array of strings.
1160  // TODO: Remove this conversion once tests are fixed.
1161  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1162  result.attributes.get(getIteratorTypesAttrName(result.name)));
1163  if (!iteratorTypes) {
1164  return parser.emitError(attributeLocation)
1165  << "expected " << getIteratorTypesAttrName(result.name)
1166  << " array attribute";
1167  }
1168 
1169  SmallVector<Attribute> iteratorTypeAttrs;
1170 
1171  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1172  auto maybeIteratorType = utils::symbolizeIteratorType(s);
1173  if (!maybeIteratorType.has_value())
1174  return parser.emitError(parser.getCurrentLocation())
1175  << "unexpected iterator_type (" << s << ")";
1176 
1177  iteratorTypeAttrs.push_back(
1178  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1179  }
1180  result.attributes.set(getIteratorTypesAttrName(result.name),
1181  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1182 
1183  // Parsing is shared with named ops, except for the region.
1184  SmallVector<Type, 1> inputTypes, outputTypes;
1185  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1186  return failure();
1187 
1188  // Optional attributes may be added.
1189  if (succeeded(parser.parseOptionalKeyword("attrs")))
1190  if (failed(parser.parseEqual()) ||
1191  failed(parser.parseOptionalAttrDict(result.attributes)))
1192  return failure();
1193 
1194  std::unique_ptr<Region> region = std::make_unique<Region>();
1195  if (parser.parseRegion(*region, {}))
1196  return failure();
1197  result.addRegion(std::move(region));
1198 
1199  // Generic ops may specify that a subset of its outputs are tensors. Such
1200  // outputs are specified in the result type.
1201  // TODO: may need to move output parsing before region parsing.
1202  // Need to wait for declarative assembly resolution to decide.
1203  SmallVector<Type, 1> outputTensorsTypes;
1204  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1205  return failure();
1206  result.addTypes(outputTensorsTypes);
1207 
1208  return success();
1209 }
1210 
1213  &effects,
1214  LinalgOp linalgOp) {
1215  for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1216  if (!llvm::isa<MemRefType>(operand.getType()))
1217  continue;
1218  effects.emplace_back(
1219  MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1220  /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1221  }
1222 
1223  for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1224  if (!llvm::isa<MemRefType>(operand.get().getType()))
1225  continue;
1226  if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1227  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1228  /*effectOnFullRegion=*/true,
1230  }
1231  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1232  /*effectOnFullRegion=*/true,
1234  }
1235 }
1236 
1237 void GenericOp::getEffects(
1239  &effects) {
1240  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1241 }
1242 
1244 getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
1245  // Operands with value semantics are speculatable, while operands with memory
1246  // semantics are not.
1247  if (!linalgOp.hasPureTensorSemantics())
1249  // The body of the op can still have speculation in its region.
1251 }
1252 
1253 Speculation::Speculatability GenericOp::getSpeculatability() {
1254  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1255 }
1256 
1257 LogicalResult GenericOp::verify() { return success(); }
1258 
1259 namespace {
1260 
1261 /// Remove any linalg operation (on tensors) that are just copying
1262 /// the values from inputs to the results. Requirements are
1263 /// 1) All iterator types are parallel
1264 /// 2) The body contains just a yield operation with the yielded values being
1265 /// the arguments corresponding to the operands.
1266 template <typename OpTy>
1267 struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1269 
1270  LogicalResult matchAndRewrite(OpTy linalgOp,
1271  PatternRewriter &rewriter) const override {
1272  // All indexing maps must be equal. It follows that they are permutations.
1273  if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1274  return failure();
1275 
1276  // Check that the body of the linalg operation is just a linalg.yield
1277  // operation.
1278  Block &body = linalgOp->getRegion(0).front();
1279  if (!llvm::hasSingleElement(body))
1280  return failure();
1281  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1282  if (!yieldOp)
1283  return failure();
1284 
1285  // In the buffer case, we need to check exact buffer equality.
1286  if (linalgOp.hasPureBufferSemantics()) {
1287  if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1288  linalgOp.getDpsInputOperand(0)->get() ==
1289  linalgOp.getDpsInitOperand(0)->get()) {
1290  rewriter.eraseOp(linalgOp);
1291  return success();
1292  }
1293  return failure();
1294  }
1295 
1296  // Mixed semantics is not supported yet.
1297  if (!linalgOp.hasPureTensorSemantics())
1298  return failure();
1299 
1300  // Get the argument number of the returned values. That is the operand
1301  // number to use for replacing uses of this operation.
1302  SmallVector<Value> returnedArgs;
1303  for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1304  auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1305  if (!yieldArg || yieldArg.getOwner() != &body)
1306  return failure();
1307  unsigned argumentNumber = yieldArg.getArgNumber();
1308  Value returnedArg = linalgOp->getOperand(argumentNumber);
1309  Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1310  // The input can have a different type than the result, e.g. a dynamic
1311  // input dimension can be turned into a static output dimension.
1312  Type returnType = returnedArg.getType();
1313  if (returnType != resultType) {
1314  // Distinguish between sparse conversion or dense tensor casting.
1315  // TODO: unify the two ops?
1316  if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1318  returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1319  linalgOp.getLoc(), resultType, returnedArg);
1320  else {
1321  if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1322  resultType))
1323  return failure();
1324  returnedArg = rewriter.create<tensor::CastOp>(
1325  linalgOp.getLoc(), resultType, returnedArg);
1326  }
1327  }
1328  returnedArgs.push_back(returnedArg);
1329  }
1330 
1331  if (returnedArgs.size() != linalgOp->getNumResults())
1332  return failure();
1333  rewriter.replaceOp(linalgOp, returnedArgs);
1334  return success();
1335  }
1336 };
1337 
1338 } // namespace
1339 
1340 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1341  MLIRContext *context) {
1342  results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1343 }
1344 
1345 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1346  return memref::foldMemRefCast(*this);
1347 }
1348 
1349 //===----------------------------------------------------------------------===//
1350 // MapOp
1351 //===----------------------------------------------------------------------===//
1352 
1353 static ParseResult parseDstStyleOp(
1354  OpAsmParser &parser, OperationState &result,
1355  function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1356  nullptr) {
1357  // Parse `ins` and `outs`.
1358  SmallVector<Type, 4> inputTypes, outputTypes;
1359  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1360  /*addOperandSegmentSizes=*/false))
1361  return failure();
1362 
1363  // Add result types.
1364  for (Type outputType : outputTypes) {
1365  if (llvm::isa<RankedTensorType>(outputType))
1366  result.addTypes(outputType);
1367  }
1368 
1369  // Parse required attributes.
1370  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1371  return failure();
1372 
1373  // Parse optional attributes.
1374  if (parser.parseOptionalAttrDict(result.attributes))
1375  return failure();
1376  return success();
1377 }
1378 
1379 void MapOp::getAsmBlockArgumentNames(Region &region,
1380  OpAsmSetValueNameFn setNameFn) {
1381  for (Value v : getRegionInputArgs())
1382  setNameFn(v, "in");
1383 }
1384 
1385 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1386  if (!getResults().empty())
1387  setNameFn(getResults().front(), "mapped");
1388 }
1389 
1390 void MapOp::build(
1391  OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1392  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1393  ArrayRef<NamedAttribute> attributes) {
1394  build(builder, result, TypeRange{}, inputs, init);
1395  result.addAttributes(attributes);
1396 
1397  // Add output types for `RankedTensorType` output arguments.
1398  Type initType = init.getType();
1399  if (llvm::isa<RankedTensorType>(initType))
1400  result.addTypes(initType);
1401 
1402  if (bodyBuild)
1403  buildGenericRegion(builder, result.location, *result.regions.front(),
1404  inputs, /*outputs=*/{}, bodyBuild);
1405 }
1406 
1407 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1408  const OperationName &payloadOpName,
1409  const NamedAttrList &payloadOpAttrs,
1410  ArrayRef<Value> operands,
1411  bool initFirst = false) {
1412  OpBuilder b(parser.getContext());
1413  Region *body = result.addRegion();
1414  Block &block = body->emplaceBlock();
1415  b.setInsertionPointToStart(&block);
1416  SmallVector<Value> bbArgs;
1417  for (auto &operand : operands) {
1418  block.addArgument(
1419  llvm::cast<ShapedType>(operand.getType()).getElementType(),
1420  b.getUnknownLoc());
1421  }
1422  SmallVector<Value> payloadOpOperands;
1423  // If initFirst flag is enabled, we consider init as the first position of
1424  // payload operands.
1425  if (initFirst) {
1426  payloadOpOperands.push_back(block.getArguments().back());
1427  for (const auto &arg : block.getArguments().drop_back())
1428  payloadOpOperands.push_back(arg);
1429  } else {
1430  payloadOpOperands = {block.getArguments().begin(),
1431  block.getArguments().end()};
1432  }
1433 
1434  Operation *payloadOp = b.create(
1435  result.location, b.getStringAttr(payloadOpName.getStringRef()),
1436  payloadOpOperands,
1437  TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1438  .getElementType()},
1439  payloadOpAttrs);
1440  b.create<YieldOp>(result.location, payloadOp->getResults());
1441 }
1442 
1443 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1444  std::optional<OperationName> payloadOpName;
1445  NamedAttrList payloadOpAttrs;
1446  if (succeeded(parser.parseOptionalLBrace())) {
1447  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1448  if (failed(operationName))
1449  return failure();
1450  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1451  return failure();
1452  payloadOpName = operationName.value();
1453  if (parser.parseRBrace())
1454  return failure();
1455  }
1456 
1457  if (parseDstStyleOp(parser, result))
1458  return failure();
1459 
1460  if (payloadOpName.has_value()) {
1461  if (!result.operands.empty())
1462  addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1463  payloadOpAttrs,
1464  ArrayRef(result.operands).drop_back());
1465  else
1466  result.addRegion();
1467  } else {
1469  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1470  /*allowType=*/true, /*allowAttrs=*/true)) {
1471  return failure();
1472  }
1473  Region *body = result.addRegion();
1474  if (parser.parseRegion(*body, regionArgs))
1475  return failure();
1476  }
1477  return success();
1478 }
1479 
1480 // Retrieve the operation from the body, if it is the only one (except
1481 // yield) and if it gets the same amount of arguments as the body does.
1482 // If initFirst flag is enabled, we check that init takes the first position in
1483 // operands of payload.
1484 static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1485  if (body->getOperations().size() != 2)
1486  return nullptr;
1487  Operation &payload = body->getOperations().front();
1488  assert(isa<YieldOp>(body->getOperations().back()));
1489 
1490  if (payload.getNumOperands() == 0 ||
1491  payload.getNumOperands() != body->getNumArguments())
1492  return nullptr;
1493  if (initFirst) {
1494  // check init
1495  if (payload.getOperands().back() != body->getArgument(0))
1496  return nullptr;
1497  // check rest
1498  for (const auto &[operand, bbArg] :
1499  llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1500  if (bbArg != operand)
1501  return nullptr;
1502  }
1503  } else {
1504  for (const auto &[operand, bbArg] :
1505  llvm::zip(payload.getOperands(), body->getArguments())) {
1506  if (bbArg != operand)
1507  return nullptr;
1508  }
1509  }
1510  return &payload;
1511 }
1512 
1513 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1514  SmallVector<StringRef> elidedAttrs;
1515  std::string attrToElide;
1516  p << " { " << payloadOp->getName().getStringRef();
1517  for (const auto &attr : payloadOp->getAttrs()) {
1518  auto fastAttr =
1519  llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1520  if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1521  attrToElide = attr.getName().str();
1522  elidedAttrs.push_back(attrToElide);
1523  break;
1524  }
1525  }
1526  p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1527  p << " }";
1528 }
1529 
1530 void MapOp::print(OpAsmPrinter &p) {
1531  Block *mapper = getBody();
1532  Operation *payloadOp = findPayloadOp(mapper);
1533  if (payloadOp) {
1534  printShortForm(p, payloadOp);
1535  }
1536 
1537  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1538  p.printOptionalAttrDict((*this)->getAttrs());
1539 
1540  if (!payloadOp) {
1541  // Print region if the payload op was not detected.
1542  p.increaseIndent();
1543  p.printNewline();
1544  p << "(";
1545  llvm::interleaveComma(mapper->getArguments(), p,
1546  [&](auto arg) { p.printRegionArgument(arg); });
1547  p << ") ";
1548 
1549  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1550  p.decreaseIndent();
1551  }
1552 }
1553 
1554 LogicalResult MapOp::verify() {
1555  auto *bodyBlock = getBody();
1556  auto blockArgs = bodyBlock->getArguments();
1557 
1558  // Checks if the number of `inputs` match the arity of the `mapper` region.
1559  if (getInputs().size() != blockArgs.size())
1560  return emitOpError() << "expects number of operands to match the arity of "
1561  "mapper, but got: "
1562  << getInputs().size() << " and " << blockArgs.size();
1563 
1564  // The parameters of mapper should all match the element type of inputs.
1565  for (const auto &[bbArgType, inputArg] :
1566  llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1567  auto inputElemType =
1568  llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1569  if (bbArgType != inputElemType) {
1570  return emitOpError() << "expected element type of input " << inputElemType
1571  << " to match bbArg type " << bbArgType;
1572  }
1573  }
1574 
1575  // The shape of each input must match the shape of the output.
1576  auto outputShape = getInit().getType().getShape();
1577  for (Type inputArgType : TypeRange{getInputs()}) {
1578  auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1579  if (inputElemShape != outputShape) {
1580  return emitOpError() << "expected shape of input (" << inputElemShape
1581  << ") to match shape of output (" << outputShape
1582  << ")";
1583  }
1584  }
1585 
1586  return success();
1587 }
1588 
1589 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1590  int64_t rank = getInit().getType().getRank();
1591  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1592 }
1593 
1594 ArrayAttr MapOp::getIndexingMaps() {
1595  Builder builder(getContext());
1596  int64_t rank = getInit().getType().getRank();
1597  int64_t numIndexingMaps = getOperands().size();
1599  numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1600 }
1601 
1602 void MapOp::getEffects(
1604  &effects) {
1605  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1606 }
1607 
1608 Speculation::Speculatability MapOp::getSpeculatability() {
1609  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1610 }
1611 
1612 //===----------------------------------------------------------------------===//
1613 // ReduceOp
1614 //===----------------------------------------------------------------------===//
1615 
1616 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1617  OpAsmSetValueNameFn setNameFn) {
1618  for (Value v : getRegionInputArgs())
1619  setNameFn(v, "in");
1620  for (Value v : getRegionOutputArgs())
1621  setNameFn(v, "init");
1622 }
1623 
1624 void ReduceOp::getAsmResultNames(
1625  function_ref<void(Value, StringRef)> setNameFn) {
1626  if (!getResults().empty())
1627  setNameFn(getResults().front(), "reduced");
1628 }
1629 
1630 void ReduceOp::build(
1631  OpBuilder &builder, OperationState &result, ValueRange inputs,
1632  ValueRange inits, ArrayRef<int64_t> dimensions,
1633  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1634  ArrayRef<NamedAttribute> attributes) {
1635  build(builder, result, TypeRange{}, inputs, inits, dimensions);
1636  result.addAttributes(attributes);
1637 
1638  // Add output types for `RankedTensorType` output arguments.
1639  for (Value init : inits) {
1640  Type initType = init.getType();
1641  if (llvm::isa<RankedTensorType>(initType))
1642  result.addTypes(initType);
1643  }
1644 
1645  if (bodyBuild)
1646  buildGenericRegion(builder, result.location, *result.regions.front(),
1647  inputs, inits, bodyBuild);
1648 }
1649 
1650 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1651  int64_t inputRank =
1652  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1653  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1654  utils::IteratorType::parallel);
1655  for (int64_t reductionDim : getDimensions())
1656  iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1657  return iteratorTypes;
1658 }
1659 
1660 ArrayAttr ReduceOp::getIndexingMaps() {
1661  int64_t inputRank =
1662  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1663  SmallVector<AffineMap> affineMaps(
1664  getNumDpsInputs(),
1666  AffineMap resultMap =
1668  .dropResults(getDimensions());
1669  for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1670  affineMaps.push_back(resultMap);
1671  return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1672 }
1673 
1674 void ReduceOp::getEffects(
1676  &effects) {
1677  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1678 }
1679 
1680 Speculation::Speculatability ReduceOp::getSpeculatability() {
1681  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1682 }
1683 
1684 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1685  NamedAttrList &attributes,
1686  StringRef attributeName) {
1687  if (parser.parseKeyword(attributeName) || parser.parseEqual())
1688  return failure();
1689 
1690  attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1691  return success();
1692 }
1693 
1694 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1695  std::optional<OperationName> payloadOpName;
1696  NamedAttrList payloadOpAttrs;
1697  if (succeeded(parser.parseOptionalLBrace())) {
1698  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1699  if (failed(operationName))
1700  return failure();
1701  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1702  return failure();
1703  payloadOpName = operationName.value();
1704  if (parser.parseRBrace())
1705  return failure();
1706  }
1707 
1708  if (parseDstStyleOp(
1709  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1710  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1711  }))
1712  return failure();
1713 
1714  if (payloadOpName.has_value()) {
1715  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1716  ArrayRef(result.operands), /*initFirst=*/true);
1717  } else {
1719  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1720  /*allowType=*/true, /*allowAttrs=*/true)) {
1721  return failure();
1722  }
1723 
1724  Region *body = result.addRegion();
1725  if (parser.parseRegion(*body, regionArgs))
1726  return failure();
1727  }
1728 
1729  return success();
1730 }
1731 
1732 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1733  ArrayRef<int64_t> attributeValue) {
1734  p << ' ' << attributeName << " = [" << attributeValue << "] ";
1735 }
1736 
1737 void ReduceOp::print(OpAsmPrinter &p) {
1738  Block *mapper = getBody();
1739  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1740  if (payloadOp) {
1741  printShortForm(p, payloadOp);
1742  }
1743 
1744  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1745  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1746  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1747  if (!payloadOp) {
1748  // Print region if the payload op was not detected.
1749  p.increaseIndent();
1750  p.printNewline();
1751  p << "(";
1752  llvm::interleaveComma(mapper->getArguments(), p,
1753  [&](auto arg) { p.printRegionArgument(arg); });
1754  p << ") ";
1755 
1756  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1757  p.decreaseIndent();
1758  }
1759 }
1760 
1761 LogicalResult ReduceOp::verify() {
1762  ArrayRef<int64_t> dimensionsRef = getDimensions();
1763 
1764  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1765  if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1766  llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1767  return emitOpError() << "expects all inputs to have the same shapes. "
1768  "Shape at input-index "
1769  << i
1770  << " is not equal to the shape at input-index 0.";
1771  }
1772  }
1773  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1774  if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1775  llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1776  return emitOpError() << "expects all outputs to have the same shapes. "
1777  "Shape at output-index "
1778  << i
1779  << " is not equal to the shape at output-index 0.";
1780  }
1781  }
1782  auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1783  auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1784 
1785  DenseSet<int64_t> dimensionsToReduce;
1786  for (int64_t dimension : dimensionsRef) {
1787  if (dimension < 0 || dimension >= inputType.getRank()) {
1788  return emitOpError()
1789  << "dimensions for reduction should be in the range [0, "
1790  << inputType.getRank() - 1 << "].";
1791  }
1792  dimensionsToReduce.insert(dimension);
1793  }
1794 
1795  auto inputDims = inputType.getShape();
1796  auto initDims = initType.getShape();
1797 
1798  // Input dimensions that will be left after the reduction.
1799  SmallVector<int64_t> reducedInputDims;
1800  for (const auto &en : llvm::enumerate(inputDims)) {
1801  if (!dimensionsToReduce.count(en.index()))
1802  reducedInputDims.push_back(en.value());
1803  }
1804 
1805  if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1806  return emitOpError() << "number of dimensions after reduction "
1807  << reducedInputDims.size()
1808  << " doesn't match the init rank "
1809  << initType.getRank();
1810  }
1811 
1812  if (reducedInputDims != initDims)
1813  return emitOpError() << "init dimensions [" << initDims
1814  << "] doesn't match input dimensions after reduction ["
1815  << reducedInputDims << "]";
1816 
1817  Block *block = getBody();
1818  if (block->getNumArguments() != this->getNumOperands())
1819  return emitOpError()
1820  << "mismatching number of operands and block arguments";
1821 
1822  // Check that the first block arguments match the element type of the inputs.
1823  for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1824  Type inputElementType =
1825  llvm::cast<ShapedType>(input.getType()).getElementType();
1826  if (inputElementType != bbArg.getType())
1827  return emitOpError()
1828  << "input element type " << inputElementType
1829  << " does not match corresponding block argument type "
1830  << bbArg.getType();
1831  }
1832 
1833  // Check that the last block arguments match the element type of the outputs.
1834  for (auto [output, bbArg] : llvm::zip(
1835  getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1836  auto outputElementType =
1837  llvm::cast<ShapedType>(output.getType()).getElementType();
1838  if (outputElementType != bbArg.getType())
1839  return emitOpError()
1840  << "output element type " << outputElementType
1841  << " does not match corresponding block argument type "
1842  << bbArg.getType();
1843  }
1844  return success();
1845 }
1846 
1847 //===----------------------------------------------------------------------===//
1848 // TransposeOp
1849 //===----------------------------------------------------------------------===//
1850 
1851 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1852  Region &region, ValueRange inputs,
1853  ValueRange outputs) {
1854  buildGenericRegion(builder, loc, region, inputs, outputs,
1855  [](OpBuilder &b, Location loc, ValueRange args) {
1856  if (!args.empty())
1857  b.create<linalg::YieldOp>(loc, args[0]);
1858  });
1859 }
1860 
1861 void TransposeOp::build(::mlir::OpBuilder &builder,
1862  ::mlir::OperationState &result, Value input, Value init,
1863  DenseI64ArrayAttr permutation,
1864  ArrayRef<NamedAttribute> attributes) {
1865  result.addOperands(input);
1866  result.addOperands(init);
1867  result.addAttribute(getPermutationAttrName(result.name), permutation);
1868  result.addAttributes(attributes);
1869 
1870  // Add output types for `RankedTensorType` output arguments.
1871  Type initType = init.getType();
1872  if (llvm::isa<RankedTensorType>(initType))
1873  result.addTypes(initType);
1874 
1875  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1876  init);
1877 }
1878 
1879 void TransposeOp::build(::mlir::OpBuilder &builder,
1880  ::mlir::OperationState &result, Value input, Value init,
1881  ArrayRef<int64_t> permutation,
1882  ArrayRef<NamedAttribute> attributes) {
1883  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1884  attributes);
1885 }
1886 
1887 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1888  if (failed(parseDstStyleOp(
1889  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1890  return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1891  })))
1892  return failure();
1893 
1894  OpBuilder builder(parser.getContext());
1895  buildIdentityRegion(builder, result.location, *result.addRegion(),
1896  /*inputs=*/result.operands,
1897  /*outputs=*/{});
1898  return success();
1899 }
1900 
1901 void TransposeOp::getAsmResultNames(
1902  function_ref<void(Value, StringRef)> setNameFn) {
1903  if (!getResults().empty())
1904  setNameFn(getResults().front(), "transposed");
1905 }
1906 
1908  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1909  printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1910  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1911 }
1912 
1913 LogicalResult TransposeOp::verify() {
1914  ArrayRef<int64_t> permutationRef = getPermutation();
1915 
1916  if (!isPermutationVector(permutationRef))
1917  return emitOpError("permutation is not valid");
1918 
1919  auto inputType = getInput().getType();
1920  auto initType = getInit().getType();
1921 
1922  int64_t rank = inputType.getRank();
1923 
1924  if (rank != initType.getRank())
1925  return emitOpError() << "input rank " << rank
1926  << " does not match init rank " << initType.getRank();
1927 
1928  if (rank != static_cast<int64_t>(permutationRef.size()))
1929  return emitOpError() << "size of permutation " << permutationRef.size()
1930  << " does not match the argument rank " << rank;
1931 
1932  auto inputDims = inputType.getShape();
1933  auto initDims = initType.getShape();
1934 
1935  for (int64_t i = 0; i < rank; ++i) {
1936  int64_t inputDim = inputDims[permutationRef[i]];
1937  int64_t initDim = initDims[i];
1938 
1939  if (inputDim != initDim) {
1940  return emitOpError() << "dim(result, " << i << ") = " << initDim
1941  << " doesn't match dim(input, permutation[" << i
1942  << "]) = " << inputDim;
1943  }
1944  }
1945 
1946  return success();
1947 }
1948 
1949 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1950  int64_t rank = getInit().getType().getRank();
1951  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1952 }
1953 
1954 ArrayAttr TransposeOp::getIndexingMaps() {
1955  Builder builder(getContext());
1956  int64_t rank = getInit().getType().getRank();
1957  return builder.getAffineMapArrayAttr(
1959  llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
1960  builder.getMultiDimIdentityMap(rank)});
1961 }
1962 
1963 void TransposeOp::getEffects(
1965  &effects) {
1966  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1967 }
1968 
1969 Speculation::Speculatability TransposeOp::getSpeculatability() {
1970  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1971 }
1972 
1973 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1975  // Only the tensor type is supported.
1976  if (!isa<TensorType>(getInput().getType()))
1977  return failure();
1978 
1979  // Single dimension transpose.
1980  if (getPermutation().size() == 0) {
1981  result.push_back(getInput());
1982  return success();
1983  }
1984  // Identity permutation.
1985  if (isIdentityPermutation(getPermutation())) {
1986  result.push_back(getInput());
1987  return success();
1988  }
1989 
1990  return failure();
1991 }
1992 
1993 /// Fold transpose with transpose.
1994 struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1996 
1997  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1998  PatternRewriter &rewriter) const override {
1999  auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2000  if (!defTransposeOp)
2001  return failure();
2002  ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2003  ArrayRef<int64_t> perms = transposeOp.getPermutation();
2004  SmallVector<int64_t> foldedPerms;
2005  foldedPerms.reserve(perms.size());
2006  for (int64_t perm : perms)
2007  foldedPerms.push_back(defPerms[perm]);
2008 
2009  rewriter.replaceOpWithNewOp<TransposeOp>(
2010  transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2011  foldedPerms);
2012  return success();
2013  }
2014 };
2015 
2016 /// This pattern canonicalize transpose by swapping the order of
2017 /// broadcast and transpose:
2018 /// transpose(broadcast(input)) -> broadcast(transpose(input))
2019 struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2021 
2022  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2023  PatternRewriter &rewriter) const override {
2024  Value input = transposeOp.getInput();
2025  BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2026  if (!input.hasOneUse() || !broadcastOp)
2027  return failure();
2028 
2029  ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2030  ArrayRef<int64_t> perms = transposeOp.getPermutation();
2031 
2032  // Get new perms and new dimensions.
2033  SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2034  SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
2035  SmallVector<int64_t> resultDimensions;
2036  unsigned dimensionSize = dimensions.size();
2037  for (unsigned i = 0; i < dimensionSize; ++i)
2038  resultDimensions.push_back(invertPerm[dimensions[i]]);
2039 
2040  // Create transpose result.
2041  Value broadcastInput = broadcastOp.getInput();
2042  Location loc = transposeOp.getLoc();
2043  MLIRContext *ctx = transposeOp.getContext();
2045  auto broadcastInputTy =
2046  mlir::cast<RankedTensorType>(broadcastInput.getType());
2047  unsigned inputRank = broadcastInputTy.getRank();
2048  for (unsigned i = 0; i < inputRank; ++i) {
2049  if (broadcastInputTy.isDynamicDim(i)) {
2050  dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
2051  ->getResult(0));
2052  } else {
2053  dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2054  broadcastInputTy.getDimSize(i)));
2055  }
2056  }
2057  SmallVector<OpFoldResult> transposeResultShapes =
2058  applyPermutation(dims, resultPerms);
2059  Value transposeInit = rewriter.create<tensor::EmptyOp>(
2060  transposeOp.getLoc(), transposeResultShapes,
2061  broadcastInputTy.getElementType());
2062 
2063  // Create broadcast(transpose(input)).
2064  Value transposeResult =
2065  rewriter
2066  .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2067  resultPerms)
2068  ->getResult(0);
2069  rewriter.replaceOpWithNewOp<BroadcastOp>(
2070  transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2071  return success();
2072  }
2073 };
2074 
2075 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2076  MLIRContext *context) {
2078 }
2079 
2080 //===----------------------------------------------------------------------===//
2081 // BroadcastOp
2082 //===----------------------------------------------------------------------===//
2083 
2084 void BroadcastOp::build(::mlir::OpBuilder &builder,
2085  ::mlir::OperationState &result, Value input, Value init,
2086  DenseI64ArrayAttr dimensions,
2087  ArrayRef<NamedAttribute> attributes) {
2088  result.addOperands(input);
2089  result.addOperands(init);
2090  result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2091  result.addAttributes(attributes);
2092 
2093  // Add output types for `RankedTensorType` output arguments.
2094  Type initType = init.getType();
2095  if (llvm::isa<RankedTensorType>(initType))
2096  result.addTypes(initType);
2097 
2098  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2099  init);
2100 }
2101 
2102 void BroadcastOp::build(::mlir::OpBuilder &builder,
2103  ::mlir::OperationState &result, Value input, Value init,
2104  ArrayRef<int64_t> dimensions,
2105  ArrayRef<NamedAttribute> attributes) {
2106  build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2107  attributes);
2108 }
2109 
2110 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2111  if (failed(parseDstStyleOp(
2112  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2113  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2114  })))
2115  return failure();
2116 
2117  OpBuilder builder(parser.getContext());
2118  buildIdentityRegion(builder, result.location, *result.addRegion(),
2119  /*inputs=*/result.operands,
2120  /*outputs=*/{});
2121  return success();
2122 }
2123 
2124 void BroadcastOp::getAsmResultNames(
2125  function_ref<void(Value, StringRef)> setNameFn) {
2126  if (!getResults().empty())
2127  setNameFn(getResults().front(), "broadcasted");
2128 }
2129 
2131  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2132  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2133  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2134 }
2135 
2136 LogicalResult BroadcastOp::verify() {
2137  ArrayRef<int64_t> dimensionsRef = getDimensions();
2138 
2139  auto inputType = getInput().getType();
2140  auto initType = getInit().getType();
2141 
2142  int64_t inputRank = inputType.getRank();
2143  int64_t initRank = initType.getRank();
2144 
2145  auto inputShape = inputType.getShape();
2146  auto initShape = initType.getShape();
2147 
2148  if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2149  return emitOpError() << "input rank plus added dimensions does not "
2150  "match init rank. input rank: "
2151  << inputRank
2152  << ", dimensions size: " << dimensionsRef.size()
2153  << ", init rank: " << initRank;
2154 
2155  for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2156  if (dim < 0 || dim >= initRank)
2157  return emitOpError() << "dimension " << idx
2158  << " is out of range. expected range: [0, "
2159  << initRank - 1 << "], got: " << dim;
2160  }
2161 
2162  // Mapping from input dims to init dims.
2163  SmallVector<int64_t> dimMap;
2164  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2165  if (!llvm::is_contained(dimensionsRef, dim))
2166  dimMap.push_back(dim);
2167  }
2168 
2169  for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2170  // This dimensions is mapped from the input. Init and input dims should
2171  // match.
2172  if (inputShape[inputDimIdx] != initShape[initDimIdx])
2173  return emitOpError() << "input dim " << inputDimIdx
2174  << " should match init dim " << initDimIdx
2175  << ". input: " << inputShape[inputDimIdx]
2176  << ", init: " << initShape[initDimIdx];
2177  }
2178 
2179  return success();
2180 }
2181 
2182 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2183  int64_t rank = getInit().getType().getRank();
2184  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2185 }
2186 
2187 ArrayAttr BroadcastOp::getIndexingMaps() {
2188  Builder builder(getContext());
2189  int64_t rank = getInit().getType().getRank();
2190  return builder.getAffineMapArrayAttr(
2191  {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2192  builder.getMultiDimIdentityMap(rank)});
2193 }
2194 
2195 void BroadcastOp::getEffects(
2197  &effects) {
2198  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2199 }
2200 
2201 Speculation::Speculatability BroadcastOp::getSpeculatability() {
2202  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2203 }
2204 
2205 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2206  MLIRContext *context) {
2207  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2208 }
2209 
2210 //===----------------------------------------------------------------------===//
2211 // YieldOp
2212 //===----------------------------------------------------------------------===//
2213 
2215  if (getNumOperands() > 0)
2216  p << ' ' << getOperands();
2217  p.printOptionalAttrDict((*this)->getAttrs());
2218  if (getNumOperands() > 0)
2219  p << " : " << getOperandTypes();
2220 }
2221 
2222 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2224  SmallVector<Type, 2> types;
2225  SMLoc loc = parser.getCurrentLocation();
2226  return failure(parser.parseOperandList(opInfo) ||
2227  parser.parseOptionalAttrDict(result.attributes) ||
2228  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2229  parser.resolveOperands(opInfo, types, loc, result.operands));
2230 }
2231 
2232 // Check the operand number and types must match the element types of the
2233 // LinalgOp interface's shaped operands.
2234 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2235  if (op.getNumOperands() != linalgOp.getNumDpsInits())
2236  return op.emitOpError("expected number of yield values (")
2237  << op.getNumOperands()
2238  << ") to match the number of inits / outs operands of the enclosing "
2239  << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2240 
2241  for (OpOperand &opOperand : op->getOpOperands()) {
2242  OpOperand *outputOperand =
2243  linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2244  Type elementType = outputOperand->get().getType();
2245  if (isa<MemRefType, RankedTensorType>(elementType))
2246  elementType = getElementTypeOrSelf(outputOperand->get().getType());
2247  if (opOperand.get().getType() != elementType)
2248  return op.emitOpError("type of yield operand ")
2249  << (opOperand.getOperandNumber() + 1) << " ("
2250  << opOperand.get().getType() << ") doesn't match "
2251  << "the element type of the enclosing linalg.generic op ("
2252  << elementType << ")";
2253  }
2254  return success();
2255 }
2256 
2257 LogicalResult linalg::YieldOp::verify() {
2258  auto *parentOp = (*this)->getParentOp();
2259  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2260  return emitOpError("expected single non-empty parent region");
2261 
2262  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2263  return verifyYield(*this, linalgOp);
2264 
2265  return emitOpError("expected parent op with LinalgOp interface");
2266 }
2267 
2268 //===----------------------------------------------------------------------===//
2269 // IndexOp
2270 //===----------------------------------------------------------------------===//
2271 
2272 LogicalResult IndexOp::verify() {
2273  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2274  if (!linalgOp)
2275  return emitOpError("expected parent op with LinalgOp interface");
2276  if (linalgOp.getNumLoops() <= getDim())
2277  return emitOpError("expected dim (")
2278  << getDim() << ") to be lower than the number of loops ("
2279  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2280  return success();
2281 }
2282 
2283 /////// Operations corresponding to library calls defined with Tablegen ////////
2284 
2285 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2286 
2287 #define GET_OP_CLASSES
2288 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2289 
2290 #define GET_OP_CLASSES
2291 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2292 
2293 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2294  unsigned rank,
2295  MLIRContext *context) {
2296  if (maybeMap)
2297  return *maybeMap;
2298  if (rank == 0)
2299  return AffineMap::get(context);
2300  return AffineMap::getMultiDimIdentityMap(rank, context);
2301 }
2302 
2304 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2305  MLIRContext *context) {
2307  res.reserve(num);
2308  for (unsigned i = 0; i < num; ++i)
2309  res.push_back(getAffineDimExpr(startIdx++, context));
2310  return res;
2311 }
2312 
2315  auto rangeA = llvm::make_range(a.begin(), a.end());
2316  auto rangeB = llvm::make_range(b.begin(), b.end());
2317  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2318  return llvm::to_vector<4>(concatRanges);
2319 }
2320 
2321 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2322  if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2323  ss << "view";
2324  for (auto size : memref.getShape())
2325  if (size < 0)
2326  ss << "sx";
2327  else
2328  ss << size << "x";
2329  if (failed(appendMangledType(ss, memref.getElementType())))
2330  return failure();
2331  if (auto as = memref.getMemorySpace()) {
2332  if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2333  ss << "as" << attr.getInt();
2334  else
2335  return failure();
2336  }
2337  return success();
2338  }
2339  if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2340  ss << "vector";
2341  llvm::interleave(
2342  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2343  if (failed(appendMangledType(ss, vec.getElementType())))
2344  return failure();
2345  return success();
2346  }
2347  if (t.isSignlessIntOrIndexOrFloat()) {
2348  ss << t;
2349  return success();
2350  }
2351  return failure();
2352 }
2353 
2355  assert(isa<LinalgOp>(op));
2356  std::string name(op->getName().getStringRef().str());
2357  std::string fun = "";
2358  for (NamedAttribute kv : op->getAttrs()) {
2359  if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2360  fun = stringifyEnum(ufa.getValue()).str() + "_";
2361  } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2362  fun = stringifyEnum(bfa.getValue()).str() + "_";
2363  }
2364  }
2365  name.reserve(128);
2366  std::replace(name.begin(), name.end(), '.', '_');
2367  llvm::raw_string_ostream ss(name);
2368  ss << "_" << fun;
2369  for (Type t : op->getOperandTypes()) {
2370  if (failed(appendMangledType(ss, t)))
2371  return std::string();
2372  ss << "_";
2373  }
2374  name.pop_back();
2375  return name;
2376 }
2377 
2378 //===----------------------------------------------------------------------===//
2379 // Canonicalizers and Folders.
2380 //===----------------------------------------------------------------------===//
2381 
2382 namespace {
2383 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2385 
2386  LogicalResult matchAndRewrite(LinalgOp op,
2387  PatternRewriter &rewriter) const override {
2388  for (OpOperand &opOperand : op->getOpOperands()) {
2389  // Linalg "inputs" may be either tensor or memref type.
2390  // tensor<0xelt_type> is a convention that may not always mean
2391  // "0 iterations". Only erase in cases we see memref<...x0x...>.
2392  auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2393  if (!mt)
2394  continue;
2395  if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2396  rewriter.eraseOp(op);
2397  return success();
2398  }
2399  }
2400  return failure();
2401  }
2402 };
2403 
2404 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2405 /// result that is more static than the linalg op.
2406 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2408 
2409  LogicalResult matchAndRewrite(tensor::CastOp castOp,
2410  PatternRewriter &rewriter) const override {
2411  if (!tensor::canFoldIntoProducerOp(castOp))
2412  return failure();
2413 
2414  auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2415  if (!linalgOp)
2416  return failure();
2417 
2418  // Cast can be in conditionally reachable region, if which case folding will
2419  // generate invalid code. Only conservatively fold ops in same block for
2420  // now.
2421  if (castOp->getBlock() != linalgOp->getBlock())
2422  return failure();
2423 
2424  OpBuilder::InsertionGuard guard(rewriter);
2425  rewriter.setInsertionPoint(linalgOp);
2426 
2427  Location loc = linalgOp.getLoc();
2428  OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2429  unsigned resultNumber = resultValue.getResultNumber();
2430  auto resultType =
2431  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2432  // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2433  // going from a more dynamic shape to a less dynamic shape. If the producer
2434  // for this cast, i.e. producer of the out operand, is also an operation
2435  // that folds with tensor.cast consumer (like this pattern), the cast will
2436  // continue to propagate as far up the stack as it can go.
2437  OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2438  Value newOperand =
2439  rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
2440  SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2441  SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2442  linalgOp.getDpsInits().end());
2443  outputOperands[resultNumber] = newOperand;
2444  newOperands.append(outputOperands.begin(), outputOperands.end());
2445 
2446  SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2447  linalgOp->result_type_end());
2448  resultTypes[resultNumber] = resultType;
2449  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2450 
2451  // Create a tensor.cast operation back to the original type.
2452  Value castBack = rewriter.create<tensor::CastOp>(
2453  loc, resultValue.getType(), newOp->getResult(resultNumber));
2454 
2455  SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2456  results[resultNumber] = castBack;
2457  rewriter.replaceOp(linalgOp, results);
2458  rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2459  return success();
2460  }
2461 };
2462 
2463 /// For each of the operand in `operands` this function maps the static sizes of
2464 /// dimensions to their affine dim expressions.
2465 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2466  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2467  for (OpOperand &opOperand : operands) {
2468  if (linalgOp.isScalar(&opOperand))
2469  continue;
2470  Value src = opOperand.get();
2471  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2472  auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2473 
2474  // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2475  // `tensor.cast` operation and source of the cast operation has a static
2476  // shape, then assign it to the `sourceShape`.
2477  auto *parentOp = src.getDefiningOp();
2478  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2479  if (parentOp) {
2480  if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2481  Value castSource = castOp.getSource();
2482  auto castSourceType =
2483  llvm::dyn_cast<RankedTensorType>(castSource.getType());
2484  if (castSourceType && castSourceType.hasStaticShape())
2485  sourceShape = castSourceType.getShape();
2486  }
2487  }
2488 
2489  // If the source shape's dimension has a static shape, map the affine dim
2490  // expression to the known static size.
2491  for (unsigned i = 0; i < sourceShape.size(); i++) {
2492  if (sourceType.isDynamicDim(i))
2493  continue;
2494  if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2495  affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2496  }
2497  }
2498 }
2499 
2500 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2501 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2502 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2503 /// change then `changeNeeded` is false and same operand is added in the
2504 /// `newOperands` list.
2505 static void createNewOperandWithStaticSizes(
2506  Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2507  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2508  SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2509  bool &changeNeeded) {
2510  Value src = opOperand->get();
2511  newOperands.push_back(src);
2512  if (linalgOp.isScalar(opOperand))
2513  return;
2514  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2515  Type resultType = sourceType;
2516  if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2517  resultTypes.push_back(resultType);
2518  return;
2519  }
2520  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2521  AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2522  SmallVector<int64_t> newShape;
2523  // If operand is updated with new shape, `newOperandNeeded` will be
2524  // true.
2525  bool newOperandNeeded = false;
2526  for (unsigned i = 0; i < sourceShape.size(); i++) {
2527  int64_t dimShape = sourceShape[i];
2528  AffineExpr dimExpr = sourceMap.getResult(i);
2529  if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2530  newShape.push_back(dimShape);
2531  continue;
2532  }
2533  // Dimension has a dynamic shape and corresponding affine dim
2534  // expression is present in the map. So assign the size for the
2535  // given affine dim expression to the dimension.
2536  newShape.push_back(affineExprToSize[dimExpr]);
2537  newOperandNeeded = true;
2538  }
2539  resultType = RankedTensorType::get(newShape, sourceType.getElementType());
2540  if (newOperandNeeded) {
2541  changeNeeded = true;
2542  // Get the new operand value given its size and element type by
2543  // casting it.
2544  Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2545  unsigned index = opOperand->getOperandNumber();
2546  newOperands[index] = newOperand;
2547  }
2548  if (linalgOp.isDpsInit(opOperand))
2549  resultTypes.push_back(resultType);
2550 }
2551 
2552 /// Static shapes for the operands can be inferred if any one of the operands
2553 /// have a static shape. This can be done by referring to the affine dim
2554 /// expressions for the operand.
2555 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2557 
2558  LogicalResult matchAndRewrite(LinalgOp linalgOp,
2559  PatternRewriter &rewriter) const override {
2560  if (!linalgOp.hasPureTensorSemantics())
2561  return failure();
2562 
2563  // Maps must be projected permutations.
2564  if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2565  return !map.isProjectedPermutation();
2566  }))
2567  return failure();
2568 
2569  // Maps affine dim expressions to the static size of that dimension.
2570  llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2571  Location loc = linalgOp.getLoc();
2572 
2573  // For each of the affine dim expression, check if the size is known. If
2574  // known add that in the map.
2575  populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2576 
2577  SmallVector<Value> newOperands;
2578  SmallVector<Type> resultTypes;
2579 
2580  // `changeNeeded` is `false` if the operands of `linalgOp` require no
2581  // change in their types.
2582  bool changeNeeded = false;
2583  newOperands.reserve(linalgOp->getNumOperands());
2584  resultTypes.reserve(linalgOp.getNumDpsInits());
2585 
2586  // Iterate over all the operands and update the static sizes.
2587  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2588  createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2589  affineExprToSize, linalgOp, newOperands,
2590  resultTypes, changeNeeded);
2591  }
2592 
2593  // If the generic op has all the required static information, no
2594  // canonicalization needed.
2595  if (!changeNeeded)
2596  return failure();
2597 
2598  // Clone op.
2599  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2600  SmallVector<Value> replacements;
2601  replacements.reserve(newOp->getNumResults());
2602  for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2603  Value newResult = std::get<1>(it);
2604  Value oldResult = std::get<0>(it);
2605  Type newType = newResult.getType();
2606  Type oldType = oldResult.getType();
2607  replacements.push_back(
2608  (newType != oldType)
2609  ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2610  : newResult);
2611  }
2612  rewriter.replaceOp(linalgOp, replacements);
2613  return success();
2614  }
2615 };
2616 
2617 } // namespace
2618 
2619 // All named ops canonicalizers and folders are auto-generated in the
2620 // .cpp.inc.
2621 
2622 //===----------------------------------------------------------------------===//
2623 // SoftmaxOp
2624 //===----------------------------------------------------------------------===//
2625 
2626 LogicalResult SoftmaxOp::verify() {
2627  ShapedType inputType = getInputOperandType();
2628  ShapedType outputType = getOutputOperandType();
2629 
2630  ArrayRef<int64_t> inputShape = inputType.getShape();
2631  ArrayRef<int64_t> outputShape = outputType.getShape();
2632  if (failed(verifyCompatibleShape(inputShape, outputShape)))
2633  return emitOpError("incompatible output shape");
2634 
2635  int64_t inputRank = getInputOperandRank();
2636  int64_t dimension = getDimension();
2637  if ((dimension < 0) || (dimension >= inputRank))
2638  return emitOpError("incorrect dimension specified");
2639 
2640  return success();
2641 }
2642 
2643 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2644  int64_t operandRank = getInputOperandRank();
2645  SmallVector<Range> loopBounds(operandRank);
2646  Location loc = getLoc();
2647  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2648  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2649  Value source = getInput();
2650  for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2651  loopBounds[dim].offset = zero;
2652  loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2653  loopBounds[dim].stride = one;
2654  }
2655  return loopBounds;
2656 }
2657 
2658 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2659  SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2660  utils::IteratorType::parallel);
2661  iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2662  return iteratorTypes;
2663 }
2664 
2665 FailureOr<TilingResult>
2667  ArrayRef<OpFoldResult> offsets,
2668  ArrayRef<OpFoldResult> sizes) {
2669  int64_t rank = getInputOperandRank();
2670  auto oneAttr = builder.getI64IntegerAttr(1);
2671  SmallVector<OpFoldResult> strides(rank, oneAttr);
2672  SmallVector<Value> tiledOperands;
2673  Operation *inputSlice =
2674  getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2675  if (!inputSlice) {
2676  return emitOpError("failed to compute input slice");
2677  }
2678  tiledOperands.emplace_back(inputSlice->getResult(0));
2679  Operation *outputSlice =
2680  getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2681  if (!outputSlice) {
2682  return emitOpError("failed to compute output slice");
2683  }
2684  tiledOperands.emplace_back(outputSlice->getResult(0));
2685 
2686  SmallVector<Type, 4> resultTypes;
2687  if (hasPureTensorSemantics())
2688  resultTypes.push_back(tiledOperands[1].getType());
2689  Operation *tiledOp =
2690  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2691 
2692  return TilingResult{
2693  {tiledOp},
2694  SmallVector<Value>(tiledOp->getResults()),
2695  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2696 }
2697 
2698 LogicalResult SoftmaxOp::getResultTilePosition(
2699  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2700  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2701  SmallVector<OpFoldResult> &resultSizes) {
2702  if (resultNumber == 0) {
2703  resultOffsets.assign(offsets.begin(), offsets.end());
2704  resultSizes.assign(sizes.begin(), sizes.end());
2705  return success();
2706  }
2707  return failure();
2708 }
2709 
2710 // cast(dynamic) -> static.
2711 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2712  return memref::foldMemRefCast(*this);
2713 }
2714 
2715 LogicalResult
2717  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2719  Location loc = getOperation()->getLoc();
2720  IRRewriter rewriter(b);
2721  auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2722  auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2723  for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2724  if (!outputShapedType.isDynamicDim(dim)) {
2725  // Static dim: Return IntegerAttr.
2726  shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2727  } else {
2728  // Dynamic dim: Return Value.
2729  OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2730  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2731  }
2732  }
2733  reifiedReturnShapes.emplace_back(std::move(shapes));
2734  return success();
2735 }
2736 
2737 void SoftmaxOp::getEffects(
2739  &effects) {
2740  for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2741  if (!llvm::isa<MemRefType>(operand.getType()))
2742  continue;
2743  effects.emplace_back(MemoryEffects::Read::get(),
2744  &getOperation()->getOpOperand(index), /*stage=*/0,
2745  /*effectOnFullRegion=*/true,
2747  }
2748 
2749  for (OpOperand &operand : getDpsInitsMutable()) {
2750  if (!llvm::isa<MemRefType>(operand.get().getType()))
2751  continue;
2752  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2753  /*effectOnFullRegion=*/true,
2755  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2756  /*effectOnFullRegion=*/true,
2758  }
2759 }
2760 
2761 // Helper functions for softmax decomposition.
2762 // @{
2763 
2764 // Helper function to produce the iterator types (reduction or parallel) and
2765 // affine maps for the iterators used in the decomposition of softmax.
2766 // This method creates:
2767 // If allParallel == true:
2768 // - iterator type: {parallel, ..., parallel}
2769 // - affine maps:
2770 // -- identity with inputRank dimensions.
2771 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2772 // where N == inputRank.
2773 //
2774 // If allParallel == false:
2775 // - iterator type at dim(i) == parallel for i != \p dim and
2776 // dim(dim) == reduction.
2777 // - affine map:
2778 // -- identity with inputRank dimensions.
2779 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2780 // where N == inputRank.
2781 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2783  int64_t dim, bool allParallel = false) {
2784  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2785  utils::IteratorType::parallel);
2786  if (!allParallel)
2787  iteratorTypes[dim] = utils::IteratorType::reduction;
2788  MLIRContext *ctxt = builder.getContext();
2789  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2790  SmallVector<AffineExpr, 2> affineExprs;
2791  for (int i = 0; i < inputRank; i++) {
2792  if (i != dim)
2793  affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2794  }
2795  auto reductionMap =
2796  AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2797  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2798  return std::make_tuple(iteratorTypes, indexingMaps);
2799 }
2800 
2801 // Helper function to produce a linalg.generic that computes a reduction on
2802 // dimension \p dim with the operation type \p T.
2803 template <typename T>
2804 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2805  int64_t dim) {
2806  auto inputType = cast<ShapedType>(input.getType());
2807  ArrayRef<int64_t> inputShape = inputType.getShape();
2808  int64_t inputRank = inputShape.size();
2809  auto [iteratorTypes, indexingMaps] =
2810  computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2811  assert(indexingMaps.size() == 2 &&
2812  "We should have two maps: 1 for the input, 1 for the output");
2813  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2814 
2815  auto genericOp = builder.create<linalg::GenericOp>(
2816  loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2817  [&](OpBuilder &b, Location loc, ValueRange args) {
2818  Value result = b.create<T>(loc, args[0], args[1]);
2819  b.create<linalg::YieldOp>(loc, result);
2820  });
2821  return genericOp.getResult(0);
2822 }
2823 
2824 /// Produce a linalg generic that computes the second step of the softmax
2825 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2826 /// on dimension \p dim.
2827 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2828  Value max, Value output, int64_t dim) {
2829  auto inputType = cast<ShapedType>(input.getType());
2830  ArrayRef<int64_t> inputShape = inputType.getShape();
2831  int64_t inputRank = inputShape.size();
2832  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2833  builder, inputRank, dim, /*allParallel=*/true);
2834  assert(indexingMaps.size() == 2 && "We should have one map for each input");
2835  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2836  // Add the affine map for the output argument.
2837  indexingMaps.push_back(indexingMaps[0]);
2838  auto genericOp = builder.create<linalg::GenericOp>(
2839  loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2840  iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2841  Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2842  Value result = b.create<math::ExpOp>(loc, diff);
2843  b.create<linalg::YieldOp>(loc, result);
2844  });
2845  return genericOp.getResult(0);
2846 }
2847 
2848 /// Produce a linalg generic that computes the final step of the softmax
2849 /// decomposition.
2850 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2851 /// yield n / d
2852 /// }
2853 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2854  Value denominator, Value output, int64_t dim) {
2855  auto inputType = cast<ShapedType>(numerator.getType());
2856  ArrayRef<int64_t> inputShape = inputType.getShape();
2857  int64_t inputRank = inputShape.size();
2858  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2859  builder, inputRank, dim, /*allParallel=*/true);
2860  assert(indexingMaps.size() == 2 &&
2861  "We should have one map for each input (2)");
2862  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2863  // Add the affine map for the output tensor.
2864  indexingMaps.push_back(indexingMaps[0]);
2865  auto genericOp = builder.create<linalg::GenericOp>(
2866  loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2867  indexingMaps, iteratorTypes,
2868  [&](OpBuilder &b, Location loc, ValueRange args) {
2869  Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2870  b.create<linalg::YieldOp>(loc, result);
2871  });
2872  return genericOp.getResult(0);
2873 }
2874 // @} End helper functions for softmax decomposition.
2875 
2876 /// Given an N-dimensional tensor x, this method converts
2877 /// softmax(x) to the following sequence of operations:
2878 ///
2879 /// 1. Compute the max of x along dimension d. This results
2880 /// in a N-1 dimensional tensor m.
2881 /// m = max(x, dim = d)
2882 ///
2883 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
2884 /// a N dimensional tensor z.
2885 /// z = exp(x - m)
2886 ///
2887 /// 3. Compute the sum of z along dimension d. This results in
2888 /// a N-1 dimensional tensor l.
2889 /// l = sum(z, dim = d)
2890 ///
2891 /// 4. Divide z and l. This gives the N-dimensional softmax.
2892 /// softmax = z / l
2893 ///
2894 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2895  OpBuilder::InsertionGuard guard(b);
2896  b.setInsertionPoint(*this);
2897  Location loc = getLoc();
2898  Value input = getInput();
2899  ShapedType inputType = getInputOperandType();
2900  Type elementType = inputType.getElementType();
2901  int64_t reductionDim = getDimension();
2902  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
2903  Value output = getOutput();
2904  dims.erase(dims.begin() + reductionDim);
2905  // Step 1: Compute max along dim.
2906  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2907  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
2908  elementType, b, loc,
2909  /*useOnlyFiniteValue=*/true);
2910  Value neutralForMaxFInit =
2911  b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
2912  .result();
2913  Value max =
2914  reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2915 
2916  // Step 2: Subtract max from input and exponentiate.
2917  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
2918 
2919  // Step 3: Compute sum along dim.
2920  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
2921  b, loc, /*useOnlyFiniteValue=*/true);
2922  Value zeroInit =
2923  b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
2924  Value denominator =
2925  reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2926 
2927  // Step 4: Compute softmax.
2928  Value result =
2929  buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2930  return SmallVector<Value>{result};
2931 }
2932 
2933 //===----------------------------------------------------------------------===//
2934 // WinogradFilterTransformOp
2935 //===----------------------------------------------------------------------===//
2936 
2937 LogicalResult WinogradFilterTransformOp::verify() {
2938  auto filterType = cast<ShapedType>(getFilter().getType());
2939  ArrayRef<int64_t> filterShape = filterType.getShape();
2940  int64_t filterH = filterShape[getFilterHDim()];
2941  int64_t filterW = filterShape[getFilterWDim()];
2942  int64_t r = getR();
2943  int64_t m = getM();
2944 
2945  if (filterH != r && filterH != 1)
2946  return emitOpError("expect filter height either equals to r or 1");
2947  if (filterW != r && filterW != 1)
2948  return emitOpError("expect filter width either equals to r or 1");
2949  if (filterH == 1 && filterW == 1)
2950  return emitOpError("expect either filter height or width equals to r");
2951 
2952  SmallVector<int64_t> expectedOutputShape;
2953  expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2954  expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2955  expectedOutputShape.push_back(filterShape[getFilterCDim()]);
2956  expectedOutputShape.push_back(filterShape[getFilterFDim()]);
2957 
2958  auto outputType = cast<ShapedType>(getOutput().getType());
2959  ArrayRef<int64_t> outputShape = outputType.getShape();
2960  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2961  return emitOpError("the output shape is not expected");
2962  }
2963  return success();
2964 }
2965 
2967 WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
2968  Location loc = getLoc();
2969  IntegerAttr zeroAttr = builder.getIndexAttr(0);
2970  IntegerAttr oneAttr = builder.getIndexAttr(1);
2971  Value filter = getFilter();
2972  int64_t filterRank = getFilterOperandRank();
2973  SmallVector<Range> loopBounds(filterRank);
2974  for (unsigned dim = 0; dim < filterRank; ++dim) {
2975  loopBounds[dim].offset = zeroAttr;
2976  loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
2977  loopBounds[dim].stride = oneAttr;
2978  }
2979  return loopBounds;
2980 }
2981 
2983 WinogradFilterTransformOp::getLoopIteratorTypes() {
2984  int64_t filterRank = getFilterOperandRank();
2985  SmallVector<utils::IteratorType> iteratorTypes(filterRank,
2986  utils::IteratorType::parallel);
2987  return iteratorTypes;
2988 }
2989 
2991  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2992  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2993  SmallVector<OpFoldResult> &resultSizes) {
2994  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
2995  ShapedType filterType = getFilterOperandType();
2996  ArrayRef<int64_t> filterShape = filterType.getShape();
2997  int64_t filterH = filterShape[getFilterHDim()];
2998  int64_t filterW = filterShape[getFilterWDim()];
2999  int64_t m = getM();
3000  int64_t r = getR();
3001  int64_t alpha = m + r - 1;
3002  int64_t alphaH = filterH != 1 ? alpha : 1;
3003  int64_t alphaW = filterW != 1 ? alpha : 1;
3004  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3005  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3006 
3007  resultOffsets.append(
3008  {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3009  resultSizes.append(
3010  {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3011 
3012  return success();
3013 }
3014 
3015 /// Implement tiling for winograd_filter_transform
3016 /// The input of winograd_filter_transform is (F, KH, KW, C).
3017 /// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3018 /// Users can specify the tile sizes of F and C.
3019 /// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3020 /// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3022  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3023  ArrayRef<OpFoldResult> sizes) {
3024  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3025  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3026  ShapedType filterType = getFilterOperandType();
3027  ArrayRef<int64_t> filterShape = filterType.getShape();
3028  int64_t filterH = filterShape[getFilterHDim()];
3029  int64_t filterW = filterShape[getFilterWDim()];
3030  IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3031  IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3032  SmallVector<Value> tiledOperands;
3033  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3034 
3035  sliceOffsets.append(
3036  {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3037  sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3038  sizes[getFilterCDim()]});
3039  int64_t filterRank = getFilterOperandRank();
3040  SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3041  Location loc = getLoc();
3042  auto filterSlice = builder.create<tensor::ExtractSliceOp>(
3043  loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3044  tiledOperands.emplace_back(filterSlice);
3045 
3046  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3047  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3048  resultSizes)))
3049  return failure();
3050 
3051  int64_t outputRank = getOutputOperandRank();
3052  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3053  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3054  loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3055  tiledOperands.emplace_back(outputSlice);
3056 
3057  SmallVector<Type> resultTypes;
3058  resultTypes.push_back(tiledOperands[1].getType());
3059  Operation *tiledOp =
3060  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3061 
3062  return TilingResult{
3063  {tiledOp},
3064  SmallVector<Value>(tiledOp->getResults()),
3065  llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3066 }
3067 
3068 //===----------------------------------------------------------------------===//
3069 // WinogradInputTransformOp
3070 //===----------------------------------------------------------------------===//
3071 
3072 LogicalResult WinogradInputTransformOp::verify() {
3073  auto inputType = cast<ShapedType>(getInput().getType());
3074  ArrayRef<int64_t> inputShape = inputType.getShape();
3075  int64_t inputH = inputShape[getInputHDim()];
3076  int64_t inputW = inputShape[getInputWDim()];
3077  int m = getM();
3078  int r = getR();
3079  int64_t tileSize = m + r - 1;
3080 
3081  auto outputType = cast<ShapedType>(getOutput().getType());
3082  ArrayRef<int64_t> outputShape = outputType.getShape();
3083  bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3084  bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3085 
3086  SmallVector<int64_t> expectedOutputShape(6, inputH);
3087  if (ShapedType::isDynamic(inputH)) {
3088  expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3089  expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3090  } else {
3091  expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3092  expectedOutputShape[getOutputTileHDim()] =
3093  leftTransform ? (inputH - (r - 1)) / m : inputH;
3094  }
3095  if (ShapedType::isDynamic(inputW)) {
3096  expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3097  expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3098  } else {
3099  expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3100  expectedOutputShape[getOutputTileWDim()] =
3101  rightTransform ? (inputW - (r - 1)) / m : inputW;
3102  }
3103  expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3104  expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3105 
3106  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3107  return emitOpError("the output shape is not expected");
3108  }
3109  return success();
3110 }
3111 
3113 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3114  Location loc = getLoc();
3115  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3116  IntegerAttr oneAttr = builder.getIndexAttr(1);
3117  Value output = getOutput();
3118  int64_t outputRank = getOutputOperandRank();
3119  SmallVector<Range> loopBounds(outputRank);
3120  for (unsigned dim = 0; dim < outputRank; ++dim) {
3121  loopBounds[dim].offset = zeroAttr;
3122  // alphaH, alphaW, tileH, tileW, N, C
3123  loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3124  loopBounds[dim].stride = oneAttr;
3125  }
3126  return loopBounds;
3127 }
3128 
3130 WinogradInputTransformOp::getLoopIteratorTypes() {
3131  int64_t outputRank = getOutputOperandRank();
3132  SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3133  utils::IteratorType::parallel);
3134  return iteratorTypes;
3135 }
3136 
3138  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3139  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3140  SmallVector<OpFoldResult> &resultSizes) {
3141  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3142  ShapedType outputType = getOutputOperandType();
3143  ArrayRef<int64_t> outputShape = outputType.getShape();
3144  int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3145  int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3146 
3147  int64_t m = getM();
3148  int64_t r = getR();
3149  int64_t alpha = m + r - 1;
3150  int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3151  int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3152 
3153  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3154  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3155 
3156  resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3157  offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3158  offsets[getOutputCDim()]});
3159  resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3160  sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3161  sizes[getOutputCDim()]});
3162 
3163  return success();
3164 }
3165 
3166 /// Implement tiling for winograd_input_transform
3167 /// The input of winograd_input_transform is (N, H, W, C).
3168 /// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3169 /// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3170 /// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3171 /// the values for the sizes of tileH, tileW, N, C for one tile.
3172 FailureOr<TilingResult>
3174  ArrayRef<OpFoldResult> offsets,
3175  ArrayRef<OpFoldResult> sizes) {
3176  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3177  int64_t m = getM();
3178  int64_t r = getR();
3179 
3180  ShapedType outputType = getOutputOperandType();
3181  ArrayRef<int64_t> outputShape = outputType.getShape();
3182  int64_t alphaH = outputShape[getOutputAlphaHDim()];
3183  int64_t alphaW = outputShape[getOutputAlphaWDim()];
3184 
3185  Location loc = getLoc();
3186  MLIRContext *context = builder.getContext();
3187  auto identityAffineMap =
3188  AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3189  auto offsetAffineMap =
3190  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3191  Value mappedOffsetH = affine::makeComposedAffineApply(
3192  builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3193  offsets[getOutputTileHDim()]);
3194  Value mappedOffsetW = affine::makeComposedAffineApply(
3195  builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3196  offsets[getOutputTileWDim()]);
3197  auto sizeAffineMap = AffineMap::get(
3198  1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3199  Value mappedSizeH = affine::makeComposedAffineApply(
3200  builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3201  Value mappedSizeW = affine::makeComposedAffineApply(
3202  builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3203 
3204  SmallVector<Value> tiledOperands;
3205  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3206 
3207  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3208  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3209  sliceOffsets.append(
3210  {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3211  OpFoldResult sizeH =
3212  alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3213  OpFoldResult sizeW =
3214  alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3215  sliceSizes.append(
3216  {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3217  int64_t inputRank = getInputOperandRank();
3218  SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3219  auto inputSlice = builder.create<tensor::ExtractSliceOp>(
3220  loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3221  tiledOperands.emplace_back(inputSlice);
3222 
3223  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3224  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3225  resultSizes)))
3226  return failure();
3227 
3228  int64_t outputRank = getOutputOperandRank();
3229  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3230  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3231  loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3232  tiledOperands.emplace_back(outputSlice);
3233 
3234  SmallVector<Type> resultTypes;
3235  resultTypes.push_back(tiledOperands[1].getType());
3236  Operation *tiledOp =
3237  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3238 
3239  return TilingResult{
3240  {tiledOp},
3241  SmallVector<Value>(tiledOp->getResults()),
3242  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3243 }
3244 
3245 //===----------------------------------------------------------------------===//
3246 // WinogradOutputTransformOp
3247 //===----------------------------------------------------------------------===//
3248 
3249 LogicalResult WinogradOutputTransformOp::verify() {
3250  auto valueType = cast<ShapedType>(getValue().getType());
3251  ArrayRef<int64_t> valueShape = valueType.getShape();
3252  int64_t valueH = valueShape[getValueAlphaHDim()];
3253  int64_t valueW = valueShape[getValueAlphaWDim()];
3254  int64_t valueTileH = valueShape[getValueTileHDim()];
3255  int64_t valueTileW = valueShape[getValueTileWDim()];
3256  int m = getM();
3257  int r = getR();
3258  bool leftTransform = valueH != 1;
3259  bool rightTransform = valueW != 1;
3260 
3261  int64_t outputRank = getOutputOperandRank();
3262  SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3263  if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3264  expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3265  } else {
3266  if (valueH != (leftTransform ? m + r - 1 : 1))
3267  return emitOpError("expect input height equals to input tile size");
3268  expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3269  }
3270  if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3271  expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3272  } else {
3273  if (valueW != (rightTransform ? m + r - 1 : 1))
3274  return emitOpError("expect input width equals to input tile size");
3275  expectedOutputShape[getOutputWDim()] =
3276  (rightTransform ? m : 1) * valueTileW;
3277  }
3278  expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3279  expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3280 
3281  auto outputType = cast<ShapedType>(getOutput().getType());
3282  ArrayRef<int64_t> outputShape = outputType.getShape();
3283  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3284  return emitOpError("the output shape is not expected");
3285  }
3286  return success();
3287 }
3288 
3290 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3291  Location loc = getLoc();
3292  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3293  IntegerAttr oneAttr = builder.getIndexAttr(1);
3294  Value value = getValue();
3295  int64_t valueRank = getValueOperandRank();
3296  SmallVector<Range> loopBounds(valueRank);
3297  for (unsigned dim = 0; dim < valueRank; ++dim) {
3298  loopBounds[dim].offset = zeroAttr;
3299  // alphaH, alphaW, tileH, tileW, N, F
3300  loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3301  loopBounds[dim].stride = oneAttr;
3302  }
3303  return loopBounds;
3304 }
3305 
3307 WinogradOutputTransformOp::getLoopIteratorTypes() {
3308  int64_t valueRank = getValueOperandRank();
3309  SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3310  utils::IteratorType::parallel);
3311  return iteratorTypes;
3312 }
3313 
3315  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3316  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3317  SmallVector<OpFoldResult> &resultSizes) {
3318  int64_t m = getM();
3319 
3320  Location loc = getLoc();
3321  MLIRContext *context = builder.getContext();
3322  auto identityAffineMap =
3323  AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3324  auto affineMap =
3325  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3326 
3327  ShapedType valueType = getValueOperandType();
3328  ArrayRef<int64_t> valueShape = valueType.getShape();
3329  int64_t valueH = valueShape[0];
3330  int64_t valueW = valueShape[1];
3331  Value mappedOffsetH = affine::makeComposedAffineApply(
3332  builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3333  offsets[getValueTileHDim()]);
3334  Value mappedOffsetW = affine::makeComposedAffineApply(
3335  builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3336  offsets[getValueTileWDim()]);
3337  Value mappedSizeH = affine::makeComposedAffineApply(
3338  builder, loc, affineMap, sizes[getValueTileHDim()]);
3339  Value mappedSizeW = affine::makeComposedAffineApply(
3340  builder, loc, affineMap, sizes[getValueTileWDim()]);
3341 
3342  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3343  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3344  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3345  OpFoldResult sizeH =
3346  valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3347  OpFoldResult sizeW =
3348  valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3349 
3350  resultOffsets.append(
3351  {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3352  resultSizes.append(
3353  {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3354  return success();
3355 }
3356 
3357 /// Implement tiling for winograd_output_transform
3358 /// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3359 /// F). The output of winograd_output_transform is (N, H, W, F) Users can
3360 /// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3361 /// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3362 /// for the sizes of tileH, tileW, N, F for one tile.
3364  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3365  ArrayRef<OpFoldResult> sizes) {
3366  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3367  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3368  Location loc = getLoc();
3369  SmallVector<Value> tiledOperands;
3370  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3371 
3372  ShapedType valueType = getValueOperandType();
3373  ArrayRef<int64_t> valueShape = valueType.getShape();
3374  int64_t alphaH = valueShape[getValueAlphaHDim()];
3375  int64_t alphaW = valueShape[getValueAlphaWDim()];
3376  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3377  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3378 
3379  sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3380  offsets[getValueTileWDim()], offsets[getValueNDim()],
3381  offsets[getValueFDim()]});
3382  sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3383  sizes[getValueTileWDim()], sizes[getValueNDim()],
3384  sizes[getValueFDim()]});
3385  int64_t valueRank = getValueOperandRank();
3386  SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3387  auto valueSlice = builder.create<tensor::ExtractSliceOp>(
3388  loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3389  tiledOperands.emplace_back(valueSlice);
3390 
3391  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3392  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3393  resultSizes)))
3394  return failure();
3395 
3396  int64_t outputRank = getOutputOperandRank();
3397  SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3398  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3399  loc, getOutput(), resultOffsets, resultSizes, strides);
3400  tiledOperands.emplace_back(outputSlice);
3401 
3402  SmallVector<Type> resultTypes;
3403  resultTypes.push_back(tiledOperands[1].getType());
3404  Operation *tiledOp =
3405  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3406 
3407  return TilingResult{
3408  {tiledOp},
3409  SmallVector<Value>(tiledOp->getResults()),
3410  llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3411 }
3412 
3413 //===----------------------------------------------------------------------===//
3414 // LinalgDialect
3415 //===----------------------------------------------------------------------===//
3416 
3417 void LinalgDialect::getCanonicalizationPatterns(
3418  RewritePatternSet &results) const {
3419  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
3420  InferStaticShapeOfOperands>(getContext());
3421 }
3422 
3424  Attribute value, Type type,
3425  Location loc) {
3426  return arith::ConstantOp::materialize(builder, value, type, loc);
3427 }
3428 
3429 // Returns true if the result expression of `subMap` are a subset of `fullMap`.
3430 static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3431  auto explicitRange = subMap.getResults();
3432  auto defaultRange = fullMap.getResults();
3433  DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3434  DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3435  llvm::set_union(explicitSet, defaultSet);
3436  return explicitSet == defaultSet;
3437 }
3438 
3439 /// Check if the user defined map is valid broadcast map. Here broadcast
3440 /// indexing maps are defined in context of corresponding default indexing maps
3441 /// for the given Op. This way the check becomes very simple i.e just check the
3442 /// number of result dims.
3443 /// Returns true if the explictMap is broadcasted with respect to the
3444 /// defaultMap.
3445 static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3446  return explictMap.getNumResults() < defaultMap.getNumResults();
3447 }
3448 
3449 /// Verifies the broadcast and transpose semantic sepecified by the explicit
3450 /// indexing map for the MatmulOp \p op for each operand specified by \p
3451 /// opIndex.
3452 static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3453  unsigned opIndex) {
3454  SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3455  SmallVector<AffineMap, 3> defaultIndexingMaps =
3456  matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3457 
3458  auto opIndexingMap = opIndexingMaps[opIndex];
3459  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3460  // Check general validity of indexing map results.
3461  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3462  return matmulOp->emitOpError()
3463  << "Unexpected dim expression in map result.";
3464 
3465  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3466  if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3467  return matmulOp->emitOpError()
3468  << "Invalid broadcast requested, should be (d2).";
3469  }
3470  return success();
3471  }
3472  return success();
3473 }
3474 
3475 // Check general validity of input indexing map.
3476 static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
3477  AffineMap opIndexingMap,
3478  AffineMap defaultIndexingMap, bool isLHS) {
3479  // Check the result dims are valid.
3480  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3481  return batchMatmulOp->emitOpError()
3482  << "Unexpected result dim expression (outside the set of default "
3483  "result dims).";
3484 
3485  // Check for valid number of result dims of input maps.
3486  if (opIndexingMap.getNumResults() > 3)
3487  return batchMatmulOp->emitOpError()
3488  << "no. of result dim expressions exceeds 3.";
3489 
3490  auto hasValidBatchDim = [](AffineMap map) {
3491  AffineExpr batchDim = map.getResult(0);
3492  return batchDim.isFunctionOfDim(0);
3493  };
3494 
3495  // Check if the requested broadcast is valid.
3496  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3497  if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3498  return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
3499  } else if (!hasValidBatchDim(opIndexingMap)) {
3500  return batchMatmulOp->emitOpError()
3501  << "Invalid batch dimension expression.";
3502  }
3503  return success();
3504 }
3505 
3506 /// This function checks if the given AffineMap for the output of a
3507 /// BatchMatmulOp has exactly 3 result dimensions and if the output map result
3508 /// dimensions are valid.
3509 static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
3510  AffineMap opIndexingMap) {
3511  if (opIndexingMap.getNumResults() != 3)
3512  return batchMatmulOp->emitOpError()
3513  << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3514  << ").";
3515 
3516  auto areValidOutputResultDim = [](AffineMap outputMap) {
3517  return outputMap.getResult(0).isFunctionOfDim(0) &&
3518  outputMap.getResult(1).isFunctionOfDim(1) &&
3519  outputMap.getResult(2).isFunctionOfDim(2);
3520  };
3521 
3522  if (!areValidOutputResultDim(opIndexingMap))
3523  return batchMatmulOp->emitOpError()
3524  << "Invalid output map result dimension.";
3525 
3526  return success();
3527 }
3528 
3529 /// Verifies the broadcast and transpose semantic specified by the explicit
3530 /// indexing map for the BatchMatmulOp op for each operand specified by opIndex.
3531 static LogicalResult
3532 verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
3533  unsigned opIndex) {
3534  SmallVector<AffineMap, 3> opIndexingMaps =
3535  batchMatmulOp.getIndexingMapsArray();
3536  SmallVector<AffineMap, 3> defaultIndexingMaps =
3537  batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
3538 
3539  if (opIndexingMaps.size() != 3)
3540  return batchMatmulOp->emitOpError()
3541  << "Indexing_map attribute must have 3 affine maps.";
3542 
3543  auto opIndexingMap = opIndexingMaps[opIndex];
3544  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3545 
3546  if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap)))
3547  return failure();
3548 
3549  if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
3550  opIndex == 0)))
3551  return failure();
3552 
3553  return success();
3554 }
3555 
3556 namespace mlir {
3557 namespace linalg {
3558 
3559 //===----------------------------------------------------------------------===//
3560 // MatMulOp
3561 //===----------------------------------------------------------------------===//
3562 
3563 /// Returns a list of AffineMap with the typical matmul indexing charactristic.
3564 SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3565  AffineExpr d0, d1, d2;
3566  SmallVector<AffineMap> indexingMaps;
3567  bindDims(context, d0, d1, d2);
3568  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3569  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3570  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3571  return indexingMaps;
3572 }
3573 
3574 SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3575  return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3576  utils::IteratorType::parallel,
3577  utils::IteratorType::reduction};
3578 }
3579 
3580 unsigned MatmulOp::getNumRegionArgs() { return 3; }
3581 
3582 std::string MatmulOp::getLibraryCallName() {
3583  return generateLibraryCallName(getOperation());
3584 }
3585 
3586 bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3587 
3588 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3589 /// the user defined indexing maps are not equal to default map.
3590 bool MatmulOp::hasUserDefinedMaps() {
3591  SmallVector<AffineMap, 3> defaultMaps =
3592  getDefaultIndexingMaps(this->getContext());
3593  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3594  return defaultMaps != explicitMaps;
3595 }
3596 
3597 /// Implements the block region builder for the MatmulOp. This is called by
3598 /// 'fillStructuredOpRegion'.
3599 void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3600  ArrayRef<NamedAttribute> attrs) {
3601  assert(3 > 0 && block.getNumArguments() == 3 &&
3602  "MatmulOp regionBuilder expects 3 (>=0) args");
3603  RegionBuilderHelper helper(b, block);
3604  SmallVector<Value> yields;
3605 
3606  TypeFn castVal = TypeFn::cast_signed;
3607  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3608  return attr.getName() == "cast";
3609  });
3610  if (castIter != attrs.end()) {
3611  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3612  castVal = attr.getValue();
3613  }
3614 
3615  Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3616  block.getArgument(0));
3617  Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3618  block.getArgument(1));
3619  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3620  Value value4 =
3621  helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
3622  yields.push_back(value4);
3623  helper.yieldOutputs(yields);
3624 }
3625 
3626 /// Returns true if the given broadcast map \p bcastMap is valid for this op.
3627 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3628  assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3629  AffineExpr exp = bcastMap.getResult(0);
3630  // Invalid map if the common dimension of matmul not found.
3631  return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
3632 }
3633 
3634 FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3635  if (parser.parseOptionalKeyword("indexing_maps"))
3636  return ArrayAttr{
3637  nullptr}; // Success in case indexing_maps was not provided.
3638 
3639  ArrayAttr arrayAttr;
3640  if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3641  return failure();
3642 
3643  if (llvm::any_of(arrayAttr,
3644  [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3645  return parser.emitError(parser.getCurrentLocation())
3646  << "element of indexing_maps array is not an affine_map";
3647 
3648  return arrayAttr;
3649 }
3650 
3651 ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3652  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3653  if (failed(indexingMapsAttr))
3654  return failure();
3655 
3656  if (*indexingMapsAttr == nullptr) {
3657  auto indexingMapAttrs = llvm::map_to_vector(
3658  MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3659  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3660  indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
3661  }
3662 
3663  result.addAttribute("indexing_maps", *indexingMapsAttr);
3664  return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3665  MatmulOp::getRegionBuilder());
3666 }
3667 
3668 void MatmulOp::print(OpAsmPrinter &p) {
3669  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
3670  MatmulOp::getDefaultIndexingMaps(getContext()),
3671  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3672  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
3673  p << " indexing_maps = [";
3674  llvm::interleaveComma(getIndexingMaps(), p,
3675  [&](Attribute attr) { p.printAttribute(attr); });
3676  p << "]";
3677  }
3678 
3679  SmallVector<StringRef, 3> elidedAttrs = {
3680  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3681  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3682  elidedAttrs);
3683 }
3684 
3685 /// Verify the user defined indexing maps.
3686 LogicalResult MatmulOp::verify() {
3687  // Verification of pure matmul is handled by verifyStructuredOpInterface().
3688  if (!hasUserDefinedMaps())
3689  return success();
3690 
3691  for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3692  if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3693  return failure();
3694  }
3695  return success();
3696 }
3697 
3698 LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3699  return memref::foldMemRefCast(*this);
3700 }
3701 
3702 void MatmulOp::getEffects(
3704  &effects) {
3705  if (hasPureTensorSemantics())
3706  return;
3707  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3708 }
3709 
3710 Speculation::Speculatability MatmulOp::getSpeculatability() {
3711  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3712 }
3713 
3714 //===----------------------------------------------------------------------===//
3715 // ContractOp
3716 //===----------------------------------------------------------------------===//
3717 
3718 SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
3719  AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3720  // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
3721  // domains are all the same, and each implements a projected permutation.
3722  // Each iteration space dim must occur for at least one operand and either
3723  // takes part in a contraction/reduction or else has parallel iteration type.
3724  // We have that a dim is a contraction/reduction dim if and only if the dim
3725  // occurs for the output operand. We use this fact for fast inference:
3726  // NB: In case we allow dims to occur solely for one input, the above still
3727  // holds: per the einsum semantics, these are reduction dims as well.
3728  SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
3729  for (auto result : outAffineMap.getResults()) {
3730  auto dimExpr = dyn_cast<AffineDimExpr>(result);
3731  assert(dimExpr && "affine_map is a projected permutation");
3732  dimsInOutput[dimExpr.getPosition()] = true;
3733  }
3734 
3735  SmallVector<utils::IteratorType> iteratorTypes;
3736  for (auto dimOccursInOutput : dimsInOutput)
3737  iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
3738  : utils::IteratorType::reduction);
3739 
3740  return iteratorTypes;
3741 }
3742 
3743 unsigned ContractOp::getNumRegionArgs() { return 3; }
3744 
3745 /// Implement block region builder, which is called by 'fillStructuredOpRegion'.
3746 void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3747  ArrayRef<NamedAttribute> attrs) {
3748  assert(block.getNumArguments() == 3 &&
3749  "ContractOp regionBuilder expects 3 args");
3750  RegionBuilderHelper helper(b, block);
3751 
3752  TypeFn castSignedness = TypeFn::cast_signed;
3753  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3754  return attr.getName() == "cast";
3755  });
3756  if (castIter != attrs.end()) {
3757  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3758  castSignedness = attr.getValue();
3759  }
3760 
3761  // TODO: Support fields with operators besides mult & add.
3762  Type outType = block.getArgument(2).getType();
3763  Value lhsAtOutType =
3764  helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
3765  Value rhsAtOutType =
3766  helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
3767  Value productAtOutType =
3768  helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
3769  Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3770  productAtOutType);
3771  helper.yieldOutputs({result});
3772 }
3773 
3774 ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
3775  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3776  if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
3777  return parser.emitError(parser.getCurrentLocation(),
3778  "expected 'indexing_maps' attribute");
3779  result.addAttribute("indexing_maps", *indexingMapsAttr);
3780 
3781  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
3782  regionBuilder);
3783 }
3784 
3786  p << " indexing_maps = [";
3787  llvm::interleaveComma(getIndexingMaps(), p,
3788  [&](Attribute attr) { p.printAttribute(attr); });
3789  p << "]";
3791  p, getOperation(), getInputs(), getOutputs(),
3792  /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
3793 }
3794 
3795 LogicalResult ContractOp::verify() {
3796  int iterationSpaceDims = -1;
3797  // Map iter space dims to #occurrences in inputs' and output's affine_maps:
3798  // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
3799  // access an input operand (so occurrence count can be at most 2) and
3800  // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
3801  SmallVector<size_t> inOccurrences;
3802  SmallVector<size_t> outOccurrences;
3803 
3804  // A helper so that for each operand's affine_map and type we check that ...
3805  auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
3806  bool isInput) -> LogicalResult {
3807  // ... the affine_map is a projected permutation;
3808  if (!affineMap.isProjectedPermutation())
3809  return emitError("provided affine_map is not a projected permutation");
3810 
3811  // ... the rank of the affine_map's results and corresponding type match;
3812  if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
3813  if (affineMap.getNumResults() != shapedType.getRank())
3814  return emitError("ranks of shaped operand and results of corresponding "
3815  "affine_map differ");
3816  } else if (affineMap.getNumResults() != 0) {
3817  return emitError("affine_map specifies shaped access while operand has "
3818  "non-shaped type");
3819  }
3820 
3821  // ... the rank of the affine_map's domain is the same as those seen prior;
3822  if (iterationSpaceDims == -1) {
3823  iterationSpaceDims = affineMap.getNumDims();
3824  inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
3825  outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
3826  } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
3827  return emitError("iteration spaces of provided affine_maps differ");
3828  }
3829 
3830  // ... update counts of dims used to access either an input or the output.
3831  for (AffineExpr affineExpr : affineMap.getResults()) {
3832  auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
3833  if (!affineDimExpr)
3834  llvm_unreachable("affine_map is a projected permutation");
3835 
3836  if (isInput)
3837  inOccurrences[affineDimExpr.getPosition()] += 1;
3838  else
3839  outOccurrences[affineDimExpr.getPosition()] += 1;
3840  }
3841 
3842  return success();
3843  };
3844 
3845  for (auto &&[affineMap, operandType, isInput] :
3846  llvm::zip(getIndexingMapsArray(), getOperandTypes(),
3847  SmallVector<bool>{true, true, false})) {
3848  if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
3849  return failure(); // NB: checkAffineMapAndType will emit relevant error.
3850  }
3851 
3852  bool hasContractingDim = false;
3853  for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
3854  size_t inOccCount = inOccurrences[dimIndex];
3855  size_t outOccCount = outOccurrences[dimIndex];
3856 
3857  // We have a contracting dim if and only if ...
3858  hasContractingDim |= inOccCount == 2 && outOccCount == 0;
3859 
3860  if (inOccCount == 0 && outOccCount == 0)
3861  return emitError() << "iteration space dim at index " << dimIndex
3862  << " not used to access any operand";
3863 
3864  // NB: We disallow a dim which occurs for only one input operand and not
3865  // for the output. In terms of einsum semantics such dims have a
3866  // sensible meaning - namely an additional reduction per each such dim.
3867  // By contrast, the ContractionOpInterface does not know about this
3868  // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
3869  // while vector.contract's verifier accepts dims of this kind many of
3870  // its lowerings give up on encountering these dims.
3871  // TODO: Remove following once we have comprehensive support for input-only
3872  // reduction dims, at both the linalg- and vector-dialect levels.
3873  if (inOccCount == 1 && outOccCount != 1)
3874  return emitError()
3875  << "iteration space dim at index " << dimIndex
3876  << " is neither a contracting dim nor of parallel iteration type";
3877  }
3878 
3879  if (!hasContractingDim)
3880  return emitError("'indexing_maps' do not specify a contracting dimension");
3881 
3882  return success();
3883 }
3884 
3885 LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3886  return memref::foldMemRefCast(*this);
3887 }
3888 
3889 void ContractOp::getEffects(
3891  &effects) {
3892  if (hasPureTensorSemantics())
3893  return;
3894  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3895 }
3896 
3897 Speculation::Speculatability ContractOp::getSpeculatability() {
3898  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3899 }
3900 
3901 //===----------------------------------------------------------------------===//
3902 // Implementation of BatchMatmulOp
3903 //===----------------------------------------------------------------------===//
3905 BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3906  AffineExpr d0, d1, d2, d3;
3907  SmallVector<AffineMap> indexingMaps;
3908  bindDims(context, d0, d1, d2, d3);
3909  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
3910  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
3911  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
3912  return indexingMaps;
3913 }
3914 
3915 SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
3917  utils::IteratorType::parallel, utils::IteratorType::parallel,
3918  utils::IteratorType::parallel, utils::IteratorType::reduction};
3919 }
3920 
3921 unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
3922 
3923 std::string BatchMatmulOp::getLibraryCallName() {
3924  return generateLibraryCallName(getOperation());
3925 }
3926 
3927 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3928 /// the user defined indexing maps are not equal to default map.
3929 bool BatchMatmulOp::hasUserDefinedMaps() {
3930  SmallVector<AffineMap, 3> defaultMaps =
3931  getDefaultIndexingMaps(this->getContext());
3932  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3933  return defaultMaps != explicitMaps;
3934 }
3935 
3936 /// Returns true if the given broadcast map bcastMap is valid for this op.
3937 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
3938  assert(bcastMap.getNumResults() < 3 &&
3939  "Expected less than 3 result dim expr.");
3940  bool isValid = false;
3941  enum Indices { batchPos, mPos, nPos, kPos };
3942  if (bcastMap.getNumResults() == 1) {
3943  AffineExpr exp = bcastMap.getResult(0);
3944  isValid = exp.isFunctionOfDim(kPos);
3945  } else if (bcastMap.getNumResults() == 2) {
3946  AffineExpr exp0 = bcastMap.getResult(0);
3947  AffineExpr exp1 = bcastMap.getResult(1);
3948  isValid = isLHS
3949  ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
3950  : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
3951  }
3952  return isValid;
3953 }
3954 
3955 void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3956  ArrayRef<NamedAttribute> attrs) {
3957  assert(block.getNumArguments() == 3 &&
3958  "BatchMatmulOp regionBuilder expects 3 (>=0) args");
3959  RegionBuilderHelper helper(b, block);
3960  SmallVector<Value> yields;
3961 
3962  auto toType = block.getArgument(2).getType();
3963  Value castValA =
3964  helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
3965  Value castValB =
3966  helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
3967  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
3968  Value addVal =
3969  helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
3970  yields.push_back(addVal);
3971  helper.yieldOutputs(yields);
3972 }
3973 
3974 ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3975  SmallVector<Attribute, 3> indexingMapsAttr;
3976  Attribute mapAttr;
3977  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
3978  if (parser.parseEqual())
3979  return failure();
3980 
3981  if (parser.parseLSquare())
3982  return failure();
3983 
3984  do {
3985  if (parser.parseAttribute(mapAttr))
3986  return failure();
3987  if (!isa<AffineMapAttr>(mapAttr)) {
3988  return parser.emitError(parser.getCurrentLocation(),
3989  "expected affine map attribute");
3990  }
3991  indexingMapsAttr.push_back(mapAttr);
3992 
3993  if (parser.parseOptionalComma())
3994  break;
3995  } while (true);
3996 
3997  if (parser.parseRSquare())
3998  return failure();
3999  }
4000  // Initialize indexingMaps, if not supplied explicitly.
4001  if (indexingMapsAttr.empty()) {
4002  indexingMapsAttr = llvm::map_to_vector(
4003  BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4004  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4005  }
4006  result.addAttribute("indexing_maps",
4007  parser.getBuilder().getArrayAttr(indexingMapsAttr));
4008 
4009  return ::parseNamedStructuredOp(parser, result,
4010  BatchMatmulOp::getNumRegionArgs(),
4011  BatchMatmulOp::getRegionBuilder());
4012 }
4013 
4015  SmallVector<StringRef, 3> elidedAttrs = {
4016  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4017  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4018  elidedAttrs);
4019 
4020  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
4021  BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4022  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4023  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
4024  p << " indexing_maps = [";
4025  llvm::interleaveComma(getIndexingMaps(), p,
4026  [&](Attribute attr) { p.printAttribute(attr); });
4027  p << "]";
4028  }
4029 }
4030 
4031 /// Verify the user defined indexing maps.
4032 LogicalResult BatchMatmulOp::verify() {
4033  // Verification of pure batch_matmul is handled by
4034  // verifyStructuredOpInterface().
4035  if (!hasUserDefinedMaps())
4036  return success();
4037 
4038  for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4039  if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
4040  return failure();
4041  }
4042  return success();
4043 }
4044 
4045 LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4047  return memref::foldMemRefCast(*this);
4048 }
4049 
4050 void BatchMatmulOp::getEffects(
4052  &effects) {
4053  if (hasPureTensorSemantics())
4054  return;
4055  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4056 }
4057 
4058 Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4059  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4060 }
4061 
4062 } // namespace linalg
4063 } // namespace mlir
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
Definition: LinalgOps.cpp:3452
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:1851
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:309
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
Definition: LinalgOps.cpp:2782
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
Definition: LinalgOps.cpp:2853
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
Definition: LinalgOps.cpp:3430
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
Definition: LinalgOps.cpp:125
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:2321
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
Definition: LinalgOps.cpp:3445
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:297
static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp has exactly 3 result di...
Definition: LinalgOps.cpp:3509
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
Definition: LinalgOps.cpp:1684
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
Definition: LinalgOps.cpp:1732
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
Definition: LinalgOps.cpp:2827
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:187
static LogicalResult verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
Definition: LinalgOps.cpp:3532
static Operation * findPayloadOp(Block *body, bool initFirst=false)
Definition: LinalgOps.cpp:1484
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
Definition: LinalgOps.cpp:159
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
Definition: LinalgOps.cpp:1353
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:203
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2804
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
Definition: LinalgOps.cpp:1244
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1211
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2234
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:335
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
Definition: LinalgOps.cpp:986
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:328
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition: LinalgOps.cpp:58
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
Definition: LinalgOps.cpp:1407
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
Definition: LinalgOps.cpp:1513
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:365
static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
Definition: LinalgOps.cpp:3476
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
Definition: LinalgOps.cpp:372
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
Definition: LinalgOps.cpp:223
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
Base type for affine expression.
Definition: AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
Definition: AffineExpr.cpp:316
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:383
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:56
Location getUnknownLoc()
Definition: Builders.cpp:27
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:314
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:222
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_iterator result_begin()
Definition: Operation.h:413
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_iterator result_end()
Definition: Operation.h:414
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:114
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:96
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1148
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1198
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2630
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2313
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:105
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:2354
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
Definition: LinalgOps.cpp:2293
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:2304
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:96
FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
Definition: LinalgOps.cpp:3634
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
DynamicAPInt ceil(const Fraction &f)
Definition: Fraction.h:79
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
Definition: Enums.h:443
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:350
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:68
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
Definition: LinalgOps.cpp:1994
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:1997
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
Definition: LinalgOps.cpp:2019
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:2022
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.