MLIR  21.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
30 #include "mlir/IR/AffineMap.h"
31 #include "mlir/IR/Attributes.h"
32 #include "mlir/IR/Builders.h"
35 #include "mlir/IR/Matchers.h"
38 #include "mlir/IR/PatternMatch.h"
41 
42 #include "llvm/ADT/DenseMap.h"
43 #include "llvm/ADT/STLExtras.h"
44 #include "llvm/ADT/SetOperations.h"
45 #include "llvm/ADT/SmallSet.h"
46 #include "llvm/ADT/SmallVector.h"
47 #include "llvm/ADT/StringSet.h"
48 #include "llvm/ADT/TypeSwitch.h"
49 #include "llvm/Support/FormatVariadic.h"
50 #include "llvm/Support/InterleavedRange.h"
51 #include "llvm/Support/LogicalResult.h"
52 #include "llvm/Support/MathExtras.h"
53 #include "llvm/Support/raw_ostream.h"
54 #include <cassert>
55 #include <optional>
56 
57 using namespace mlir;
58 using namespace mlir::linalg;
59 
60 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
62  int64_t dim) {
63  auto type = cast<ShapedType>(v.getType());
64  if (!type.isDynamicDim(dim))
65  return builder.getIndexAttr(type.getDimSize(dim));
66 
67  return getAsOpFoldResult(
69  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
70  return builder.create<tensor::DimOp>(loc, v, dim);
71  })
72  .Case<MemRefType>([&](MemRefType t) -> Value {
73  return builder.create<memref::DimOp>(loc, v, dim);
74  }));
75 }
76 
77 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
78 /// `source`.
79 static Operation *getSlice(OpBuilder &b, Location loc, Value source,
80  ArrayRef<OpFoldResult> offsets,
82  ArrayRef<OpFoldResult> strides) {
83  return TypeSwitch<Type, Operation *>(source.getType())
84  .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
85  return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
86  strides);
87  })
88  .Case<MemRefType>([&](MemRefType type) -> Operation * {
89  return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
90  strides);
91  })
92  .Default([&](Type t) -> Operation * { return nullptr; });
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // Helper functions
97 //===----------------------------------------------------------------------===//
98 
100  int64_t dim) {
101  if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
102  return b.createOrFold<memref::DimOp>(loc, source, dim);
103  if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
104  return b.createOrFold<tensor::DimOp>(loc, source, dim);
105  llvm_unreachable("Expected MemRefType or TensorType");
106 }
107 
109  int64_t dim) {
110  auto shapedType = llvm::cast<ShapedType>(source.getType());
111  if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
112  return createOrFoldDimOp(b, loc, source, dim);
113  return b.getIndexAttr(shapedType.getDimSize(dim));
114 }
115 
116 //===----------------------------------------------------------------------===//
117 // Support for named Linalg ops defined in ods-gen.
118 //===----------------------------------------------------------------------===//
119 
122 
123 /// Fills the region of a structured operation using the provided
124 /// `regionBuilder`. The method is used by both named structured ops created by
125 /// ods-gen and by manually defined C++ ops. It is called by both builders and
126 /// parsers and creates a block with arguments corresponding to the elemental
127 /// types of `inputTypes` and `outputTypes`.
128 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
129  TypeRange inputTypes, TypeRange outputTypes,
131  RegionBuilderFn regionBuilder) {
132  SmallVector<Type, 8> argTypes;
133  SmallVector<Location, 8> argLocs;
134  for (auto containers : {inputTypes, outputTypes}) {
135  for (auto t : containers) {
136  argTypes.push_back(
137  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
138 
139  // TODO: Pass in a proper location here.
140  argLocs.push_back(opBuilder.getUnknownLoc());
141  }
142  }
143 
144  // RAII.
145  OpBuilder::InsertionGuard guard(opBuilder);
146  Block *body =
147  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
148 
149  opBuilder.setInsertionPointToStart(body);
150  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
151  regionBuilder(b, *body, attrs);
152 
153  // indexing_maps is an auto-generated method.
154 
155  // iterator_types is an auto-generated method.
156 }
157 
158 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
159 /// The result types are derived automatically if `resultTensorTypes` is none.
160 /// The body of the operation is filled using `regionBuilder`. All ods-gen
161 /// created structured operations use the method to implement their builders.
163  std::optional<TypeRange> resultTensorTypes,
164  ValueRange inputs, ValueRange outputs,
165  ArrayRef<NamedAttribute> attributes,
166  RegionBuilderFn regionBuilder) {
167  // Derive the result types if needed.
168  SmallVector<Type> derivedResultTypes =
169  resultTensorTypes.value_or(TypeRange());
170  if (!resultTensorTypes)
171  copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
172  llvm::IsaPred<RankedTensorType>);
173 
174  state.addOperands(inputs);
175  state.addOperands(outputs);
176  state.addTypes(derivedResultTypes);
177 
178  state.addAttributes(attributes);
179  state.addAttribute(
180  "operandSegmentSizes",
181  b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
182  static_cast<int32_t>(outputs.size())}));
183 
184  // Create and fill the region of the structured operation.
185  Region &region = *state.addRegion();
186  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
187  state.attributes.getAttrs(), regionBuilder);
188 }
189 
190 static void buildMatmulOp(OpBuilder &b, OperationState &state,
191  std::optional<TypeRange> resultTensorTypes,
192  ValueRange inputs, ValueRange outputs,
193  ArrayRef<NamedAttribute> attributes,
194  RegionBuilderFn regionBuilder,
195  ArrayRef<AffineMap> indexingMaps) {
196  // Initialize indexingMaps attribute, for MatmulOp.
197  SmallVector<Attribute, 3> indexingMapsAttrVal;
198  indexingMapsAttrVal = llvm::map_to_vector(
199  MatmulOp::getDefaultIndexingMaps(b.getContext()),
200  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
201  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
202  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
203  attributes, regionBuilder);
204 }
205 
207  std::optional<TypeRange> resultTensorTypes,
208  ValueRange inputs, ValueRange outputs,
209  ArrayRef<NamedAttribute> attributes,
210  RegionBuilderFn regionBuilder,
211  ArrayRef<AffineMap> indexingMaps) {
212  // Initialize indexingMaps attribute, for BatchMatmulOp.
213  SmallVector<Attribute, 4> indexingMapsAttrVal;
214  indexingMapsAttrVal =
215  llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
216  return AffineMapAttr::get(map);
217  });
218  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
219  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
220  attributes, regionBuilder);
221 }
222 
224  std::optional<TypeRange> resultTensorTypes,
225  ValueRange inputs, ValueRange outputs,
226  ArrayRef<NamedAttribute> attributes,
227  RegionBuilderFn regionBuilder,
228  ArrayRef<AffineMap> indexingMaps) {
229  // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
230  SmallVector<Attribute, 4> indexingMapsAttrVal;
231  indexingMapsAttrVal =
232  llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
233  return AffineMapAttr::get(map);
234  });
235  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
236  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
237  attributes, regionBuilder);
238 }
239 
240 /// Common parsing used for both named structured ops created by ods-gen and by
241 /// manually defined C++ ops. Does not handle regions.
242 static ParseResult
244  SmallVectorImpl<Type> &inputTypes,
245  SmallVectorImpl<Type> &outputTypes,
246  bool addOperandSegmentSizes = true) {
247  SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
249  outputsOperands;
250 
251  if (succeeded(parser.parseOptionalLess())) {
252  if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
253  return failure();
254  }
255  attrsLoc = parser.getCurrentLocation();
256  if (parser.parseOptionalAttrDict(result.attributes))
257  return failure();
258 
259  if (succeeded(parser.parseOptionalKeyword("ins"))) {
260  if (parser.parseLParen())
261  return failure();
262 
263  inputsOperandsLoc = parser.getCurrentLocation();
264  if (parser.parseOperandList(inputsOperands) ||
265  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
266  return failure();
267  }
268 
269  if (succeeded(parser.parseOptionalKeyword("outs"))) {
270  outputsOperandsLoc = parser.getCurrentLocation();
271  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
272  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
273  return failure();
274  }
275 
276  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
277  result.operands) ||
278  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
279  result.operands))
280  return failure();
281 
282  if (addOperandSegmentSizes) {
283  // This is a bit complex because we're trying to be backward compatible with
284  // operation syntax that mix the inherent attributes and the discardable
285  // ones in the same dictionary. If the properties are used, we append the
286  // operandSegmentSizes there directly. Otherwise we append it to the
287  // discardable attributes dictionary where it is handled by the generic
288  // Operation::create(...) method.
289  if (result.propertiesAttr) {
290  NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
291  attrs.append("operandSegmentSizes",
293  {static_cast<int32_t>(inputsOperands.size()),
294  static_cast<int32_t>(outputsOperands.size())}));
295  result.propertiesAttr = attrs.getDictionary(parser.getContext());
296  } else {
297  result.addAttribute("operandSegmentSizes",
299  {static_cast<int32_t>(inputsOperands.size()),
300  static_cast<int32_t>(outputsOperands.size())}));
301  }
302  }
303  if (!result.propertiesAttr) {
304  std::optional<RegisteredOperationName> info =
305  result.name.getRegisteredInfo();
306  if (info) {
307  if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
308  return parser.emitError(attrsLoc)
309  << "'" << result.name.getStringRef() << "' op ";
310  })))
311  return failure();
312  }
313  }
314  return success();
315 }
316 
318  ValueRange outputs) {
319  if (!inputs.empty())
320  p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
321  if (!outputs.empty())
322  p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
323 }
324 
325 //===----------------------------------------------------------------------===//
326 // Specific parsing and printing for named structured ops created by ods-gen.
327 //===----------------------------------------------------------------------===//
328 
329 static ParseResult parseNamedStructuredOpRegion(
330  OpAsmParser &parser, Region &region, unsigned numRegionArgs,
331  TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
332  RegionBuilderFn regionBuilder) {
333  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
334  return parser.emitError(
335  parser.getCurrentLocation(),
336  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
337  "region expects {0} args, got {1}",
338  numRegionArgs, inputTypes.size() + outputTypes.size()));
339  }
340 
341  OpBuilder opBuilder(parser.getContext());
342  fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
343  regionBuilder);
344  return success();
345 }
346 
347 static ParseResult
349  SmallVectorImpl<Type> &resultTypes) {
350  if (parser.parseOptionalArrowTypeList(resultTypes))
351  return failure();
352  return success();
353 }
354 
355 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
356  OperationState &result,
357  unsigned numRegionArgs,
358  RegionBuilderFn regionBuilder) {
359  // TODO: Enable when ods-gen supports captures.
360  SmallVector<Type, 1> inputTypes, outputTypes;
361  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
362  return failure();
363 
364  // Parse optional attributes.
365  if (parser.parseOptionalAttrDict(result.attributes))
366  return failure();
367 
368  // TODO: consider merging results parsing into region parsing.
369  // Need to wait for declarative assembly resolution to decide.
370  SmallVector<Type, 1> outputTensorsTypes;
371  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
372  return failure();
373  result.addTypes(outputTensorsTypes);
374 
375  std::unique_ptr<Region> region = std::make_unique<Region>();
376  if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
377  outputTypes, result.attributes.getAttrs(),
378  regionBuilder))
379  return failure();
380  result.addRegion(std::move(region));
381 
382  return success();
383 }
384 
386  TypeRange resultTypes) {
387  if (resultTypes.empty())
388  return;
389  p.printOptionalArrowTypeList(resultTypes);
390 }
391 
393  ValueRange inputs, ValueRange outputs,
394  ArrayRef<StringRef> elidedAttrs = {}) {
395  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
396 
397  // Printing is shared with generic ops, except for the region and
398  // attributes.
399  printCommonStructuredOpParts(p, inputs, outputs);
400 
401  // Results printing.
403 
404  // Region is elided.
405 }
406 
407 //===----------------------------------------------------------------------===//
408 // Region builder helper.
409 // TODO: Move this to a utility library.
410 // The public methods on this class are referenced directly from generated code.
411 // Helper build the unary, binary, and type conversion functions defined by the
412 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
413 // class.
414 //
415 // Implementations of the math functions must be polymorphic over numeric types,
416 // internally performing necessary casts. If the function application makes no
417 // sense, then the only recourse is to assert and return nullptr. This can be
418 // extended later if it becomes possible to fail construction of the region. The
419 // invariant should be enforced at a higher level.
420 //
421 // TODO: These helpers are currently type polymorphic over the class of integer
422 // and floating point types, but they will not internally cast within bit
423 // widths of a class (mixed precision such as i8->i32) or across classes
424 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
425 // to be handled with care and work is being considered to extend the op
426 // language to make such cases explicit. In the mean-time, violating this will
427 // fail verification, which is deemed acceptable.
428 //===----------------------------------------------------------------------===//
429 
430 namespace {
431 
432 class RegionBuilderHelper {
433 public:
434  RegionBuilderHelper(OpBuilder &builder, Block &block)
435  : builder(builder), block(block) {}
436 
437  // Build the unary functions defined by OpDSL.
438  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
439  if (!isFloatingPoint(arg))
440  llvm_unreachable("unsupported non numeric type");
441  OpBuilder::InsertionGuard g(builder);
442  builder.setInsertionPointToEnd(&block);
443  switch (unaryFn) {
444  case UnaryFn::exp:
445  return builder.create<math::ExpOp>(arg.getLoc(), arg);
446  case UnaryFn::log:
447  return builder.create<math::LogOp>(arg.getLoc(), arg);
448  case UnaryFn::abs:
449  return builder.create<math::AbsFOp>(arg.getLoc(), arg);
450  case UnaryFn::ceil:
451  return builder.create<math::CeilOp>(arg.getLoc(), arg);
452  case UnaryFn::floor:
453  return builder.create<math::FloorOp>(arg.getLoc(), arg);
454  case UnaryFn::negf:
455  return builder.create<arith::NegFOp>(arg.getLoc(), arg);
456  case UnaryFn::reciprocal: {
457  Attribute oneAttr = builder.getOneAttr(arg.getType());
458  auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
459  ::cast<TypedAttr>(oneAttr));
460  return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
461  }
462  case UnaryFn::round:
463  return builder.create<math::RoundOp>(arg.getLoc(), arg);
464  case UnaryFn::sqrt:
465  return builder.create<math::SqrtOp>(arg.getLoc(), arg);
466  case UnaryFn::rsqrt:
467  return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
468  case UnaryFn::square:
469  return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
470  case UnaryFn::tanh:
471  return builder.create<math::TanhOp>(arg.getLoc(), arg);
472  case UnaryFn::erf:
473  return builder.create<math::ErfOp>(arg.getLoc(), arg);
474  }
475  llvm_unreachable("unsupported unary function");
476  }
477 
478  // Build the binary functions defined by OpDSL.
479  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
480  bool allComplex = isComplex(arg0) && isComplex(arg1);
481  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
482  bool allInteger = isInteger(arg0) && isInteger(arg1);
483  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
484  arg1.getType().getIntOrFloatBitWidth() == 1;
485  if (!allComplex && !allFloatingPoint && !allInteger)
486  llvm_unreachable("unsupported non numeric type");
487  OpBuilder::InsertionGuard g(builder);
488  builder.setInsertionPointToEnd(&block);
489  switch (binaryFn) {
490  case BinaryFn::add:
491  if (allComplex)
492  return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
493  if (allFloatingPoint)
494  return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
495  if (allBool)
496  return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
497  return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
498  case BinaryFn::sub:
499  if (allComplex)
500  return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
501  if (allFloatingPoint)
502  return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
503  if (allBool)
504  llvm_unreachable("unsupported operation: sub with bools");
505  return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
506  case BinaryFn::mul:
507  if (allComplex)
508  return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
509  if (allFloatingPoint)
510  return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
511  if (allBool)
512  return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
513  return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
514  case BinaryFn::div:
515  if (allComplex)
516  return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
517  if (allFloatingPoint)
518  return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
519  if (allBool)
520  llvm_unreachable("unsupported operation: div with bools");
521  return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
522  case BinaryFn::div_unsigned:
523  if (!allInteger || allBool)
524  llvm_unreachable("unsupported operation: unsigned div not on uint");
525  return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
526  case BinaryFn::max_signed:
527  assert(!allComplex);
528  if (allFloatingPoint)
529  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
530  return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
531  case BinaryFn::min_signed:
532  assert(!allComplex);
533  if (allFloatingPoint)
534  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
535  return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
536  case BinaryFn::max_unsigned:
537  assert(!allComplex);
538  if (allFloatingPoint)
539  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
540  return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
541  case BinaryFn::min_unsigned:
542  assert(!allComplex);
543  if (allFloatingPoint)
544  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
545  return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
546  case BinaryFn::powf:
547  assert(allFloatingPoint);
548  return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
549  }
550  llvm_unreachable("unsupported binary function");
551  }
552 
553  // Build the ternary functions defined by OpDSL.
554  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
555  Value arg2) {
556  bool headBool =
557  isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
558  bool tailFloatingPoint =
559  isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
560  bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
561  OpBuilder::InsertionGuard g(builder);
562  builder.setInsertionPointToEnd(&block);
563  switch (ternaryFn) {
564  case TernaryFn::select:
565  if (!headBool && !(tailFloatingPoint || tailInteger))
566  llvm_unreachable("unsupported non numeric type");
567  return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
568  }
569  llvm_unreachable("unsupported ternary function");
570  }
571 
572  // Build the type functions defined by OpDSL.
573  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
574  switch (typeFn) {
575  case TypeFn::cast_signed:
576  return cast(toType, operand, false);
577  case TypeFn::cast_unsigned:
578  return cast(toType, operand, true);
579  }
580  llvm_unreachable("unsupported type conversion function");
581  }
582 
583  void yieldOutputs(ValueRange values) {
584  OpBuilder::InsertionGuard g(builder);
585  builder.setInsertionPointToEnd(&block);
586  Location loc = builder.getUnknownLoc();
587  builder.create<YieldOp>(loc, values);
588  }
589 
590  Value constant(const std::string &value) {
591  OpBuilder::InsertionGuard g(builder);
592  builder.setInsertionPointToEnd(&block);
593  Location loc = builder.getUnknownLoc();
594  Attribute valueAttr = parseAttribute(value, builder.getContext());
595  return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
596  }
597 
598  Value index(int64_t dim) {
599  OpBuilder::InsertionGuard g(builder);
600  builder.setInsertionPointToEnd(&block);
601  return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
602  }
603 
604  Type getIntegerType(unsigned width) {
605  return IntegerType::get(builder.getContext(), width);
606  }
607 
608  Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
609  Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
610 
611 private:
612  // Generates operations to cast the given operand to a specified type.
613  // If the cast cannot be performed, a warning will be issued and the
614  // operand returned as-is (which will presumably yield a verification
615  // issue downstream).
616  Value cast(Type toType, Value operand, bool isUnsignedCast) {
617  OpBuilder::InsertionGuard g(builder);
618  builder.setInsertionPointToEnd(&block);
619  auto loc = operand.getLoc();
620  return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
621  }
622 
623  bool isComplex(Value value) {
624  return llvm::isa<ComplexType>(value.getType());
625  }
626  bool isFloatingPoint(Value value) {
627  return llvm::isa<FloatType>(value.getType());
628  }
629  bool isInteger(Value value) {
630  return llvm::isa<IntegerType>(value.getType());
631  }
632 
633  OpBuilder &builder;
634  Block &block;
635 };
636 
637 } // namespace
638 
639 //===----------------------------------------------------------------------===//
640 // CopyOp
641 //===----------------------------------------------------------------------===//
642 
643 namespace {
644 
645 struct EraseSelfCopy : OpRewritePattern<CopyOp> {
647  LogicalResult matchAndRewrite(CopyOp copyOp,
648  PatternRewriter &rewriter) const override {
649  if (copyOp.getInputs() != copyOp.getOutputs())
650  return rewriter.notifyMatchFailure(copyOp, "not a self copy");
651  if (copyOp.hasPureBufferSemantics())
652  rewriter.eraseOp(copyOp);
653  else
654  rewriter.replaceOp(copyOp, copyOp.getInputs());
655 
656  return success();
657  }
658 };
659 
660 } // namespace
661 
662 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
663  MLIRContext *context) {
664  results.add<EraseSelfCopy>(context);
665 }
666 
667 //===----------------------------------------------------------------------===//
668 // FillOp
669 //===----------------------------------------------------------------------===//
670 
671 namespace {
672 
673 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
674 ///
675 /// For such op chains, we can create new linalg.fill ops with the result
676 /// type of the tensor.expand/collapse_shape op.
677 template <typename TensorReshapeOp>
678 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
680  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
681  PatternRewriter &rewriter) const override {
682  auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
683  if (!oldFill)
684  return failure();
685 
686  Location loc = oldFill.getLoc();
687  TensorReshapeOp newInit;
688  if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
689 
690  newInit = rewriter.create<TensorReshapeOp>(
691  loc, reshapeOp.getResultType(), oldFill.output(),
692  reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
693  reshapeOp.getStaticOutputShape());
694  } else {
695  newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
696  oldFill.output(),
697  reshapeOp.getReassociation());
698  }
699  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
700  ValueRange{newInit});
701  return success();
702  }
703 };
704 
705 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
706 /// filling value are the same.
707 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
709 
710  LogicalResult matchAndRewrite(tensor::PadOp padOp,
711  PatternRewriter &rewriter) const override {
712  auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
713  if (!fillOp)
714  return failure();
715 
716  // We can only fold if the padding value is the same as the original
717  // filling value.
718  Value padValue = padOp.getConstantPaddingValue();
719  if (!padValue || fillOp.value() != padValue)
720  return failure();
721 
722  ReifiedRankedShapedTypeDims reifiedShape;
723  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
724  return rewriter.notifyMatchFailure(
725  padOp, "failed to reify tensor.pad op result shape");
726 
727  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
728  padOp.getLoc(), reifiedShape.front(),
729  padOp.getResultType().getElementType());
730  Value replacement =
731  rewriter
732  .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
733  ValueRange{emptyTensor})
734  .getResult(0);
735  if (replacement.getType() != padOp.getResultType()) {
736  replacement = rewriter.create<tensor::CastOp>(
737  fillOp.getLoc(), padOp.getResultType(), replacement);
738  }
739  rewriter.replaceOp(padOp, replacement);
740  return success();
741  }
742 };
743 
744 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
745 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
746 /// filling value are the same.
747 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
749 
750  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
751  PatternRewriter &rewriter) const override {
752  auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
753  if (!srcPadOp)
754  return failure();
755 
756  if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
757  return failure();
758 
759  // Walk back the tensor.insert_slice chain and find the first destination
760  // value at the start of the chain.
761  Value firstDest = insertOp.getDest();
762  while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
763  if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
764  return failure();
765 
766  // Make sure the range of values accessed are disjoint. Without this, we
767  // cannot fold tensor.pad away.
768  bool disjoint = false;
769  for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
770  // If the dimension has dynamic offset/size, we cannot guarantee
771  // disjoint. So just skip it.
772  if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
773  insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
774  prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
775  continue;
776 
777  // Get the range start and end, inclusively for both.
778  int64_t prevStart = prevOp.getStaticOffset(i);
779  int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
780  prevOp.getStaticStride(i);
781  int64_t nextStart = insertOp.getStaticOffset(i);
782  int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
783  insertOp.getStaticStride(i);
784  if (prevEnd < nextStart || nextEnd < prevStart) {
785  disjoint = true;
786  break;
787  }
788  }
789 
790  if (!disjoint)
791  break;
792  firstDest = prevOp.getDest();
793  }
794 
795  // Check whether the first destination is a fill op. For overlapped cases,
796  // this also cannot be true.
797  auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
798  if (!dstFillOp)
799  return failure();
800 
801  // We can only fold if the padding value is the same as the original
802  // filling value.
803  Value padValue = srcPadOp.getConstantPaddingValue();
804  if (!padValue || dstFillOp.value() != padValue)
805  return failure();
806 
807  SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
808  SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
809 
810  Location loc = insertOp.getLoc();
811  MLIRContext *context = getContext();
812 
813  AffineExpr sym0, sym1;
814  bindSymbols(context, sym0, sym1);
815  auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
816 
817  // Calculate the new offsets for the insert. It should be the old offsets
818  // plus low padding sizes.
819  SmallVector<OpFoldResult, 4> newOffsets;
820  for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
821  newOffsets.push_back(affine::makeComposedFoldedAffineApply(
822  rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
823  }
824 
825  RankedTensorType srcPadType = srcPadOp.getSourceType();
827  for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
828  if (srcPadType.isDynamicDim(i)) {
829  newSizes.push_back(
830  rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
831  .getResult());
832  } else {
833  newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
834  }
835  }
836 
837  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
838  insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
839  newSizes, insertOp.getMixedStrides());
840  return success();
841  }
842 };
843 
844 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
845 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
846 public:
848 
849  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
850  PatternRewriter &rewriter) const override {
851  // See if tensor input of tensor.extract op is the result of a linalg.fill
852  // op.
853  auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
854  if (!fillOp)
855  return failure();
856 
857  // Get scalar input operand of linalg.fill op.
858  Value extractedScalar = fillOp.getInputs()[0];
859 
860  // Replace tensor.extract op with scalar value used to fill the tensor.
861  rewriter.replaceOp(extractOp, extractedScalar);
862  return success();
863  }
864 };
865 
866 /// Folds pack(fill) into a single fill op if
867 /// 1. The pack op does not have padding value, or
868 /// 2. The filled value and padding value are the same.
869 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
870  linalg::PackOp packOp) {
871  auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
872  if (!fillOp)
873  return failure();
874 
875  if (auto paddingValue = packOp.getPaddingValue())
876  if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
877  return failure();
878 
879  Value packOpDest = packOp.getDest();
880  if (!packOpDest.hasOneUse())
881  return failure();
882 
883  return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
884  packOp.getDest());
885 }
886 
887 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
888 struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
889 public:
890  FoldFillWithPack(MLIRContext *context)
891  : OpRewritePattern<linalg::PackOp>(context) {}
892 
893  LogicalResult matchAndRewrite(linalg::PackOp packOp,
894  PatternRewriter &rewriter) const override {
895  auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
896  if (failed(fillOp))
897  return failure();
898  rewriter.replaceOp(packOp, fillOp.value().result());
899  return success();
900  }
901 };
902 
903 /// Fold fill with copy.
904 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
906 
907  LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
908  PatternRewriter &rewriter) const override {
909  if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
910  rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
911  fillOp.getInputs(),
912  copyOp.getOutputs());
913  return success();
914  }
915  if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
916  rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
917  fillOp.getOutputs());
918  return success();
919  }
920  return failure();
921  }
922 };
923 
924 /// Fold fill with transpose.
925 struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
927 
928  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
929  PatternRewriter &rewriter) const override {
930  if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
931  rewriter.replaceOpWithNewOp<FillOp>(
932  transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
933  transposeOp.getDpsInitOperand(0)->get());
934  return success();
935  }
936  return failure();
937  }
938 };
939 
940 /// Fold a concat with all elements being fills of the same value
941 /// into a fill of the concat result shape.
942 struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
944 
945  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
946  PatternRewriter &rewriter) const override {
947  auto concatOperands = concatOp.getInputs();
948  if (concatOperands.empty()) {
949  return failure();
950  }
951 
952  auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
953  if (!firstFillOp) {
954  return failure();
955  }
956  // Prefetch the fill value.
957  OpFoldResult firstFillVal =
958  getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
959  // Collect all the outs values for the fill operations.
960  SmallVector<Value> allOuts;
961  allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
962 
963  auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
964  auto fillOp = v.getDefiningOp<linalg::FillOp>();
965  if (!fillOp) {
966  return false;
967  }
968 
969  OpFoldResult fillVal =
970  getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
971  if (fillVal != firstFillVal)
972  return false;
973 
974  allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
975  return true;
976  };
977  if (!llvm::all_of(concatOperands.drop_front(),
978  isDefinedByCompatibleFillOp)) {
979  return rewriter.notifyMatchFailure(
980  concatOp, "not all operands are defined by a compatible fill op");
981  }
982 
983  Value outsConcat = rewriter.create<tensor::ConcatOp>(
984  concatOp.getLoc(), concatOp.getDim(), allOuts);
985  rewriter.replaceOpWithNewOp<linalg::FillOp>(
986  concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
987  return success();
988  }
989 };
990 
991 } // namespace
992 
993 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
994  MLIRContext *context) {
995  results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
996  FoldFillWithPack, FoldFillWithPad,
997  FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
998  FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
999  FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1000 }
1001 
1002 //===----------------------------------------------------------------------===//
1003 // GenericOp
1004 //===----------------------------------------------------------------------===//
1005 
1007  OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
1008  ValueRange outputs,
1009  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
1010  SmallVector<Type, 4> blockArgTypes;
1011  SmallVector<Location, 4> blockArgLocs;
1012  for (ValueRange container : {inputs, outputs}) {
1013  for (Value v : container) {
1014  Type t = v.getType();
1015  blockArgTypes.push_back(
1016  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
1017  blockArgLocs.push_back(v.getLoc());
1018  }
1019  }
1020 
1021  OpBuilder::InsertionGuard guard(builder);
1022  Block *bodyBlock =
1023  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1024  bodyBuild(builder, loc, bodyBlock->getArguments());
1025 }
1026 
1027 void GenericOp::getAsmBlockArgumentNames(Region &region,
1028  OpAsmSetValueNameFn setNameFn) {
1029  for (Value v : getRegionInputArgs())
1030  setNameFn(v, "in");
1031  for (Value v : getRegionOutputArgs())
1032  setNameFn(v, "out");
1033 }
1034 
1035 void GenericOp::build(
1036  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1037  ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1038  ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1039  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1040  ArrayRef<NamedAttribute> attributes) {
1041  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1042  iteratorTypes, doc, libraryCall);
1043  result.addAttributes(attributes);
1044  if (bodyBuild)
1045  buildGenericRegion(builder, result.location, *result.regions.front(),
1046  inputs, outputs, bodyBuild);
1047 }
1048 
1049 void GenericOp::build(
1050  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1051  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1052  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1053  StringRef libraryCall,
1054  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1055  ArrayRef<NamedAttribute> attributes) {
1056  build(builder, result, resultTensorTypes, inputs, outputs,
1057  builder.getAffineMapArrayAttr(indexingMaps),
1058  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1059  iteratorTypes,
1060  [&](utils::IteratorType iter) -> mlir::Attribute {
1061  return IteratorTypeAttr::get(builder.getContext(), iter);
1062  }))),
1063  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1064  libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1065  bodyBuild, attributes);
1066 }
1067 
1068 void GenericOp::build(
1069  OpBuilder &builder, OperationState &result, ValueRange inputs,
1070  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1071  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1072  StringRef libraryCall,
1073  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1074  ArrayRef<NamedAttribute> attributes) {
1075  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1076  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1077 }
1078 
1079 void GenericOp::build(
1080  OpBuilder &builder, OperationState &result, ValueRange inputs,
1081  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1082  ArrayRef<utils::IteratorType> iteratorTypes,
1083  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1084  ArrayRef<NamedAttribute> attributes) {
1085  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1086  /*doc=*/"",
1087  /*libraryCall=*/"", bodyBuild, attributes);
1088 }
1089 
1090 void GenericOp::build(
1091  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1092  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1093  ArrayRef<utils::IteratorType> iteratorTypes,
1094  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1095  ArrayRef<NamedAttribute> attributes) {
1096  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1097  iteratorTypes,
1098  /*doc=*/"",
1099  /*libraryCall=*/"", bodyBuild, attributes);
1100 }
1101 
1102 void GenericOp::print(OpAsmPrinter &p) {
1103  p << " ";
1104 
1105  // Print extra attributes.
1106  auto genericAttrNames = linalgTraitAttrNames();
1107 
1108  llvm::StringSet<> genericAttrNamesSet;
1109  genericAttrNamesSet.insert_range(genericAttrNames);
1110  SmallVector<NamedAttribute, 8> genericAttrs;
1111  for (auto attr : (*this)->getAttrs()) {
1112  if (attr.getName() == getIteratorTypesAttrName()) {
1113  auto iteratorTypes =
1114  llvm::cast<ArrayAttr>(attr.getValue())
1115  .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1116  // Convert IteratorType enums into the string representation. This is
1117  // needed, because tests still use the old format when 'iterator_types'
1118  // attribute is represented as an array of strings.
1119  // TODO: Remove this conversion once tests are fixed.
1120  SmallVector<Attribute> iteratorTypeNames =
1121  llvm::to_vector(llvm::map_range(
1122  iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1123  return StringAttr::get(getContext(), stringifyIteratorType(t));
1124  }));
1125 
1126  genericAttrs.emplace_back(
1127  getIteratorTypesAttrName(),
1128  ArrayAttr::get(getContext(), iteratorTypeNames));
1129  } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1130  genericAttrs.push_back(attr);
1131  }
1132  }
1133  if (!genericAttrs.empty()) {
1134  auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1135  p << genericDictAttr;
1136  }
1137 
1138  // Printing is shared with named ops, except for the region and attributes
1139  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1140 
1141  genericAttrNames.push_back("operandSegmentSizes");
1142  genericAttrNamesSet.insert(genericAttrNames.back());
1143 
1144  bool hasExtraAttrs = false;
1145  for (NamedAttribute n : (*this)->getAttrs()) {
1146  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1147  break;
1148  }
1149  if (hasExtraAttrs) {
1150  p << " attrs = ";
1151  p.printOptionalAttrDict((*this)->getAttrs(),
1152  /*elidedAttrs=*/genericAttrNames);
1153  }
1154 
1155  // Print region.
1156  if (!getRegion().empty()) {
1157  p << ' ';
1158  p.printRegion(getRegion());
1159  }
1160 
1161  // Print results.
1162  printNamedStructuredOpResults(p, getResultTensors().getTypes());
1163 }
1164 
1165 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1166  DictionaryAttr dictAttr;
1167  // Parse the core linalg traits that must check into a dictAttr.
1168  // The name is unimportant as we will overwrite result.attributes.
1169  // The core linalg traits must contain the information necessary to pass the
1170  // verifier.
1171  llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1172  if (parser.parseAttribute(dictAttr, "_", result.attributes))
1173  return failure();
1174  result.attributes.assign(dictAttr.getValue().begin(),
1175  dictAttr.getValue().end());
1176 
1177  // Convert array of string into an array of IteratorType enums. This is
1178  // needed, because tests still use the old format when 'iterator_types'
1179  // attribute is represented as an array of strings.
1180  // TODO: Remove this conversion once tests are fixed.
1181  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1182  result.attributes.get(getIteratorTypesAttrName(result.name)));
1183  if (!iteratorTypes) {
1184  return parser.emitError(attributeLocation)
1185  << "expected " << getIteratorTypesAttrName(result.name)
1186  << " array attribute";
1187  }
1188 
1189  SmallVector<Attribute> iteratorTypeAttrs;
1190 
1191  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1192  auto maybeIteratorType = utils::symbolizeIteratorType(s);
1193  if (!maybeIteratorType.has_value())
1194  return parser.emitError(parser.getCurrentLocation())
1195  << "unexpected iterator_type (" << s << ")";
1196 
1197  iteratorTypeAttrs.push_back(
1198  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1199  }
1200  result.attributes.set(getIteratorTypesAttrName(result.name),
1201  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1202 
1203  // Parsing is shared with named ops, except for the region.
1204  SmallVector<Type, 1> inputTypes, outputTypes;
1205  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1206  return failure();
1207 
1208  // Optional attributes may be added.
1209  if (succeeded(parser.parseOptionalKeyword("attrs")))
1210  if (failed(parser.parseEqual()) ||
1211  failed(parser.parseOptionalAttrDict(result.attributes)))
1212  return failure();
1213 
1214  std::unique_ptr<Region> region = std::make_unique<Region>();
1215  if (parser.parseRegion(*region, {}))
1216  return failure();
1217  result.addRegion(std::move(region));
1218 
1219  // Generic ops may specify that a subset of its outputs are tensors. Such
1220  // outputs are specified in the result type.
1221  // TODO: may need to move output parsing before region parsing.
1222  // Need to wait for declarative assembly resolution to decide.
1223  SmallVector<Type, 1> outputTensorsTypes;
1224  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1225  return failure();
1226  result.addTypes(outputTensorsTypes);
1227 
1228  return success();
1229 }
1230 
1233  &effects,
1234  LinalgOp linalgOp) {
1235  for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1236  if (!llvm::isa<MemRefType>(operand.getType()))
1237  continue;
1238  effects.emplace_back(
1239  MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1240  /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1241  }
1242 
1243  for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1244  if (!llvm::isa<MemRefType>(operand.get().getType()))
1245  continue;
1246  if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1247  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1248  /*effectOnFullRegion=*/true,
1250  }
1251  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1252  /*effectOnFullRegion=*/true,
1254  }
1255 }
1256 
1257 void GenericOp::getEffects(
1259  &effects) {
1260  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1261 }
1262 
1264 getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
1265  // Operands with value semantics are speculatable, while operands with memory
1266  // semantics are not.
1267  if (!linalgOp.hasPureTensorSemantics())
1269  // The body of the op can still have speculation in its region.
1271 }
1272 
1273 Speculation::Speculatability GenericOp::getSpeculatability() {
1274  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1275 }
1276 
1277 LogicalResult GenericOp::verify() { return success(); }
1278 
1279 namespace {
1280 
1281 /// Remove linalg operations that are just copying the values from inputs to
1282 /// results. In the memref case, the operation must be copying to and from the
1283 /// same value. Requirements are:
1284 /// 1) All iterator types are parallel
1285 /// 2) The body contains just a yield operation with the yielded values being
1286 /// the arguments corresponding to the operands.
1287 template <typename OpTy>
1288 struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1290 
1291  LogicalResult matchAndRewrite(OpTy linalgOp,
1292  PatternRewriter &rewriter) const override {
1293  // All indexing maps must be equal. It follows that they are permutations.
1294  if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1295  return failure();
1296 
1297  // Check that the body of the linalg operation is just a linalg.yield
1298  // operation.
1299  Block &body = linalgOp->getRegion(0).front();
1300  if (!llvm::hasSingleElement(body))
1301  return failure();
1302  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1303  if (!yieldOp)
1304  return failure();
1305 
1306  // In the buffer case, we need to check exact buffer equality.
1307  if (linalgOp.hasPureBufferSemantics()) {
1308  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1309  linalgOp.getDpsInputOperand(0)->get() !=
1310  linalgOp.getDpsInitOperand(0)->get()) {
1311  return rewriter.notifyMatchFailure(
1312  linalgOp, "expected single input and output to be the same value");
1313  }
1314 
1315  auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1316  if (!yieldArg || yieldArg.getOwner() != &body) {
1317  return rewriter.notifyMatchFailure(linalgOp,
1318  "cannot fold fill-like op");
1319  }
1320 
1321  rewriter.eraseOp(linalgOp);
1322  return success();
1323  }
1324 
1325  if (!linalgOp.hasPureTensorSemantics()) {
1326  return rewriter.notifyMatchFailure(
1327  linalgOp, "mixed semantics is not supported yet");
1328  }
1329 
1330  // Get the argument number of the returned values. That is the operand
1331  // number to use for replacing uses of this operation.
1332  SmallVector<Value> returnedArgs;
1333  for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1334  auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1335  if (!yieldArg || yieldArg.getOwner() != &body)
1336  return failure();
1337  unsigned argumentNumber = yieldArg.getArgNumber();
1338  Value returnedArg = linalgOp->getOperand(argumentNumber);
1339  Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1340  // The input can have a different type than the result, e.g. a dynamic
1341  // input dimension can be turned into a static output dimension.
1342  Type returnType = returnedArg.getType();
1343  if (returnType != resultType) {
1344  // Distinguish between sparse conversion or dense tensor casting.
1345  // TODO: unify the two ops?
1346  if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1348  returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1349  linalgOp.getLoc(), resultType, returnedArg);
1350  else {
1351  if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1352  resultType))
1353  return failure();
1354  returnedArg = rewriter.create<tensor::CastOp>(
1355  linalgOp.getLoc(), resultType, returnedArg);
1356  }
1357  }
1358  returnedArgs.push_back(returnedArg);
1359  }
1360 
1361  if (returnedArgs.size() != linalgOp->getNumResults())
1362  return failure();
1363  rewriter.replaceOp(linalgOp, returnedArgs);
1364  return success();
1365  }
1366 };
1367 
1368 } // namespace
1369 
1370 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1371  MLIRContext *context) {
1372  results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1373 }
1374 
1375 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1376  return memref::foldMemRefCast(*this);
1377 }
1378 
1379 //===----------------------------------------------------------------------===//
1380 // MapOp
1381 //===----------------------------------------------------------------------===//
1382 
1383 static ParseResult parseDstStyleOp(
1384  OpAsmParser &parser, OperationState &result,
1385  function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1386  nullptr) {
1387  // Parse `ins` and `outs`.
1388  SmallVector<Type, 4> inputTypes, outputTypes;
1389  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1390  /*addOperandSegmentSizes=*/false))
1391  return failure();
1392 
1393  // Add result types.
1394  for (Type outputType : outputTypes) {
1395  if (llvm::isa<RankedTensorType>(outputType))
1396  result.addTypes(outputType);
1397  }
1398 
1399  // Parse required attributes.
1400  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1401  return failure();
1402 
1403  // Parse optional attributes.
1404  if (parser.parseOptionalAttrDict(result.attributes))
1405  return failure();
1406  return success();
1407 }
1408 
1409 void MapOp::getAsmBlockArgumentNames(Region &region,
1410  OpAsmSetValueNameFn setNameFn) {
1411  for (Value v : getRegionInputArgs())
1412  setNameFn(v, "in");
1413 }
1414 
1415 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1416  if (!getResults().empty())
1417  setNameFn(getResults().front(), "mapped");
1418 }
1419 
1420 void MapOp::build(
1421  OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1422  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1423  ArrayRef<NamedAttribute> attributes) {
1424  build(builder, result, TypeRange{}, inputs, init);
1425  result.addAttributes(attributes);
1426 
1427  // Add output types for `RankedTensorType` output arguments.
1428  Type initType = init.getType();
1429  if (llvm::isa<RankedTensorType>(initType))
1430  result.addTypes(initType);
1431 
1432  if (bodyBuild)
1433  buildGenericRegion(builder, result.location, *result.regions.front(),
1434  inputs, /*outputs=*/{}, bodyBuild);
1435 }
1436 
1437 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1438  const OperationName &payloadOpName,
1439  const NamedAttrList &payloadOpAttrs,
1440  ArrayRef<Value> operands,
1441  bool initFirst = false) {
1442  OpBuilder b(parser.getContext());
1443  Region *body = result.addRegion();
1444  Block &block = body->emplaceBlock();
1445  b.setInsertionPointToStart(&block);
1446  for (auto &operand : operands) {
1447  block.addArgument(
1448  llvm::cast<ShapedType>(operand.getType()).getElementType(),
1449  b.getUnknownLoc());
1450  }
1451  SmallVector<Value> payloadOpOperands;
1452  // If initFirst flag is enabled, we consider init as the first position of
1453  // payload operands.
1454  if (initFirst) {
1455  payloadOpOperands.push_back(block.getArguments().back());
1456  for (const auto &arg : block.getArguments().drop_back())
1457  payloadOpOperands.push_back(arg);
1458  } else {
1459  payloadOpOperands = {block.getArguments().begin(),
1460  block.getArguments().end()};
1461  }
1462 
1463  Operation *payloadOp = b.create(
1464  result.location, b.getStringAttr(payloadOpName.getStringRef()),
1465  payloadOpOperands,
1466  TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1467  .getElementType()},
1468  payloadOpAttrs);
1469  b.create<YieldOp>(result.location, payloadOp->getResults());
1470 }
1471 
1472 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1473  std::optional<OperationName> payloadOpName;
1474  NamedAttrList payloadOpAttrs;
1475  if (succeeded(parser.parseOptionalLBrace())) {
1476  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1477  if (failed(operationName))
1478  return failure();
1479  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1480  return failure();
1481  payloadOpName = operationName.value();
1482  if (parser.parseRBrace())
1483  return failure();
1484  }
1485 
1486  if (parseDstStyleOp(parser, result))
1487  return failure();
1488 
1489  if (payloadOpName.has_value()) {
1490  if (!result.operands.empty())
1491  addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1492  payloadOpAttrs,
1493  ArrayRef(result.operands).drop_back());
1494  else
1495  result.addRegion();
1496  } else {
1498  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1499  /*allowType=*/true, /*allowAttrs=*/true)) {
1500  return failure();
1501  }
1502  Region *body = result.addRegion();
1503  if (parser.parseRegion(*body, regionArgs))
1504  return failure();
1505  }
1506  return success();
1507 }
1508 
1509 // Retrieve the operation from the body, if it is the only one (except
1510 // yield) and if it gets the same amount of arguments as the body does.
1511 // If initFirst flag is enabled, we check that init takes the first position in
1512 // operands of payload.
1513 static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1514  if (body->getOperations().size() != 2)
1515  return nullptr;
1516  Operation &payload = body->getOperations().front();
1517  assert(isa<YieldOp>(body->getOperations().back()));
1518 
1519  if (payload.getNumOperands() == 0 ||
1520  payload.getNumOperands() != body->getNumArguments())
1521  return nullptr;
1522  if (initFirst) {
1523  // check init
1524  if (payload.getOperands().back() != body->getArgument(0))
1525  return nullptr;
1526  // check rest
1527  for (const auto &[operand, bbArg] :
1528  llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1529  if (bbArg != operand)
1530  return nullptr;
1531  }
1532  } else {
1533  for (const auto &[operand, bbArg] :
1534  llvm::zip(payload.getOperands(), body->getArguments())) {
1535  if (bbArg != operand)
1536  return nullptr;
1537  }
1538  }
1539  return &payload;
1540 }
1541 
1542 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1543  SmallVector<StringRef> elidedAttrs;
1544  std::string attrToElide;
1545  p << " { " << payloadOp->getName().getStringRef();
1546  for (const auto &attr : payloadOp->getAttrs()) {
1547  auto fastAttr =
1548  llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1549  if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1550  attrToElide = attr.getName().str();
1551  elidedAttrs.push_back(attrToElide);
1552  break;
1553  }
1554  }
1555  p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1556  p << " }";
1557 }
1558 
1559 void MapOp::print(OpAsmPrinter &p) {
1560  Block *mapper = getBody();
1561  Operation *payloadOp = findPayloadOp(mapper);
1562  if (payloadOp) {
1563  printShortForm(p, payloadOp);
1564  }
1565 
1566  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1567  p.printOptionalAttrDict((*this)->getAttrs());
1568 
1569  if (!payloadOp) {
1570  // Print region if the payload op was not detected.
1571  p.increaseIndent();
1572  p.printNewline();
1573  p << "(";
1574  llvm::interleaveComma(mapper->getArguments(), p,
1575  [&](auto arg) { p.printRegionArgument(arg); });
1576  p << ") ";
1577 
1578  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1579  p.decreaseIndent();
1580  }
1581 }
1582 
1583 LogicalResult MapOp::verify() {
1584  auto *bodyBlock = getBody();
1585  auto blockArgs = bodyBlock->getArguments();
1586 
1587  // Checks if the number of `inputs` match the arity of the `mapper` region.
1588  if (getInputs().size() != blockArgs.size())
1589  return emitOpError() << "expects number of operands to match the arity of "
1590  "mapper, but got: "
1591  << getInputs().size() << " and " << blockArgs.size();
1592 
1593  // The parameters of mapper should all match the element type of inputs.
1594  for (const auto &[bbArgType, inputArg] :
1595  llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1596  auto inputElemType =
1597  llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1598  if (bbArgType != inputElemType) {
1599  return emitOpError() << "expected element type of input " << inputElemType
1600  << " to match bbArg type " << bbArgType;
1601  }
1602  }
1603 
1604  // The shape of each input must match the shape of the output.
1605  auto outputShape = getInit().getType().getShape();
1606  for (Type inputArgType : TypeRange{getInputs()}) {
1607  auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1608  if (inputElemShape != outputShape) {
1609  return emitOpError() << "expected shape of input (" << inputElemShape
1610  << ") to match shape of output (" << outputShape
1611  << ")";
1612  }
1613  }
1614 
1615  return success();
1616 }
1617 
1618 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1619  int64_t rank = getInit().getType().getRank();
1620  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1621 }
1622 
1623 ArrayAttr MapOp::getIndexingMaps() {
1624  Builder builder(getContext());
1625  int64_t rank = getInit().getType().getRank();
1626  int64_t numIndexingMaps = getOperands().size();
1628  numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1629 }
1630 
1631 void MapOp::getEffects(
1633  &effects) {
1634  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1635 }
1636 
1637 Speculation::Speculatability MapOp::getSpeculatability() {
1638  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1639 }
1640 
1641 //===----------------------------------------------------------------------===//
1642 // ReduceOp
1643 //===----------------------------------------------------------------------===//
1644 
1645 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1646  OpAsmSetValueNameFn setNameFn) {
1647  for (Value v : getRegionInputArgs())
1648  setNameFn(v, "in");
1649  for (Value v : getRegionOutputArgs())
1650  setNameFn(v, "init");
1651 }
1652 
1653 void ReduceOp::getAsmResultNames(
1654  function_ref<void(Value, StringRef)> setNameFn) {
1655  if (!getResults().empty())
1656  setNameFn(getResults().front(), "reduced");
1657 }
1658 
1659 void ReduceOp::build(
1660  OpBuilder &builder, OperationState &result, ValueRange inputs,
1661  ValueRange inits, ArrayRef<int64_t> dimensions,
1662  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1663  ArrayRef<NamedAttribute> attributes) {
1664  build(builder, result, TypeRange{}, inputs, inits, dimensions);
1665  result.addAttributes(attributes);
1666 
1667  // Add output types for `RankedTensorType` output arguments.
1668  for (Value init : inits) {
1669  Type initType = init.getType();
1670  if (llvm::isa<RankedTensorType>(initType))
1671  result.addTypes(initType);
1672  }
1673 
1674  if (bodyBuild)
1675  buildGenericRegion(builder, result.location, *result.regions.front(),
1676  inputs, inits, bodyBuild);
1677 }
1678 
1679 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1680  int64_t inputRank =
1681  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1682  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1683  utils::IteratorType::parallel);
1684  for (int64_t reductionDim : getDimensions())
1685  iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1686  return iteratorTypes;
1687 }
1688 
1689 ArrayAttr ReduceOp::getIndexingMaps() {
1690  int64_t inputRank =
1691  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1692  SmallVector<AffineMap> affineMaps(
1693  getNumDpsInputs(),
1695  AffineMap resultMap =
1697  .dropResults(getDimensions());
1698  for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1699  affineMaps.push_back(resultMap);
1700  return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1701 }
1702 
1703 void ReduceOp::getEffects(
1705  &effects) {
1706  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1707 }
1708 
1709 Speculation::Speculatability ReduceOp::getSpeculatability() {
1710  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1711 }
1712 
1713 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1714  NamedAttrList &attributes,
1715  StringRef attributeName) {
1716  if (parser.parseKeyword(attributeName) || parser.parseEqual())
1717  return failure();
1718 
1719  attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1720  return success();
1721 }
1722 
1723 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1724  std::optional<OperationName> payloadOpName;
1725  NamedAttrList payloadOpAttrs;
1726  if (succeeded(parser.parseOptionalLBrace())) {
1727  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1728  if (failed(operationName))
1729  return failure();
1730  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1731  return failure();
1732  payloadOpName = operationName.value();
1733  if (parser.parseRBrace())
1734  return failure();
1735  }
1736 
1737  if (parseDstStyleOp(
1738  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1739  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1740  }))
1741  return failure();
1742 
1743  if (payloadOpName.has_value()) {
1744  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1745  ArrayRef(result.operands), /*initFirst=*/true);
1746  } else {
1748  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1749  /*allowType=*/true, /*allowAttrs=*/true)) {
1750  return failure();
1751  }
1752 
1753  Region *body = result.addRegion();
1754  if (parser.parseRegion(*body, regionArgs))
1755  return failure();
1756  }
1757 
1758  return success();
1759 }
1760 
1761 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1762  ArrayRef<int64_t> attributeValue) {
1763  p << ' ' << attributeName << " = [" << attributeValue << "] ";
1764 }
1765 
1766 void ReduceOp::print(OpAsmPrinter &p) {
1767  Block *mapper = getBody();
1768  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1769  if (payloadOp) {
1770  printShortForm(p, payloadOp);
1771  }
1772 
1773  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1774  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1775  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1776  if (!payloadOp) {
1777  // Print region if the payload op was not detected.
1778  p.increaseIndent();
1779  p.printNewline();
1780  p << "(";
1781  llvm::interleaveComma(mapper->getArguments(), p,
1782  [&](auto arg) { p.printRegionArgument(arg); });
1783  p << ") ";
1784 
1785  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1786  p.decreaseIndent();
1787  }
1788 }
1789 
1790 LogicalResult ReduceOp::verify() {
1791  ArrayRef<int64_t> dimensionsRef = getDimensions();
1792 
1793  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1794  if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1795  llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1796  return emitOpError() << "expects all inputs to have the same shapes. "
1797  "Shape at input-index "
1798  << i
1799  << " is not equal to the shape at input-index 0.";
1800  }
1801  }
1802  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1803  if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1804  llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1805  return emitOpError() << "expects all outputs to have the same shapes. "
1806  "Shape at output-index "
1807  << i
1808  << " is not equal to the shape at output-index 0.";
1809  }
1810  }
1811  auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1812  auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1813 
1814  DenseSet<int64_t> dimensionsToReduce;
1815  for (int64_t dimension : dimensionsRef) {
1816  if (dimension < 0 || dimension >= inputType.getRank()) {
1817  return emitOpError()
1818  << "dimensions for reduction should be in the range [0, "
1819  << inputType.getRank() - 1 << "].";
1820  }
1821  dimensionsToReduce.insert(dimension);
1822  }
1823 
1824  auto inputDims = inputType.getShape();
1825  auto initDims = initType.getShape();
1826 
1827  // Input dimensions that will be left after the reduction.
1828  SmallVector<int64_t> reducedInputDims;
1829  for (const auto &en : llvm::enumerate(inputDims)) {
1830  if (!dimensionsToReduce.count(en.index()))
1831  reducedInputDims.push_back(en.value());
1832  }
1833 
1834  if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1835  return emitOpError() << "number of dimensions after reduction "
1836  << reducedInputDims.size()
1837  << " doesn't match the init rank "
1838  << initType.getRank();
1839  }
1840 
1841  if (reducedInputDims != initDims)
1842  return emitOpError() << "init dimensions [" << initDims
1843  << "] doesn't match input dimensions after reduction ["
1844  << reducedInputDims << "]";
1845 
1846  Block *block = getBody();
1847  if (block->getNumArguments() != this->getNumOperands())
1848  return emitOpError()
1849  << "mismatching number of operands and block arguments";
1850 
1851  // Check that the first block arguments match the element type of the inputs.
1852  for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1853  Type inputElementType =
1854  llvm::cast<ShapedType>(input.getType()).getElementType();
1855  if (inputElementType != bbArg.getType())
1856  return emitOpError()
1857  << "input element type " << inputElementType
1858  << " does not match corresponding block argument type "
1859  << bbArg.getType();
1860  }
1861 
1862  // Check that the last block arguments match the element type of the outputs.
1863  for (auto [output, bbArg] : llvm::zip(
1864  getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1865  auto outputElementType =
1866  llvm::cast<ShapedType>(output.getType()).getElementType();
1867  if (outputElementType != bbArg.getType())
1868  return emitOpError()
1869  << "output element type " << outputElementType
1870  << " does not match corresponding block argument type "
1871  << bbArg.getType();
1872  }
1873  return success();
1874 }
1875 
1876 //===----------------------------------------------------------------------===//
1877 // TransposeOp
1878 //===----------------------------------------------------------------------===//
1879 
1880 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1881  Region &region, ValueRange inputs,
1882  ValueRange outputs) {
1883  buildGenericRegion(builder, loc, region, inputs, outputs,
1884  [](OpBuilder &b, Location loc, ValueRange args) {
1885  if (!args.empty())
1886  b.create<linalg::YieldOp>(loc, args[0]);
1887  });
1888 }
1889 
1890 void TransposeOp::build(::mlir::OpBuilder &builder,
1891  ::mlir::OperationState &result, Value input, Value init,
1892  DenseI64ArrayAttr permutation,
1893  ArrayRef<NamedAttribute> attributes) {
1894  result.addOperands(input);
1895  result.addOperands(init);
1896  result.addAttribute(getPermutationAttrName(result.name), permutation);
1897  result.addAttributes(attributes);
1898 
1899  // Add output types for `RankedTensorType` output arguments.
1900  Type initType = init.getType();
1901  if (llvm::isa<RankedTensorType>(initType))
1902  result.addTypes(initType);
1903 
1904  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1905  init);
1906 }
1907 
1908 void TransposeOp::build(::mlir::OpBuilder &builder,
1909  ::mlir::OperationState &result, Value input, Value init,
1910  ArrayRef<int64_t> permutation,
1911  ArrayRef<NamedAttribute> attributes) {
1912  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1913  attributes);
1914 }
1915 
1916 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1917  if (failed(parseDstStyleOp(
1918  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1919  return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1920  })))
1921  return failure();
1922 
1923  OpBuilder builder(parser.getContext());
1924  buildIdentityRegion(builder, result.location, *result.addRegion(),
1925  /*inputs=*/result.operands,
1926  /*outputs=*/{});
1927  return success();
1928 }
1929 
1930 void TransposeOp::getAsmResultNames(
1931  function_ref<void(Value, StringRef)> setNameFn) {
1932  if (!getResults().empty())
1933  setNameFn(getResults().front(), "transposed");
1934 }
1935 
1937  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1938  printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1939  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1940 }
1941 
1942 LogicalResult TransposeOp::verify() {
1943  ArrayRef<int64_t> permutationRef = getPermutation();
1944 
1945  if (!isPermutationVector(permutationRef))
1946  return emitOpError("permutation is not valid");
1947 
1948  auto inputType = getInput().getType();
1949  auto initType = getInit().getType();
1950 
1951  int64_t rank = inputType.getRank();
1952 
1953  if (rank != initType.getRank())
1954  return emitOpError() << "input rank " << rank
1955  << " does not match init rank " << initType.getRank();
1956 
1957  if (rank != static_cast<int64_t>(permutationRef.size()))
1958  return emitOpError() << "size of permutation " << permutationRef.size()
1959  << " does not match the argument rank " << rank;
1960 
1961  auto inputDims = inputType.getShape();
1962  auto initDims = initType.getShape();
1963 
1964  for (int64_t i = 0; i < rank; ++i) {
1965  int64_t inputDim = inputDims[permutationRef[i]];
1966  int64_t initDim = initDims[i];
1967 
1968  if (inputDim != initDim) {
1969  return emitOpError() << "dim(result, " << i << ") = " << initDim
1970  << " doesn't match dim(input, permutation[" << i
1971  << "]) = " << inputDim;
1972  }
1973  }
1974 
1975  return success();
1976 }
1977 
1978 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1979  int64_t rank = getInit().getType().getRank();
1980  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1981 }
1982 
1983 ArrayAttr TransposeOp::getIndexingMaps() {
1984  Builder builder(getContext());
1985  int64_t rank = getInit().getType().getRank();
1986  return builder.getAffineMapArrayAttr(
1988  llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
1989  builder.getMultiDimIdentityMap(rank)});
1990 }
1991 
1992 void TransposeOp::getEffects(
1994  &effects) {
1995  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1996 }
1997 
1998 Speculation::Speculatability TransposeOp::getSpeculatability() {
1999  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2000 }
2001 
2002 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2004  // Only the tensor type is supported.
2005  if (!isa<TensorType>(getInput().getType()))
2006  return failure();
2007 
2008  // Single dimension transpose.
2009  if (getPermutation().size() == 0) {
2010  result.push_back(getInput());
2011  return success();
2012  }
2013  // Identity permutation.
2014  if (isIdentityPermutation(getPermutation())) {
2015  result.push_back(getInput());
2016  return success();
2017  }
2018 
2019  return failure();
2020 }
2021 
2022 /// Fold transpose with transpose.
2023 struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
2025 
2026  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2027  PatternRewriter &rewriter) const override {
2028  auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2029  if (!defTransposeOp)
2030  return failure();
2031  ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2032  ArrayRef<int64_t> perms = transposeOp.getPermutation();
2033  SmallVector<int64_t> foldedPerms;
2034  foldedPerms.reserve(perms.size());
2035  for (int64_t perm : perms)
2036  foldedPerms.push_back(defPerms[perm]);
2037 
2038  rewriter.replaceOpWithNewOp<TransposeOp>(
2039  transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2040  foldedPerms);
2041  return success();
2042  }
2043 };
2044 
2045 /// This pattern canonicalize transpose by swapping the order of
2046 /// broadcast and transpose:
2047 /// transpose(broadcast(input)) -> broadcast(transpose(input))
2048 struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2050 
2051  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2052  PatternRewriter &rewriter) const override {
2053  Value input = transposeOp.getInput();
2054  BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2055  if (!input.hasOneUse() || !broadcastOp)
2056  return failure();
2057 
2058  ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2059  ArrayRef<int64_t> perms = transposeOp.getPermutation();
2060 
2061  // Get new perms and new dimensions.
2062  SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2063  SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
2064  SmallVector<int64_t> resultDimensions;
2065  unsigned dimensionSize = dimensions.size();
2066  for (unsigned i = 0; i < dimensionSize; ++i)
2067  resultDimensions.push_back(invertPerm[dimensions[i]]);
2068 
2069  // Create transpose result.
2070  Value broadcastInput = broadcastOp.getInput();
2071  Location loc = transposeOp.getLoc();
2072  MLIRContext *ctx = transposeOp.getContext();
2074  auto broadcastInputTy =
2075  mlir::cast<RankedTensorType>(broadcastInput.getType());
2076  unsigned inputRank = broadcastInputTy.getRank();
2077  for (unsigned i = 0; i < inputRank; ++i) {
2078  if (broadcastInputTy.isDynamicDim(i)) {
2079  dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
2080  ->getResult(0));
2081  } else {
2082  dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2083  broadcastInputTy.getDimSize(i)));
2084  }
2085  }
2086  SmallVector<OpFoldResult> transposeResultShapes =
2087  applyPermutation(dims, resultPerms);
2088  Value transposeInit = rewriter.create<tensor::EmptyOp>(
2089  transposeOp.getLoc(), transposeResultShapes,
2090  broadcastInputTy.getElementType());
2091 
2092  // Create broadcast(transpose(input)).
2093  Value transposeResult =
2094  rewriter
2095  .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2096  resultPerms)
2097  ->getResult(0);
2098  rewriter.replaceOpWithNewOp<BroadcastOp>(
2099  transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2100  return success();
2101  }
2102 };
2103 
2104 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2105  MLIRContext *context) {
2107 }
2108 
2109 //===----------------------------------------------------------------------===//
2110 // BroadcastOp
2111 //===----------------------------------------------------------------------===//
2112 
2113 void BroadcastOp::build(::mlir::OpBuilder &builder,
2114  ::mlir::OperationState &result, Value input, Value init,
2115  DenseI64ArrayAttr dimensions,
2116  ArrayRef<NamedAttribute> attributes) {
2117  result.addOperands(input);
2118  result.addOperands(init);
2119  result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2120  result.addAttributes(attributes);
2121 
2122  // Add output types for `RankedTensorType` output arguments.
2123  Type initType = init.getType();
2124  if (llvm::isa<RankedTensorType>(initType))
2125  result.addTypes(initType);
2126 
2127  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2128  init);
2129 }
2130 
2131 void BroadcastOp::build(::mlir::OpBuilder &builder,
2132  ::mlir::OperationState &result, Value input, Value init,
2133  ArrayRef<int64_t> dimensions,
2134  ArrayRef<NamedAttribute> attributes) {
2135  build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2136  attributes);
2137 }
2138 
2139 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2140  if (failed(parseDstStyleOp(
2141  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2142  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2143  })))
2144  return failure();
2145 
2146  OpBuilder builder(parser.getContext());
2147  buildIdentityRegion(builder, result.location, *result.addRegion(),
2148  /*inputs=*/result.operands,
2149  /*outputs=*/{});
2150  return success();
2151 }
2152 
2153 void BroadcastOp::getAsmResultNames(
2154  function_ref<void(Value, StringRef)> setNameFn) {
2155  if (!getResults().empty())
2156  setNameFn(getResults().front(), "broadcasted");
2157 }
2158 
2160  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2161  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2162  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2163 }
2164 
2165 LogicalResult BroadcastOp::verify() {
2166  ArrayRef<int64_t> dimensionsRef = getDimensions();
2167 
2168  auto inputType = getInput().getType();
2169  auto initType = getInit().getType();
2170 
2171  int64_t inputRank = inputType.getRank();
2172  int64_t initRank = initType.getRank();
2173 
2174  auto inputShape = inputType.getShape();
2175  auto initShape = initType.getShape();
2176 
2177  if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2178  return emitOpError() << "input rank plus added dimensions does not "
2179  "match init rank. input rank: "
2180  << inputRank
2181  << ", dimensions size: " << dimensionsRef.size()
2182  << ", init rank: " << initRank;
2183 
2184  for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2185  if (dim < 0 || dim >= initRank)
2186  return emitOpError() << "dimension " << idx
2187  << " is out of range. expected range: [0, "
2188  << initRank - 1 << "], got: " << dim;
2189  }
2190 
2191  // Mapping from input dims to init dims.
2192  SmallVector<int64_t> dimMap;
2193  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2194  if (!llvm::is_contained(dimensionsRef, dim))
2195  dimMap.push_back(dim);
2196  }
2197 
2198  for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2199  // This dimensions is mapped from the input. Init and input dims should
2200  // match.
2201  if (inputShape[inputDimIdx] != initShape[initDimIdx])
2202  return emitOpError() << "input dim " << inputDimIdx
2203  << " should match init dim " << initDimIdx
2204  << ". input: " << inputShape[inputDimIdx]
2205  << ", init: " << initShape[initDimIdx];
2206  }
2207 
2208  return success();
2209 }
2210 
2211 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2212  int64_t rank = getInit().getType().getRank();
2213  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2214 }
2215 
2216 ArrayAttr BroadcastOp::getIndexingMaps() {
2217  Builder builder(getContext());
2218  int64_t rank = getInit().getType().getRank();
2219  return builder.getAffineMapArrayAttr(
2220  {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2221  builder.getMultiDimIdentityMap(rank)});
2222 }
2223 
2224 void BroadcastOp::getEffects(
2226  &effects) {
2227  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2228 }
2229 
2230 Speculation::Speculatability BroadcastOp::getSpeculatability() {
2231  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2232 }
2233 
2234 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2235  MLIRContext *context) {
2236  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2237 }
2238 
2239 //===----------------------------------------------------------------------===//
2240 // YieldOp
2241 //===----------------------------------------------------------------------===//
2242 
2244  if (getNumOperands() > 0)
2245  p << ' ' << getOperands();
2246  p.printOptionalAttrDict((*this)->getAttrs());
2247  if (getNumOperands() > 0)
2248  p << " : " << getOperandTypes();
2249 }
2250 
2251 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2253  SmallVector<Type, 2> types;
2254  SMLoc loc = parser.getCurrentLocation();
2255  return failure(parser.parseOperandList(opInfo) ||
2256  parser.parseOptionalAttrDict(result.attributes) ||
2257  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2258  parser.resolveOperands(opInfo, types, loc, result.operands));
2259 }
2260 
2261 // Check the operand number and types must match the element types of the
2262 // LinalgOp interface's shaped operands.
2263 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2264  if (op.getNumOperands() != linalgOp.getNumDpsInits())
2265  return op.emitOpError("expected number of yield values (")
2266  << op.getNumOperands()
2267  << ") to match the number of inits / outs operands of the enclosing "
2268  << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2269 
2270  for (OpOperand &opOperand : op->getOpOperands()) {
2271  OpOperand *outputOperand =
2272  linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2273  Type elementType = outputOperand->get().getType();
2274  if (isa<MemRefType, RankedTensorType>(elementType))
2275  elementType = getElementTypeOrSelf(outputOperand->get().getType());
2276  if (opOperand.get().getType() != elementType)
2277  return op.emitOpError("type of yield operand ")
2278  << (opOperand.getOperandNumber() + 1) << " ("
2279  << opOperand.get().getType() << ") doesn't match "
2280  << "the element type of the enclosing linalg.generic op ("
2281  << elementType << ")";
2282  }
2283  return success();
2284 }
2285 
2286 LogicalResult linalg::YieldOp::verify() {
2287  auto *parentOp = (*this)->getParentOp();
2288  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2289  return emitOpError("expected single non-empty parent region");
2290 
2291  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2292  return verifyYield(*this, linalgOp);
2293 
2294  return emitOpError("expected parent op with LinalgOp interface");
2295 }
2296 
2297 //===----------------------------------------------------------------------===//
2298 // IndexOp
2299 //===----------------------------------------------------------------------===//
2300 
2301 LogicalResult IndexOp::verify() {
2302  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2303  if (!linalgOp)
2304  return emitOpError("expected parent op with LinalgOp interface");
2305  if (linalgOp.getNumLoops() <= getDim())
2306  return emitOpError("expected dim (")
2307  << getDim() << ") to be lower than the number of loops ("
2308  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2309  return success();
2310 }
2311 
2312 OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2313  auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2314  // Bail out if `linalg.index` does not have a proper parent yet at this
2315  // point, e.g., when calling `createOrFold` during IR construction in
2316  // `genericOp::build`.
2317  if (!linalgOp)
2318  return OpFoldResult{};
2319 
2320  // Index of unit dims is always 0.
2321  SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2322  uint64_t dim = getDim();
2323  assert(dim < loopBounds.size() && "Dim is out of bounds");
2324  if (loopBounds[dim] == 1)
2326 
2327  return OpFoldResult{};
2328 }
2329 
2330 /////// Operations corresponding to library calls defined with Tablegen ////////
2331 
2332 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2333 
2334 #define GET_OP_CLASSES
2335 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2336 
2337 #define GET_OP_CLASSES
2338 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2339 #define GET_OP_CLASSES
2340 #include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2341 
2342 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2343  unsigned rank,
2344  MLIRContext *context) {
2345  if (maybeMap)
2346  return *maybeMap;
2347  if (rank == 0)
2348  return AffineMap::get(context);
2349  return AffineMap::getMultiDimIdentityMap(rank, context);
2350 }
2351 
2353 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2354  MLIRContext *context) {
2356  res.reserve(num);
2357  for (unsigned i = 0; i < num; ++i)
2358  res.push_back(getAffineDimExpr(startIdx++, context));
2359  return res;
2360 }
2361 
2364  auto rangeA = llvm::make_range(a.begin(), a.end());
2365  auto rangeB = llvm::make_range(b.begin(), b.end());
2366  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2367  return llvm::to_vector<4>(concatRanges);
2368 }
2369 
2370 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2371  if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2372  ss << "view";
2373  for (auto size : memref.getShape())
2374  if (size < 0)
2375  ss << "sx";
2376  else
2377  ss << size << "x";
2378  if (failed(appendMangledType(ss, memref.getElementType())))
2379  return failure();
2380  if (auto as = memref.getMemorySpace()) {
2381  if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2382  ss << "as" << attr.getInt();
2383  else
2384  return failure();
2385  }
2386  return success();
2387  }
2388  if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2389  ss << "vector";
2390  llvm::interleave(
2391  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2392  if (failed(appendMangledType(ss, vec.getElementType())))
2393  return failure();
2394  return success();
2395  }
2396  if (t.isSignlessIntOrIndexOrFloat()) {
2397  ss << t;
2398  return success();
2399  }
2400  return failure();
2401 }
2402 
2404  assert(isa<LinalgOp>(op));
2405  std::string name(op->getName().getStringRef().str());
2406  std::string fun = "";
2407  for (NamedAttribute kv : op->getAttrs()) {
2408  if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2409  fun = stringifyEnum(ufa.getValue()).str() + "_";
2410  } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2411  fun = stringifyEnum(bfa.getValue()).str() + "_";
2412  }
2413  }
2414  name.reserve(128);
2415  llvm::replace(name, '.', '_');
2416  llvm::raw_string_ostream ss(name);
2417  ss << "_" << fun;
2418  for (Type t : op->getOperandTypes()) {
2419  if (failed(appendMangledType(ss, t)))
2420  return std::string();
2421  ss << "_";
2422  }
2423  name.pop_back();
2424  return name;
2425 }
2426 
2427 //===----------------------------------------------------------------------===//
2428 // Canonicalizers and Folders.
2429 //===----------------------------------------------------------------------===//
2430 
2431 namespace {
2432 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2434 
2435  LogicalResult matchAndRewrite(LinalgOp op,
2436  PatternRewriter &rewriter) const override {
2437  for (OpOperand &opOperand : op->getOpOperands()) {
2438  // Linalg "inputs" may be either tensor or memref type.
2439  // tensor<0xelt_type> is a convention that may not always mean
2440  // "0 iterations". Only erase in cases we see memref<...x0x...>.
2441  auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2442  if (!mt)
2443  continue;
2444  if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2445  rewriter.eraseOp(op);
2446  return success();
2447  }
2448  }
2449  return failure();
2450  }
2451 };
2452 
2453 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2454 /// result that is more static than the linalg op.
2455 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2457 
2458  LogicalResult matchAndRewrite(tensor::CastOp castOp,
2459  PatternRewriter &rewriter) const override {
2460  if (!tensor::canFoldIntoProducerOp(castOp))
2461  return failure();
2462 
2463  auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2464  if (!linalgOp)
2465  return failure();
2466 
2467  // Cast can be in conditionally reachable region, if which case folding will
2468  // generate invalid code. Only conservatively fold ops in same block for
2469  // now.
2470  if (castOp->getBlock() != linalgOp->getBlock())
2471  return failure();
2472 
2473  OpBuilder::InsertionGuard guard(rewriter);
2474  rewriter.setInsertionPoint(linalgOp);
2475 
2476  Location loc = linalgOp.getLoc();
2477  OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2478  unsigned resultNumber = resultValue.getResultNumber();
2479  auto resultType =
2480  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2481  // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2482  // going from a more dynamic shape to a less dynamic shape. If the producer
2483  // for this cast, i.e. producer of the out operand, is also an operation
2484  // that folds with tensor.cast consumer (like this pattern), the cast will
2485  // continue to propagate as far up the stack as it can go.
2486  OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2487  Value newOperand =
2488  rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
2489  SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2490  SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2491  linalgOp.getDpsInits().end());
2492  outputOperands[resultNumber] = newOperand;
2493  newOperands.append(outputOperands.begin(), outputOperands.end());
2494 
2495  SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2496  linalgOp->result_type_end());
2497  resultTypes[resultNumber] = resultType;
2498  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2499 
2500  // Create a tensor.cast operation back to the original type.
2501  Value castBack = rewriter.create<tensor::CastOp>(
2502  loc, resultValue.getType(), newOp->getResult(resultNumber));
2503 
2504  SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2505  results[resultNumber] = castBack;
2506  rewriter.replaceOp(linalgOp, results);
2507  rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2508  return success();
2509  }
2510 };
2511 
2512 /// For each of the operand in `operands` this function maps the static sizes of
2513 /// dimensions to their affine dim expressions.
2514 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2515  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2516  for (OpOperand &opOperand : operands) {
2517  if (linalgOp.isScalar(&opOperand))
2518  continue;
2519  Value src = opOperand.get();
2520  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2521  auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2522 
2523  // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2524  // `tensor.cast` operation and source of the cast operation has a static
2525  // shape, then assign it to the `sourceShape`.
2526  auto *parentOp = src.getDefiningOp();
2527  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2528  if (parentOp) {
2529  if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2530  Value castSource = castOp.getSource();
2531  auto castSourceType =
2532  llvm::dyn_cast<RankedTensorType>(castSource.getType());
2533  if (castSourceType && castSourceType.hasStaticShape())
2534  sourceShape = castSourceType.getShape();
2535  }
2536  }
2537 
2538  // If the source shape's dimension has a static shape, map the affine dim
2539  // expression to the known static size.
2540  for (unsigned i = 0; i < sourceShape.size(); i++) {
2541  if (sourceType.isDynamicDim(i))
2542  continue;
2543  if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2544  affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2545  }
2546  }
2547 }
2548 
2549 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2550 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2551 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2552 /// change then `changeNeeded` is false and same operand is added in the
2553 /// `newOperands` list.
2554 static void createNewOperandWithStaticSizes(
2555  Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2556  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2557  SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2558  bool &changeNeeded) {
2559  Value src = opOperand->get();
2560  newOperands.push_back(src);
2561  if (linalgOp.isScalar(opOperand))
2562  return;
2563  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2564  Type resultType = sourceType;
2565  if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2566  resultTypes.push_back(resultType);
2567  return;
2568  }
2569  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2570  AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2571  SmallVector<int64_t> newShape;
2572  // If operand is updated with new shape, `newOperandNeeded` will be
2573  // true.
2574  bool newOperandNeeded = false;
2575  for (unsigned i = 0; i < sourceShape.size(); i++) {
2576  int64_t dimShape = sourceShape[i];
2577  AffineExpr dimExpr = sourceMap.getResult(i);
2578  if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2579  newShape.push_back(dimShape);
2580  continue;
2581  }
2582  // Dimension has a dynamic shape and corresponding affine dim
2583  // expression is present in the map. So assign the size for the
2584  // given affine dim expression to the dimension.
2585  newShape.push_back(affineExprToSize[dimExpr]);
2586  newOperandNeeded = true;
2587  }
2588  resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2589  sourceType.getEncoding());
2590  if (newOperandNeeded) {
2591  changeNeeded = true;
2592  // Get the new operand value given its size and element type by
2593  // casting it.
2594  Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2595  unsigned index = opOperand->getOperandNumber();
2596  newOperands[index] = newOperand;
2597  }
2598  if (linalgOp.isDpsInit(opOperand))
2599  resultTypes.push_back(resultType);
2600 }
2601 
2602 /// Static shapes for the operands can be inferred if any one of the operands
2603 /// have a static shape. This can be done by referring to the affine dim
2604 /// expressions for the operand.
2605 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2607 
2608  LogicalResult matchAndRewrite(LinalgOp linalgOp,
2609  PatternRewriter &rewriter) const override {
2610  if (!linalgOp.hasPureTensorSemantics())
2611  return failure();
2612 
2613  // Maps must be projected permutations.
2614  if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2615  return !map.isProjectedPermutation();
2616  }))
2617  return failure();
2618 
2619  // Maps affine dim expressions to the static size of that dimension.
2620  llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2621  Location loc = linalgOp.getLoc();
2622 
2623  // For each of the affine dim expression, check if the size is known. If
2624  // known add that in the map.
2625  populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2626 
2627  SmallVector<Value> newOperands;
2628  SmallVector<Type> resultTypes;
2629 
2630  // `changeNeeded` is `false` if the operands of `linalgOp` require no
2631  // change in their types.
2632  bool changeNeeded = false;
2633  newOperands.reserve(linalgOp->getNumOperands());
2634  resultTypes.reserve(linalgOp.getNumDpsInits());
2635 
2636  // Iterate over all the operands and update the static sizes.
2637  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2638  createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2639  affineExprToSize, linalgOp, newOperands,
2640  resultTypes, changeNeeded);
2641  }
2642 
2643  // If the generic op has all the required static information, no
2644  // canonicalization needed.
2645  if (!changeNeeded)
2646  return failure();
2647 
2648  // Clone op.
2649  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2650  SmallVector<Value> replacements;
2651  replacements.reserve(newOp->getNumResults());
2652  for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2653  Value newResult = std::get<1>(it);
2654  Value oldResult = std::get<0>(it);
2655  Type newType = newResult.getType();
2656  Type oldType = oldResult.getType();
2657  replacements.push_back(
2658  (newType != oldType)
2659  ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2660  : newResult);
2661  }
2662  rewriter.replaceOp(linalgOp, replacements);
2663  return success();
2664  }
2665 };
2666 
2667 } // namespace
2668 
2669 // All named ops canonicalizers and folders are auto-generated in the
2670 // .cpp.inc.
2671 
2672 //===----------------------------------------------------------------------===//
2673 // SoftmaxOp
2674 //===----------------------------------------------------------------------===//
2675 
2676 LogicalResult SoftmaxOp::verify() {
2677  ShapedType inputType = getInputOperandType();
2678  ShapedType outputType = getOutputOperandType();
2679 
2680  ArrayRef<int64_t> inputShape = inputType.getShape();
2681  ArrayRef<int64_t> outputShape = outputType.getShape();
2682  if (failed(verifyCompatibleShape(inputShape, outputShape)))
2683  return emitOpError("incompatible output shape");
2684 
2685  int64_t inputRank = getInputOperandRank();
2686  int64_t dimension = getDimension();
2687  if ((dimension < 0) || (dimension >= inputRank))
2688  return emitOpError("incorrect dimension specified");
2689 
2690  return success();
2691 }
2692 
2693 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2694  int64_t operandRank = getInputOperandRank();
2695  SmallVector<Range> loopBounds(operandRank);
2696  Location loc = getLoc();
2697  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2698  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2699  Value source = getInput();
2700  for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2701  loopBounds[dim].offset = zero;
2702  loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2703  loopBounds[dim].stride = one;
2704  }
2705  return loopBounds;
2706 }
2707 
2708 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2709  SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2710  utils::IteratorType::parallel);
2711  iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2712  return iteratorTypes;
2713 }
2714 
2715 FailureOr<TilingResult>
2717  ArrayRef<OpFoldResult> offsets,
2718  ArrayRef<OpFoldResult> sizes) {
2719  int64_t rank = getInputOperandRank();
2720  auto oneAttr = builder.getI64IntegerAttr(1);
2721  SmallVector<OpFoldResult> strides(rank, oneAttr);
2722  SmallVector<Value> tiledOperands;
2723  Operation *inputSlice =
2724  getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2725  if (!inputSlice) {
2726  return emitOpError("failed to compute input slice");
2727  }
2728  tiledOperands.emplace_back(inputSlice->getResult(0));
2729  Operation *outputSlice =
2730  getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2731  if (!outputSlice) {
2732  return emitOpError("failed to compute output slice");
2733  }
2734  tiledOperands.emplace_back(outputSlice->getResult(0));
2735 
2736  SmallVector<Type, 4> resultTypes;
2737  if (hasPureTensorSemantics())
2738  resultTypes.push_back(tiledOperands[1].getType());
2739  Operation *tiledOp =
2740  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2741 
2742  return TilingResult{
2743  {tiledOp},
2744  SmallVector<Value>(tiledOp->getResults()),
2745  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2746 }
2747 
2748 LogicalResult SoftmaxOp::getResultTilePosition(
2749  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2750  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2751  SmallVector<OpFoldResult> &resultSizes) {
2752  if (resultNumber == 0) {
2753  resultOffsets.assign(offsets.begin(), offsets.end());
2754  resultSizes.assign(sizes.begin(), sizes.end());
2755  return success();
2756  }
2757  return failure();
2758 }
2759 
2760 // cast(dynamic) -> static.
2761 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2762  return memref::foldMemRefCast(*this);
2763 }
2764 
2765 LogicalResult
2767  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2769  Location loc = getOperation()->getLoc();
2770  IRRewriter rewriter(b);
2771  auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2772  auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2773  for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2774  if (!outputShapedType.isDynamicDim(dim)) {
2775  // Static dim: Return IntegerAttr.
2776  shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2777  } else {
2778  // Dynamic dim: Return Value.
2779  OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2780  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2781  }
2782  }
2783  reifiedReturnShapes.emplace_back(std::move(shapes));
2784  return success();
2785 }
2786 
2787 void SoftmaxOp::getEffects(
2789  &effects) {
2790  for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2791  if (!llvm::isa<MemRefType>(operand.getType()))
2792  continue;
2793  effects.emplace_back(MemoryEffects::Read::get(),
2794  &getOperation()->getOpOperand(index), /*stage=*/0,
2795  /*effectOnFullRegion=*/true,
2797  }
2798 
2799  for (OpOperand &operand : getDpsInitsMutable()) {
2800  if (!llvm::isa<MemRefType>(operand.get().getType()))
2801  continue;
2802  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2803  /*effectOnFullRegion=*/true,
2805  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2806  /*effectOnFullRegion=*/true,
2808  }
2809 }
2810 
2811 // Helper functions for softmax decomposition.
2812 // @{
2813 
2814 // Helper function to produce the iterator types (reduction or parallel) and
2815 // affine maps for the iterators used in the decomposition of softmax.
2816 // This method creates:
2817 // If allParallel == true:
2818 // - iterator type: {parallel, ..., parallel}
2819 // - affine maps:
2820 // -- identity with inputRank dimensions.
2821 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2822 // where N == inputRank.
2823 //
2824 // If allParallel == false:
2825 // - iterator type at dim(i) == parallel for i != \p dim and
2826 // dim(dim) == reduction.
2827 // - affine map:
2828 // -- identity with inputRank dimensions.
2829 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2830 // where N == inputRank.
2831 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2833  int64_t dim, bool allParallel = false) {
2834  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2835  utils::IteratorType::parallel);
2836  if (!allParallel)
2837  iteratorTypes[dim] = utils::IteratorType::reduction;
2838  MLIRContext *ctxt = builder.getContext();
2839  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2840  SmallVector<AffineExpr, 2> affineExprs;
2841  for (int i = 0; i < inputRank; i++) {
2842  if (i != dim)
2843  affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2844  }
2845  auto reductionMap =
2846  AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2847  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2848  return std::make_tuple(iteratorTypes, indexingMaps);
2849 }
2850 
2851 // Helper function to produce a linalg.generic that computes a reduction on
2852 // dimension \p dim with the operation type \p T.
2853 template <typename T>
2854 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2855  int64_t dim) {
2856  auto inputType = cast<ShapedType>(input.getType());
2857  ArrayRef<int64_t> inputShape = inputType.getShape();
2858  int64_t inputRank = inputShape.size();
2859  auto [iteratorTypes, indexingMaps] =
2860  computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2861  assert(indexingMaps.size() == 2 &&
2862  "We should have two maps: 1 for the input, 1 for the output");
2863  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2864 
2865  auto genericOp = builder.create<linalg::GenericOp>(
2866  loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2867  [&](OpBuilder &b, Location loc, ValueRange args) {
2868  Value result = b.create<T>(loc, args[0], args[1]);
2869  b.create<linalg::YieldOp>(loc, result);
2870  });
2871  return genericOp.getResult(0);
2872 }
2873 
2874 /// Produce a linalg generic that computes the second step of the softmax
2875 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2876 /// on dimension \p dim.
2877 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2878  Value max, Value output, int64_t dim) {
2879  auto inputType = cast<ShapedType>(input.getType());
2880  ArrayRef<int64_t> inputShape = inputType.getShape();
2881  int64_t inputRank = inputShape.size();
2882  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2883  builder, inputRank, dim, /*allParallel=*/true);
2884  assert(indexingMaps.size() == 2 && "We should have one map for each input");
2885  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2886  // Add the affine map for the output argument.
2887  indexingMaps.push_back(indexingMaps[0]);
2888  auto genericOp = builder.create<linalg::GenericOp>(
2889  loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2890  iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2891  Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2892  Value result = b.create<math::ExpOp>(loc, diff);
2893  b.create<linalg::YieldOp>(loc, result);
2894  });
2895  return genericOp.getResult(0);
2896 }
2897 
2898 /// Produce a linalg generic that computes the final step of the softmax
2899 /// decomposition.
2900 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2901 /// yield n / d
2902 /// }
2903 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2904  Value denominator, Value output, int64_t dim) {
2905  auto inputType = cast<ShapedType>(numerator.getType());
2906  ArrayRef<int64_t> inputShape = inputType.getShape();
2907  int64_t inputRank = inputShape.size();
2908  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2909  builder, inputRank, dim, /*allParallel=*/true);
2910  assert(indexingMaps.size() == 2 &&
2911  "We should have one map for each input (2)");
2912  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2913  // Add the affine map for the output tensor.
2914  indexingMaps.push_back(indexingMaps[0]);
2915  auto genericOp = builder.create<linalg::GenericOp>(
2916  loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2917  indexingMaps, iteratorTypes,
2918  [&](OpBuilder &b, Location loc, ValueRange args) {
2919  Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2920  b.create<linalg::YieldOp>(loc, result);
2921  });
2922  return genericOp.getResult(0);
2923 }
2924 // @} End helper functions for softmax decomposition.
2925 
2926 /// Given an N-dimensional tensor x, this method converts
2927 /// softmax(x) to the following sequence of operations:
2928 ///
2929 /// 1. Compute the max of x along dimension d. This results
2930 /// in a N-1 dimensional tensor m.
2931 /// m = max(x, dim = d)
2932 ///
2933 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
2934 /// a N dimensional tensor z.
2935 /// z = exp(x - m)
2936 ///
2937 /// 3. Compute the sum of z along dimension d. This results in
2938 /// a N-1 dimensional tensor l.
2939 /// l = sum(z, dim = d)
2940 ///
2941 /// 4. Divide z and l. This gives the N-dimensional softmax.
2942 /// softmax = z / l
2943 ///
2944 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2945  OpBuilder::InsertionGuard guard(b);
2946  b.setInsertionPoint(*this);
2947  Location loc = getLoc();
2948  Value input = getInput();
2949  ShapedType inputType = getInputOperandType();
2950  Type elementType = inputType.getElementType();
2951  int64_t reductionDim = getDimension();
2952  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
2953  Value output = getOutput();
2954  dims.erase(dims.begin() + reductionDim);
2955  // Step 1: Compute max along dim.
2956  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2957  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
2958  elementType, b, loc,
2959  /*useOnlyFiniteValue=*/true);
2960  Value neutralForMaxFInit =
2961  b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
2962  .result();
2963  Value max =
2964  reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2965 
2966  // Step 2: Subtract max from input and exponentiate.
2967  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
2968 
2969  // Step 3: Compute sum along dim.
2970  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
2971  b, loc, /*useOnlyFiniteValue=*/true);
2972  Value zeroInit =
2973  b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
2974  Value denominator =
2975  reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2976 
2977  // Step 4: Compute softmax.
2978  Value result =
2979  buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2980  return SmallVector<Value>{result};
2981 }
2982 
2983 //===----------------------------------------------------------------------===//
2984 // WinogradFilterTransformOp
2985 //===----------------------------------------------------------------------===//
2986 
2987 LogicalResult WinogradFilterTransformOp::verify() {
2988  auto filterType = cast<ShapedType>(getFilter().getType());
2989  ArrayRef<int64_t> filterShape = filterType.getShape();
2990  int64_t filterH = filterShape[getFilterHDim()];
2991  int64_t filterW = filterShape[getFilterWDim()];
2992  int64_t r = getR();
2993  int64_t m = getM();
2994 
2995  if (filterH != r && filterH != 1)
2996  return emitOpError("expect filter height either equals to r or 1");
2997  if (filterW != r && filterW != 1)
2998  return emitOpError("expect filter width either equals to r or 1");
2999  if (filterH == 1 && filterW == 1)
3000  return emitOpError("expect either filter height or width equals to r");
3001 
3002  SmallVector<int64_t> expectedOutputShape;
3003  expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3004  expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3005  expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3006  expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3007 
3008  auto outputType = cast<ShapedType>(getOutput().getType());
3009  ArrayRef<int64_t> outputShape = outputType.getShape();
3010  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3011  return emitOpError("the output shape is not expected");
3012  }
3013  return success();
3014 }
3015 
3017 WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3018  Location loc = getLoc();
3019  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3020  IntegerAttr oneAttr = builder.getIndexAttr(1);
3021  Value filter = getFilter();
3022  int64_t filterRank = getFilterOperandRank();
3023  SmallVector<Range> loopBounds(filterRank);
3024  for (unsigned dim = 0; dim < filterRank; ++dim) {
3025  loopBounds[dim].offset = zeroAttr;
3026  loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
3027  loopBounds[dim].stride = oneAttr;
3028  }
3029  return loopBounds;
3030 }
3031 
3033 WinogradFilterTransformOp::getLoopIteratorTypes() {
3034  int64_t filterRank = getFilterOperandRank();
3035  SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3036  utils::IteratorType::parallel);
3037  return iteratorTypes;
3038 }
3039 
3041  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3042  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3043  SmallVector<OpFoldResult> &resultSizes) {
3044  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3045  ShapedType filterType = getFilterOperandType();
3046  ArrayRef<int64_t> filterShape = filterType.getShape();
3047  int64_t filterH = filterShape[getFilterHDim()];
3048  int64_t filterW = filterShape[getFilterWDim()];
3049  int64_t m = getM();
3050  int64_t r = getR();
3051  int64_t alpha = m + r - 1;
3052  int64_t alphaH = filterH != 1 ? alpha : 1;
3053  int64_t alphaW = filterW != 1 ? alpha : 1;
3054  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3055  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3056 
3057  resultOffsets.append(
3058  {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3059  resultSizes.append(
3060  {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3061 
3062  return success();
3063 }
3064 
3065 /// Implement tiling for winograd_filter_transform
3066 /// The input of winograd_filter_transform is (F, KH, KW, C).
3067 /// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3068 /// Users can specify the tile sizes of F and C.
3069 /// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3070 /// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3072  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3073  ArrayRef<OpFoldResult> sizes) {
3074  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3075  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3076  ShapedType filterType = getFilterOperandType();
3077  ArrayRef<int64_t> filterShape = filterType.getShape();
3078  int64_t filterH = filterShape[getFilterHDim()];
3079  int64_t filterW = filterShape[getFilterWDim()];
3080  IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3081  IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3082  SmallVector<Value> tiledOperands;
3083  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3084 
3085  sliceOffsets.append(
3086  {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3087  sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3088  sizes[getFilterCDim()]});
3089  int64_t filterRank = getFilterOperandRank();
3090  SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3091  Location loc = getLoc();
3092  auto filterSlice = builder.create<tensor::ExtractSliceOp>(
3093  loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3094  tiledOperands.emplace_back(filterSlice);
3095 
3096  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3097  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3098  resultSizes)))
3099  return failure();
3100 
3101  int64_t outputRank = getOutputOperandRank();
3102  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3103  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3104  loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3105  tiledOperands.emplace_back(outputSlice);
3106 
3107  SmallVector<Type> resultTypes;
3108  resultTypes.push_back(tiledOperands[1].getType());
3109  Operation *tiledOp =
3110  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3111 
3112  return TilingResult{
3113  {tiledOp},
3114  SmallVector<Value>(tiledOp->getResults()),
3115  llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3116 }
3117 
3118 //===----------------------------------------------------------------------===//
3119 // WinogradInputTransformOp
3120 //===----------------------------------------------------------------------===//
3121 
3122 LogicalResult WinogradInputTransformOp::verify() {
3123  auto inputType = cast<ShapedType>(getInput().getType());
3124  ArrayRef<int64_t> inputShape = inputType.getShape();
3125  int64_t inputH = inputShape[getInputHDim()];
3126  int64_t inputW = inputShape[getInputWDim()];
3127  int m = getM();
3128  int r = getR();
3129  int64_t tileSize = m + r - 1;
3130 
3131  auto outputType = cast<ShapedType>(getOutput().getType());
3132  ArrayRef<int64_t> outputShape = outputType.getShape();
3133  bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3134  bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3135 
3136  SmallVector<int64_t> expectedOutputShape(6, inputH);
3137  if (ShapedType::isDynamic(inputH)) {
3138  expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3139  expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3140  } else {
3141  expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3142  expectedOutputShape[getOutputTileHDim()] =
3143  leftTransform ? (inputH - (r - 1)) / m : inputH;
3144  }
3145  if (ShapedType::isDynamic(inputW)) {
3146  expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3147  expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3148  } else {
3149  expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3150  expectedOutputShape[getOutputTileWDim()] =
3151  rightTransform ? (inputW - (r - 1)) / m : inputW;
3152  }
3153  expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3154  expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3155 
3156  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3157  return emitOpError("the output shape is not expected");
3158  }
3159  return success();
3160 }
3161 
3163 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3164  Location loc = getLoc();
3165  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3166  IntegerAttr oneAttr = builder.getIndexAttr(1);
3167  Value output = getOutput();
3168  int64_t outputRank = getOutputOperandRank();
3169  SmallVector<Range> loopBounds(outputRank);
3170  for (unsigned dim = 0; dim < outputRank; ++dim) {
3171  loopBounds[dim].offset = zeroAttr;
3172  // alphaH, alphaW, tileH, tileW, N, C
3173  loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3174  loopBounds[dim].stride = oneAttr;
3175  }
3176  return loopBounds;
3177 }
3178 
3180 WinogradInputTransformOp::getLoopIteratorTypes() {
3181  int64_t outputRank = getOutputOperandRank();
3182  SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3183  utils::IteratorType::parallel);
3184  return iteratorTypes;
3185 }
3186 
3188  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3189  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3190  SmallVector<OpFoldResult> &resultSizes) {
3191  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3192  ShapedType outputType = getOutputOperandType();
3193  ArrayRef<int64_t> outputShape = outputType.getShape();
3194  int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3195  int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3196 
3197  int64_t m = getM();
3198  int64_t r = getR();
3199  int64_t alpha = m + r - 1;
3200  int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3201  int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3202 
3203  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3204  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3205 
3206  resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3207  offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3208  offsets[getOutputCDim()]});
3209  resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3210  sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3211  sizes[getOutputCDim()]});
3212 
3213  return success();
3214 }
3215 
3216 /// Implement tiling for winograd_input_transform
3217 /// The input of winograd_input_transform is (N, H, W, C).
3218 /// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3219 /// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3220 /// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3221 /// the values for the sizes of tileH, tileW, N, C for one tile.
3222 FailureOr<TilingResult>
3224  ArrayRef<OpFoldResult> offsets,
3225  ArrayRef<OpFoldResult> sizes) {
3226  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3227  int64_t m = getM();
3228  int64_t r = getR();
3229 
3230  ShapedType outputType = getOutputOperandType();
3231  ArrayRef<int64_t> outputShape = outputType.getShape();
3232  int64_t alphaH = outputShape[getOutputAlphaHDim()];
3233  int64_t alphaW = outputShape[getOutputAlphaWDim()];
3234 
3235  Location loc = getLoc();
3236  MLIRContext *context = builder.getContext();
3237  auto identityAffineMap =
3238  AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3239  auto offsetAffineMap =
3240  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3241  Value mappedOffsetH = affine::makeComposedAffineApply(
3242  builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3243  offsets[getOutputTileHDim()]);
3244  Value mappedOffsetW = affine::makeComposedAffineApply(
3245  builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3246  offsets[getOutputTileWDim()]);
3247  auto sizeAffineMap = AffineMap::get(
3248  1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3249  Value mappedSizeH = affine::makeComposedAffineApply(
3250  builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3251  Value mappedSizeW = affine::makeComposedAffineApply(
3252  builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3253 
3254  SmallVector<Value> tiledOperands;
3255  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3256 
3257  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3258  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3259  sliceOffsets.append(
3260  {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3261  OpFoldResult sizeH =
3262  alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3263  OpFoldResult sizeW =
3264  alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3265  sliceSizes.append(
3266  {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3267  int64_t inputRank = getInputOperandRank();
3268  SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3269  auto inputSlice = builder.create<tensor::ExtractSliceOp>(
3270  loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3271  tiledOperands.emplace_back(inputSlice);
3272 
3273  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3274  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3275  resultSizes)))
3276  return failure();
3277 
3278  int64_t outputRank = getOutputOperandRank();
3279  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3280  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3281  loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3282  tiledOperands.emplace_back(outputSlice);
3283 
3284  SmallVector<Type> resultTypes;
3285  resultTypes.push_back(tiledOperands[1].getType());
3286  Operation *tiledOp =
3287  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3288 
3289  return TilingResult{
3290  {tiledOp},
3291  SmallVector<Value>(tiledOp->getResults()),
3292  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3293 }
3294 
3295 //===----------------------------------------------------------------------===//
3296 // WinogradOutputTransformOp
3297 //===----------------------------------------------------------------------===//
3298 
3299 LogicalResult WinogradOutputTransformOp::verify() {
3300  auto valueType = cast<ShapedType>(getValue().getType());
3301  ArrayRef<int64_t> valueShape = valueType.getShape();
3302  int64_t valueH = valueShape[getValueAlphaHDim()];
3303  int64_t valueW = valueShape[getValueAlphaWDim()];
3304  int64_t valueTileH = valueShape[getValueTileHDim()];
3305  int64_t valueTileW = valueShape[getValueTileWDim()];
3306  int m = getM();
3307  int r = getR();
3308  bool leftTransform = valueH != 1;
3309  bool rightTransform = valueW != 1;
3310 
3311  int64_t outputRank = getOutputOperandRank();
3312  SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3313  if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3314  expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3315  } else {
3316  if (valueH != (leftTransform ? m + r - 1 : 1))
3317  return emitOpError("expect input height equals to input tile size");
3318  expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3319  }
3320  if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3321  expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3322  } else {
3323  if (valueW != (rightTransform ? m + r - 1 : 1))
3324  return emitOpError("expect input width equals to input tile size");
3325  expectedOutputShape[getOutputWDim()] =
3326  (rightTransform ? m : 1) * valueTileW;
3327  }
3328  expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3329  expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3330 
3331  auto outputType = cast<ShapedType>(getOutput().getType());
3332  ArrayRef<int64_t> outputShape = outputType.getShape();
3333  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3334  return emitOpError("the output shape is not expected");
3335  }
3336  return success();
3337 }
3338 
3340 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3341  Location loc = getLoc();
3342  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3343  IntegerAttr oneAttr = builder.getIndexAttr(1);
3344  Value value = getValue();
3345  int64_t valueRank = getValueOperandRank();
3346  SmallVector<Range> loopBounds(valueRank);
3347  for (unsigned dim = 0; dim < valueRank; ++dim) {
3348  loopBounds[dim].offset = zeroAttr;
3349  // alphaH, alphaW, tileH, tileW, N, F
3350  loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3351  loopBounds[dim].stride = oneAttr;
3352  }
3353  return loopBounds;
3354 }
3355 
3357 WinogradOutputTransformOp::getLoopIteratorTypes() {
3358  int64_t valueRank = getValueOperandRank();
3359  SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3360  utils::IteratorType::parallel);
3361  return iteratorTypes;
3362 }
3363 
3365  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3366  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3367  SmallVector<OpFoldResult> &resultSizes) {
3368  int64_t m = getM();
3369 
3370  Location loc = getLoc();
3371  MLIRContext *context = builder.getContext();
3372  auto identityAffineMap =
3373  AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3374  auto affineMap =
3375  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3376 
3377  ShapedType valueType = getValueOperandType();
3378  ArrayRef<int64_t> valueShape = valueType.getShape();
3379  int64_t valueH = valueShape[0];
3380  int64_t valueW = valueShape[1];
3381  Value mappedOffsetH = affine::makeComposedAffineApply(
3382  builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3383  offsets[getValueTileHDim()]);
3384  Value mappedOffsetW = affine::makeComposedAffineApply(
3385  builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3386  offsets[getValueTileWDim()]);
3387  Value mappedSizeH = affine::makeComposedAffineApply(
3388  builder, loc, affineMap, sizes[getValueTileHDim()]);
3389  Value mappedSizeW = affine::makeComposedAffineApply(
3390  builder, loc, affineMap, sizes[getValueTileWDim()]);
3391 
3392  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3393  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3394  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3395  OpFoldResult sizeH =
3396  valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3397  OpFoldResult sizeW =
3398  valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3399 
3400  resultOffsets.append(
3401  {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3402  resultSizes.append(
3403  {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3404  return success();
3405 }
3406 
3407 /// Implement tiling for winograd_output_transform
3408 /// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3409 /// F). The output of winograd_output_transform is (N, H, W, F) Users can
3410 /// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3411 /// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3412 /// for the sizes of tileH, tileW, N, F for one tile.
3414  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3415  ArrayRef<OpFoldResult> sizes) {
3416  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3417  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3418  Location loc = getLoc();
3419  SmallVector<Value> tiledOperands;
3420  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3421 
3422  ShapedType valueType = getValueOperandType();
3423  ArrayRef<int64_t> valueShape = valueType.getShape();
3424  int64_t alphaH = valueShape[getValueAlphaHDim()];
3425  int64_t alphaW = valueShape[getValueAlphaWDim()];
3426  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3427  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3428 
3429  sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3430  offsets[getValueTileWDim()], offsets[getValueNDim()],
3431  offsets[getValueFDim()]});
3432  sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3433  sizes[getValueTileWDim()], sizes[getValueNDim()],
3434  sizes[getValueFDim()]});
3435  int64_t valueRank = getValueOperandRank();
3436  SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3437  auto valueSlice = builder.create<tensor::ExtractSliceOp>(
3438  loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3439  tiledOperands.emplace_back(valueSlice);
3440 
3441  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3442  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3443  resultSizes)))
3444  return failure();
3445 
3446  int64_t outputRank = getOutputOperandRank();
3447  SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3448  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3449  loc, getOutput(), resultOffsets, resultSizes, strides);
3450  tiledOperands.emplace_back(outputSlice);
3451 
3452  SmallVector<Type> resultTypes;
3453  resultTypes.push_back(tiledOperands[1].getType());
3454  Operation *tiledOp =
3455  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3456 
3457  return TilingResult{
3458  {tiledOp},
3459  SmallVector<Value>(tiledOp->getResults()),
3460  llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3461 }
3462 
3463 //===----------------------------------------------------------------------===//
3464 // LinalgDialect
3465 // TODO: Merge with the LinalgDialect block at the bottom
3466 //===----------------------------------------------------------------------===//
3467 
3468 // Returns true if the result expression of `subMap` are a subset of `fullMap`.
3469 static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3470  auto explicitRange = subMap.getResults();
3471  auto defaultRange = fullMap.getResults();
3472  DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3473  DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3474  llvm::set_union(explicitSet, defaultSet);
3475  return explicitSet == defaultSet;
3476 }
3477 
3478 /// Check if the user defined map is valid broadcast map. Here broadcast
3479 /// indexing maps are defined in context of corresponding default indexing maps
3480 /// for the given Op. This way the check becomes very simple i.e just check the
3481 /// number of result dims.
3482 /// Returns true if the explictMap is broadcasted with respect to the
3483 /// defaultMap.
3484 static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3485  return explictMap.getNumResults() < defaultMap.getNumResults();
3486 }
3487 
3488 /// Verifies the broadcast and transpose semantic sepecified by the explicit
3489 /// indexing map for the MatmulOp \p op for each operand specified by \p
3490 /// opIndex.
3491 static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3492  unsigned opIndex) {
3493  SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3494  SmallVector<AffineMap, 3> defaultIndexingMaps =
3495  matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3496 
3497  auto opIndexingMap = opIndexingMaps[opIndex];
3498  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3499  // Check general validity of indexing map results.
3500  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3501  return matmulOp->emitOpError()
3502  << "Unexpected dim expression in map result.";
3503 
3504  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3505  if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3506  return matmulOp->emitOpError()
3507  << "Invalid broadcast requested, should be (d2).";
3508  }
3509  return success();
3510  }
3511  return success();
3512 }
3513 
3514 // Check general validity of input indexing map of
3515 // BatchMatmulOp/BatchReduceMatmulOp.
3516 template <typename OpTy>
3517 static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
3518  AffineMap opIndexingMap,
3519  AffineMap defaultIndexingMap, bool isLHS) {
3520  assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3521  isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3522  "Expected BatchMatmulOp or BatchReduceMatmulOp");
3523  // Check the result dims are valid.
3524  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3525  return batchVariantMatmulOp->emitOpError()
3526  << "Unexpected result dim expression (outside the set of default "
3527  "result dims).";
3528 
3529  // Check for valid number of result dims of input maps.
3530  if (opIndexingMap.getNumResults() > 3)
3531  return batchVariantMatmulOp->emitOpError()
3532  << "no. of result dim expressions exceeds 3.";
3533 
3534  auto hasValidBatchDim = [](AffineMap map) {
3535  AffineExpr batchDim = map.getResult(0);
3536  return batchDim.isFunctionOfDim(0);
3537  };
3538 
3539  // Check if the requested broadcast is valid.
3540  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3541  if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3542  return batchVariantMatmulOp->emitOpError()
3543  << "Invalid broadcast requested.";
3544  } else if (!hasValidBatchDim(opIndexingMap)) {
3545  return batchVariantMatmulOp->emitOpError()
3546  << "Invalid batch dimension expression.";
3547  }
3548  return success();
3549 }
3550 
3551 /// This function checks if the given AffineMap for the output of a
3552 /// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
3553 /// dimensions and if the output map result dimensions are valid.
3554 template <typename OpTy>
3555 static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
3556  AffineMap opIndexingMap) {
3557  assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3558  isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3559  "Expected BatchMatmulOp or BatchReduceMatmulOp");
3560  if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3561  opIndexingMap.getNumResults() != 3) {
3562 
3563  return batchVariantMatmulOp->emitOpError()
3564  << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3565  << ").";
3566  }
3567  if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3568  opIndexingMap.getNumResults() != 2) {
3569  return batchVariantMatmulOp->emitOpError()
3570  << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
3571  << ").";
3572  }
3573 
3574  auto areValidOutputResultDim = [&](AffineMap outputMap) {
3575  return isa<BatchMatmulOp>(batchVariantMatmulOp)
3576  ? outputMap.getResult(0).isFunctionOfDim(0) &&
3577  outputMap.getResult(1).isFunctionOfDim(1) &&
3578  outputMap.getResult(2).isFunctionOfDim(2)
3579  : outputMap.getResult(0).isFunctionOfDim(1) &&
3580  outputMap.getResult(1).isFunctionOfDim(2);
3581  };
3582 
3583  if (!areValidOutputResultDim(opIndexingMap)) {
3584  return batchVariantMatmulOp->emitOpError()
3585  << "Invalid output map result dimension.";
3586  }
3587 
3588  return success();
3589 }
3590 
3591 /// Verifies the broadcast and transpose semantic specified by the explicit
3592 /// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
3593 /// specified by opIndex.
3594 template <typename OpTy>
3595 static LogicalResult
3597  unsigned opIndex) {
3598  SmallVector<AffineMap, 3> opIndexingMaps =
3599  batchVariantMatmulOp.getIndexingMapsArray();
3600  SmallVector<AffineMap, 3> defaultIndexingMaps =
3601  batchVariantMatmulOp.getDefaultIndexingMaps(
3602  batchVariantMatmulOp->getContext());
3603 
3604  if (opIndexingMaps.size() != 3)
3605  return batchVariantMatmulOp->emitOpError()
3606  << "Indexing_map attribute must have 3 affine maps.";
3607 
3608  auto opIndexingMap = opIndexingMaps[opIndex];
3609  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3610 
3611  if (opIndex == 2 &&
3612  failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
3613  return failure();
3614 
3615  if (opIndex != 2 &&
3616  failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
3617  defaultIndexingMap, opIndex == 0)))
3618  return failure();
3619 
3620  return success();
3621 }
3622 
3623 namespace mlir {
3624 namespace linalg {
3625 
3626 //===----------------------------------------------------------------------===//
3627 // MatMulOp
3628 //===----------------------------------------------------------------------===//
3629 
3630 /// Returns a list of AffineMap with the typical matmul indexing charactristic.
3631 SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3632  AffineExpr d0, d1, d2;
3633  SmallVector<AffineMap> indexingMaps;
3634  bindDims(context, d0, d1, d2);
3635  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3636  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3637  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3638  return indexingMaps;
3639 }
3640 
3641 SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3642  return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3643  utils::IteratorType::parallel,
3644  utils::IteratorType::reduction};
3645 }
3646 
3647 unsigned MatmulOp::getNumRegionArgs() { return 3; }
3648 
3649 std::string MatmulOp::getLibraryCallName() {
3650  return generateLibraryCallName(getOperation());
3651 }
3652 
3653 bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3654 
3655 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3656 /// the user defined indexing maps are not equal to default map.
3657 bool MatmulOp::hasUserDefinedMaps() {
3658  SmallVector<AffineMap, 3> defaultMaps =
3659  getDefaultIndexingMaps(this->getContext());
3660  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3661  return defaultMaps != explicitMaps;
3662 }
3663 
3664 /// Implements the block region builder for the MatmulOp. This is called by
3665 /// 'fillStructuredOpRegion'.
3666 void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3667  ArrayRef<NamedAttribute> attrs) {
3668  assert(3 > 0 && block.getNumArguments() == 3 &&
3669  "MatmulOp regionBuilder expects 3 (>=0) args");
3670  RegionBuilderHelper helper(b, block);
3671  SmallVector<Value> yields;
3672 
3673  TypeFn castVal = TypeFn::cast_signed;
3674  const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3675  return attr.getName() == "cast";
3676  });
3677  if (castIter != attrs.end()) {
3678  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3679  castVal = attr.getValue();
3680  }
3681 
3682  Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3683  block.getArgument(0));
3684  Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3685  block.getArgument(1));
3686  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3687  Value value4 =
3688  helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
3689  yields.push_back(value4);
3690  helper.yieldOutputs(yields);
3691 }
3692 
3693 /// Returns true if the given bcastMap map is a valid broadcast map. A valid
3694 /// broadcast map must include K dimension.
3695 /// TODO: Strict inclusion of K dimension in the broadcast map is not
3696 /// necessary for both input matrices simultaneously. We can relax this
3697 /// condition to have K dimension for one input matrix map and infer the K
3698 /// dimension for other input matrix map from the one already having K
3699 /// dimension.
3700 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3701  assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3702  AffineExpr expr = bcastMap.getResult(0);
3703  // Invalid map if the common dimension of matmul not found.
3704  return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
3705 }
3706 
3707 FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3708  if (parser.parseOptionalKeyword("indexing_maps"))
3709  return ArrayAttr{
3710  nullptr}; // Success in case indexing_maps was not provided.
3711 
3712  ArrayAttr arrayAttr;
3713  if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3714  return failure();
3715 
3716  if (llvm::any_of(arrayAttr,
3717  [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3718  return parser.emitError(parser.getCurrentLocation())
3719  << "element of indexing_maps array is not an affine_map";
3720 
3721  return arrayAttr;
3722 }
3723 
3724 ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3725  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3726  if (failed(indexingMapsAttr))
3727  return failure();
3728 
3729  if (*indexingMapsAttr == nullptr) {
3730  auto indexingMapAttrs = llvm::map_to_vector(
3731  MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3732  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3733  indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
3734  }
3735 
3736  result.addAttribute("indexing_maps", *indexingMapsAttr);
3737  return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3738  MatmulOp::getRegionBuilder());
3739 }
3740 
3741 void MatmulOp::print(OpAsmPrinter &p) {
3742  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3743  MatmulOp::getDefaultIndexingMaps(getContext()),
3744  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3745  if (!llvm::equal(getIndexingMaps(), indexingMaps))
3746  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3747 
3748  std::array<StringRef, 3> elidedAttrs = {
3749  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3750  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3751  elidedAttrs);
3752 }
3753 
3754 /// Verify the user defined indexing maps.
3755 LogicalResult MatmulOp::verify() {
3756  // Verification of pure matmul is handled by verifyStructuredOpInterface().
3757  if (!hasUserDefinedMaps())
3758  return success();
3759 
3760  for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3761  if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3762  return failure();
3763  }
3764  return success();
3765 }
3766 
3767 LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3768  return memref::foldMemRefCast(*this);
3769 }
3770 
3771 void MatmulOp::getEffects(
3773  &effects) {
3774  if (hasPureTensorSemantics())
3775  return;
3776  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3777 }
3778 
3779 Speculation::Speculatability MatmulOp::getSpeculatability() {
3780  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3781 }
3782 
3783 //===----------------------------------------------------------------------===//
3784 // ContractOp
3785 //===----------------------------------------------------------------------===//
3786 
3787 SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
3788  AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3789  // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
3790  // domains are all the same, and each implements a projected permutation.
3791  // Each iteration space dim must occur for at least one operand and either
3792  // takes part in a contraction/reduction or else has parallel iteration type.
3793  // We have that a dim is a contraction/reduction dim if and only if the dim
3794  // occurs for the output operand. We use this fact for fast inference:
3795  // NB: In case we allow dims to occur solely for one input, the above still
3796  // holds: per the einsum semantics, these are reduction dims as well.
3797  SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
3798  for (auto result : outAffineMap.getResults()) {
3799  auto dimExpr = dyn_cast<AffineDimExpr>(result);
3800  assert(dimExpr && "affine_map is a projected permutation");
3801  dimsInOutput[dimExpr.getPosition()] = true;
3802  }
3803 
3804  SmallVector<utils::IteratorType> iteratorTypes;
3805  for (auto dimOccursInOutput : dimsInOutput)
3806  iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
3807  : utils::IteratorType::reduction);
3808 
3809  return iteratorTypes;
3810 }
3811 
3812 unsigned ContractOp::getNumRegionArgs() { return 3; }
3813 
3814 /// Implement block region builder, which is called by 'fillStructuredOpRegion'.
3815 void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3816  ArrayRef<NamedAttribute> attrs) {
3817  assert(block.getNumArguments() == 3 &&
3818  "ContractOp regionBuilder expects 3 args");
3819  RegionBuilderHelper helper(b, block);
3820 
3821  TypeFn castSignedness = TypeFn::cast_signed;
3822  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3823  return attr.getName() == "cast";
3824  });
3825  if (castIter != attrs.end()) {
3826  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3827  castSignedness = attr.getValue();
3828  }
3829 
3830  // TODO: Support fields with operators besides mult & add.
3831  Type outType = block.getArgument(2).getType();
3832  Value lhsAtOutType =
3833  helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
3834  Value rhsAtOutType =
3835  helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
3836  Value productAtOutType =
3837  helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
3838  Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3839  productAtOutType);
3840  helper.yieldOutputs({result});
3841 }
3842 
3843 ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
3844  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3845  if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
3846  return parser.emitError(parser.getCurrentLocation(),
3847  "expected 'indexing_maps' attribute");
3848  result.addAttribute("indexing_maps", *indexingMapsAttr);
3849 
3850  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
3851  regionBuilder);
3852 }
3853 
3855  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3857  p, getOperation(), getInputs(), getOutputs(),
3858  /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
3859 }
3860 
3861 LogicalResult ContractOp::verify() {
3862  int iterationSpaceDims = -1;
3863  // Map iter space dims to #occurrences in inputs' and output's affine_maps:
3864  // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
3865  // access an input operand (so occurrence count can be at most 2) and
3866  // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
3867  SmallVector<size_t> inOccurrences;
3868  SmallVector<size_t> outOccurrences;
3869 
3870  // A helper so that for each operand's affine_map and type we check that ...
3871  auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
3872  bool isInput) -> LogicalResult {
3873  // ... the affine_map is a projected permutation;
3874  if (!affineMap.isProjectedPermutation())
3875  return emitError("provided affine_map is not a projected permutation");
3876 
3877  // ... the rank of the affine_map's results and corresponding type match;
3878  if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
3879  if (affineMap.getNumResults() != shapedType.getRank())
3880  return emitError("ranks of shaped operand and results of corresponding "
3881  "affine_map differ");
3882  } else if (affineMap.getNumResults() != 0) {
3883  return emitError("affine_map specifies shaped access while operand has "
3884  "non-shaped type");
3885  }
3886 
3887  // ... the rank of the affine_map's domain is the same as those seen prior;
3888  if (iterationSpaceDims == -1) {
3889  iterationSpaceDims = affineMap.getNumDims();
3890  inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
3891  outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
3892  } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
3893  return emitError("iteration spaces of provided affine_maps differ");
3894  }
3895 
3896  // ... update counts of dims used to access either an input or the output.
3897  for (AffineExpr affineExpr : affineMap.getResults()) {
3898  auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
3899  if (!affineDimExpr)
3900  llvm_unreachable("affine_map is a projected permutation");
3901 
3902  if (isInput)
3903  inOccurrences[affineDimExpr.getPosition()] += 1;
3904  else
3905  outOccurrences[affineDimExpr.getPosition()] += 1;
3906  }
3907 
3908  return success();
3909  };
3910 
3911  for (auto &&[affineMap, operandType, isInput] :
3912  llvm::zip(getIndexingMapsArray(), getOperandTypes(),
3913  SmallVector<bool>{true, true, false})) {
3914  if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
3915  return failure(); // NB: checkAffineMapAndType will emit relevant error.
3916  }
3917 
3918  bool hasContractingDim = false;
3919  for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
3920  size_t inOccCount = inOccurrences[dimIndex];
3921  size_t outOccCount = outOccurrences[dimIndex];
3922 
3923  // We have a contracting dim if and only if ...
3924  hasContractingDim |= inOccCount == 2 && outOccCount == 0;
3925 
3926  if (inOccCount == 0 && outOccCount == 0)
3927  return emitError() << "iteration space dim at index " << dimIndex
3928  << " not used to access any operand";
3929 
3930  // NB: We disallow a dim which occurs for only one input operand and not
3931  // for the output. In terms of einsum semantics such dims have a
3932  // sensible meaning - namely an additional reduction per each such dim.
3933  // By contrast, the ContractionOpInterface does not know about this
3934  // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
3935  // while vector.contract's verifier accepts dims of this kind many of
3936  // its lowerings give up on encountering these dims.
3937  // TODO: Remove following once we have comprehensive support for input-only
3938  // reduction dims, at both the linalg- and vector-dialect levels.
3939  if (inOccCount == 1 && outOccCount != 1)
3940  return emitError()
3941  << "iteration space dim at index " << dimIndex
3942  << " is neither a contracting dim nor of parallel iteration type";
3943  }
3944 
3945  if (!hasContractingDim)
3946  return emitError("'indexing_maps' do not specify a contracting dimension");
3947 
3948  return success();
3949 }
3950 
3951 LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3952  return memref::foldMemRefCast(*this);
3953 }
3954 
3955 void ContractOp::getEffects(
3957  &effects) {
3958  if (hasPureTensorSemantics())
3959  return;
3960  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3961 }
3962 
3963 Speculation::Speculatability ContractOp::getSpeculatability() {
3964  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3965 }
3966 
3967 //===----------------------------------------------------------------------===//
3968 // Implementation of BatchMatmulOp
3969 //===----------------------------------------------------------------------===//
3971 BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3972  AffineExpr d0, d1, d2, d3;
3973  SmallVector<AffineMap> indexingMaps;
3974  bindDims(context, d0, d1, d2, d3);
3975  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
3976  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
3977  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
3978  return indexingMaps;
3979 }
3980 
3981 SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
3983  utils::IteratorType::parallel, utils::IteratorType::parallel,
3984  utils::IteratorType::parallel, utils::IteratorType::reduction};
3985 }
3986 
3987 unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
3988 
3989 std::string BatchMatmulOp::getLibraryCallName() {
3990  return generateLibraryCallName(getOperation());
3991 }
3992 
3993 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3994 /// the user defined indexing maps are not equal to default map.
3995 bool BatchMatmulOp::hasUserDefinedMaps() {
3996  SmallVector<AffineMap, 3> defaultMaps =
3997  getDefaultIndexingMaps(this->getContext());
3998  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3999  return defaultMaps != explicitMaps;
4000 }
4001 
4002 /// Returns true if the given bcastMap map is a valid broadcast map. A valid
4003 /// broadcast map must include K dimension.
4004 /// TODO: Strict inclusion of K dimension in the broadcast map is not
4005 /// necessary for both input matrices simultaneously. We can relax this
4006 /// condition to have K dimension for one input matrix map and infer the K
4007 /// dimension for other input matrix map from the one already having K
4008 /// dimension.
4009 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
4010  assert(bcastMap.getNumResults() < 3 &&
4011  "Expected less than 3 result dim expr.");
4012  bool isValid = false;
4013  enum Indices { batchPos, mPos, nPos, kPos };
4014  if (bcastMap.getNumResults() == 1) {
4015  AffineExpr expr = bcastMap.getResult(0);
4016  isValid = expr.isFunctionOfDim(kPos);
4017  } else if (bcastMap.getNumResults() == 2) {
4018  AffineExpr expr0 = bcastMap.getResult(0);
4019  AffineExpr expr1 = bcastMap.getResult(1);
4020  isValid =
4021  isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
4022  expr0.isFunctionOfDim(mPos)) &&
4023  expr1.isFunctionOfDim(kPos))
4024  : ((expr0.isFunctionOfDim(batchPos) &&
4025  expr1.isFunctionOfDim(kPos)) ||
4026  (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4027  }
4028  return isValid;
4029 }
4030 
4031 void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
4032  ArrayRef<NamedAttribute> attrs) {
4033  assert(block.getNumArguments() == 3 &&
4034  "BatchMatmulOp regionBuilder expects 3 (>=0) args");
4035  RegionBuilderHelper helper(b, block);
4036  SmallVector<Value> yields;
4037 
4038  TypeFn castVal = TypeFn::cast_signed;
4039  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4040  return attr.getName() == "cast";
4041  });
4042  if (castIter != attrs.end()) {
4043  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4044  castVal = attr.getValue();
4045  }
4046 
4047  auto toType = block.getArgument(2).getType();
4048  Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
4049  Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
4050  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
4051  Value addVal =
4052  helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
4053  yields.push_back(addVal);
4054  helper.yieldOutputs(yields);
4055 }
4056 
4057 ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4058  SmallVector<Attribute, 3> indexingMapsAttr;
4059  Attribute mapAttr;
4060  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4061  if (parser.parseEqual())
4062  return failure();
4063 
4064  if (parser.parseLSquare())
4065  return failure();
4066 
4067  do {
4068  if (parser.parseAttribute(mapAttr))
4069  return failure();
4070  if (!isa<AffineMapAttr>(mapAttr)) {
4071  return parser.emitError(parser.getCurrentLocation(),
4072  "expected affine map attribute");
4073  }
4074  indexingMapsAttr.push_back(mapAttr);
4075 
4076  if (parser.parseOptionalComma())
4077  break;
4078  } while (true);
4079 
4080  if (parser.parseRSquare())
4081  return failure();
4082  }
4083  // Initialize indexingMaps, if not supplied explicitly.
4084  if (indexingMapsAttr.empty()) {
4085  indexingMapsAttr = llvm::map_to_vector(
4086  BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4087  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4088  }
4089  result.addAttribute("indexing_maps",
4090  parser.getBuilder().getArrayAttr(indexingMapsAttr));
4091 
4092  return ::parseNamedStructuredOp(parser, result,
4093  BatchMatmulOp::getNumRegionArgs(),
4094  BatchMatmulOp::getRegionBuilder());
4095 }
4096 
4098  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4099  BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4100  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4101  if (!llvm::equal(getIndexingMaps(), indexingMaps))
4102  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4103 
4104  std::array<StringRef, 3> elidedAttrs = {
4105  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4106  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4107  elidedAttrs);
4108 }
4109 
4110 /// Verify the user defined indexing maps.
4111 LogicalResult BatchMatmulOp::verify() {
4112  // Verification of pure batch_matmul is handled by
4113  // verifyStructuredOpInterface().
4114  if (!hasUserDefinedMaps())
4115  return success();
4116 
4117  for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4118  if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
4119  return failure();
4120  }
4121  return success();
4122 }
4123 
4124 LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4126  return memref::foldMemRefCast(*this);
4127 }
4128 
4129 void BatchMatmulOp::getEffects(
4131  &effects) {
4132  if (hasPureTensorSemantics())
4133  return;
4134  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4135 }
4136 
4137 Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4138  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4139 }
4140 
4141 //===----------------------------------------------------------------------===//
4142 // ElementwiseOp
4143 //===----------------------------------------------------------------------===//
4144 //
4145 namespace {
4146 struct ArityGroupAndKind {
4147  // The enum class {Unary, Binary, Ternary, ..}
4148  ElementwiseArityGroup arityGroup;
4149 
4150  // The kind (e.g. `exp` or `add`) belonging to the arity group.
4151  union Kind {
4152  UnaryFn unaryFn;
4153  BinaryFn binaryFn;
4154  TernaryFn ternaryFn;
4155  } kind;
4156 };
4157 
4158 unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4159  return static_cast<unsigned>(arityGroup);
4160 }
4161 } // namespace
4162 
4163 static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
4164  constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4165  constexpr int lastBinary =
4166  static_cast<int>(ElementwiseCaseLimits::LastBinary);
4167  constexpr int lastTernary =
4168  static_cast<int>(ElementwiseCaseLimits::LastTernary);
4169 
4170  int val = static_cast<int>(kind);
4171  ArityGroupAndKind result;
4172 
4173  if (val < lastUnary) {
4174  result.arityGroup = ElementwiseArityGroup::Unary;
4175  result.kind.unaryFn = static_cast<UnaryFn>(val);
4176  return result;
4177  }
4178  if (val < lastBinary) {
4179  result.arityGroup = ElementwiseArityGroup::Binary;
4180  result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4181  return result;
4182  }
4183  if (val >= lastTernary) {
4184  llvm_unreachable("unhandled ElementwiseFn");
4185  }
4186  result.arityGroup = ElementwiseArityGroup::Ternary;
4187  result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4188  return result;
4189 }
4190 
4191 SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
4192  auto rank = getResultRank();
4193  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
4194 }
4195 
4197 ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4198  MLIRContext *context) {
4199  auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
4200  return SmallVector<AffineMap>(numMaps, map);
4201 }
4202 
4203 ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
4204  // Expect e.g. `kind = #linalg.elemwise_kind<add>`
4205  Attribute attr;
4206  mlir::linalg::ElementwiseKind elemwiseKindVal;
4207  if (parser.parseKeyword("kind") || parser.parseEqual())
4208  return failure();
4209 
4210  if (succeeded(parser.parseAttribute(attr))) {
4211  auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4212  if (!elemwiseKindAttr)
4213  return parser.emitError(parser.getCurrentLocation(),
4214  "expected ElementwiseKind attribute");
4215  elemwiseKindVal = elemwiseKindAttr.getValue();
4216  } else {
4217  return parser.emitError(parser.getCurrentLocation(),
4218  "expected operation 'kind' attribute");
4219  }
4220  result.addAttribute(
4221  "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
4222 
4223  // Parse optional `indexing_maps`
4224  SmallVector<Attribute, 3> indexingMapsAttr;
4225  Attribute mapAttr;
4226  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4227  if (parser.parseEqual())
4228  return failure();
4229  if (parser.parseLSquare())
4230  return failure();
4231  do {
4232  if (parser.parseAttribute(mapAttr))
4233  return failure();
4234  if (!isa<AffineMapAttr>(mapAttr))
4235  return parser.emitError(parser.getCurrentLocation(),
4236  "expected affine map attribute");
4237  indexingMapsAttr.push_back(mapAttr);
4238  if (parser.parseOptionalComma())
4239  break;
4240  } while (true);
4241  if (parser.parseRSquare())
4242  return failure();
4243  }
4244  // At this stage of parsing the only way to infer number of region
4245  // args is through op kind, as input output tensors are not parsed yet.
4246  auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
4247  int numRegionArgs =
4248  getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
4249  if (parseNamedStructuredOp(parser, result, numRegionArgs,
4250  ElementwiseOp::getRegionBuilder())) {
4251  return parser.emitError(parser.getCurrentLocation(),
4252  "unable to parse elemwise op");
4253  }
4254 
4255  // Initialize indexingMaps, if not supplied explicitly.
4256  if (indexingMapsAttr.empty()) {
4257  // We need to infer the numDims of the indexing maps from the output
4258  // type which is already parsed by now.
4259  auto resultType = result.operands[result.operands.size() - 1].getType();
4260  auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4261  if (!shapedType)
4262  return parser.emitError(parser.getCurrentLocation(),
4263  "return type needs to be shaped type");
4264  auto numDims = shapedType.getRank();
4265  indexingMapsAttr = llvm::map_to_vector(
4266  ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4267  parser.getContext()),
4268  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4269  }
4270 
4271  result.addAttribute("indexing_maps",
4272  parser.getBuilder().getArrayAttr(indexingMapsAttr));
4273  return success();
4274 }
4275 
4277  p << " kind=";
4278  p.printAttribute(getKindAttr());
4279  SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
4280  "indexing_maps"};
4281  unsigned arity =
4282  getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
4283  unsigned numDims = getResultRank();
4284 
4285  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4286  ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
4287  getContext()),
4288  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4289 
4290  if (!llvm::equal(getIndexingMaps(), indexingMaps))
4291  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4292 
4293  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4294  elidedAttrs);
4295 }
4296 
4297 LogicalResult ElementwiseOp::verify() {
4298  // All necessary checks are done either by
4299  // - EnumAttr (e.g. unknown operation kind)
4300  // - verifyStructuredOpInterface (incorrect map, sizes).
4301  return success();
4302 }
4303 
4304 /// Implements the block region builder for the ElementwiseOp. This is called by
4305 /// 'fillStructuredOpRegion'.
4306 void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
4307  ArrayRef<NamedAttribute> attrs) {
4308  ElementwiseKind elemwiseKind;
4309  for (auto attr : attrs) {
4310  if (attr.getName() == b.getStringAttr("kind")) {
4311  auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4312  assert(kindAttr && "op kind attribute incorrectly set");
4313  elemwiseKind = kindAttr.getValue();
4314  break;
4315  }
4316  }
4317 
4318  ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
4319  auto arityGroup = groupAndKind.arityGroup;
4320  auto kind = groupAndKind.kind;
4321  assert(block.getNumArguments() ==
4322  getArityGroupAsUInt(arityGroup) + 1 /*output*/
4323  && "Elementwise regionBuilder number of block args mismatch");
4324 
4325  RegionBuilderHelper helper(b, block);
4326  SmallVector<Value> yields;
4327  Value result;
4328 
4329  if (arityGroup == ElementwiseArityGroup::Unary) {
4330  result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
4331 
4332  } else if (arityGroup == ElementwiseArityGroup::Binary) {
4333  result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
4334  block.getArgument(1));
4335 
4336  } else if (arityGroup == ElementwiseArityGroup::Ternary) {
4337  result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
4338  block.getArgument(1), block.getArgument(2));
4339 
4340  } else {
4341  assert(false && "found unhandled category in elemwise");
4342  }
4343 
4344  yields.push_back(result);
4345  helper.yieldOutputs(yields);
4346 }
4347 
4348 LogicalResult ElementwiseOp::fold(FoldAdaptor,
4350  return memref::foldMemRefCast(*this);
4351 }
4352 
4353 void ElementwiseOp::getEffects(
4355  &effects) {
4356  if (hasPureTensorSemantics())
4357  return;
4358  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4359 }
4360 
4361 Speculation::Speculatability ElementwiseOp::getSpeculatability() {
4362  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4363 }
4364 
4365 //===----------------------------------------------------------------------===//
4366 // PackOp/UnPackOp Common
4367 //===----------------------------------------------------------------------===//
4368 // Given the (potentially) updated packed type, `newPackedTy`, generates an
4369 // updated mixed-tile-sizes attribute. A tile size is updated only
4370 // when:
4371 // * a dim from newPackedTy is static, and
4372 // * the corresponding size from mixedTiles is still dynamic.
4373 // Otherwise, the original tile size is preserved.
4374 // Note - packed-type-dim and mixed-tile-size should always match!
4377  SmallVector<OpFoldResult> mixedTiles) {
4378  SmallVector<OpFoldResult> newMixedTileSizes;
4379  for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4380  .getShape()
4381  .take_back(mixedTiles.size()),
4382  mixedTiles)) {
4383  int64_t shape = std::get<0>(it);
4384  if (shape == ShapedType::kDynamic) {
4385  newMixedTileSizes.push_back(std::get<1>(it));
4386  continue;
4387  }
4388 
4389  // If the current result dim is static, update the dynamic mixed-size
4390  // (provided the original value is dynamic).
4391  OpFoldResult tile = std::get<1>(it);
4392  if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
4393  // Already a constant
4394  newMixedTileSizes.push_back(tile);
4395  } else {
4396  assert(getConstantIntValue(tile).value() == shape &&
4397  "tile size and dim size don't match!");
4398  newMixedTileSizes.push_back(
4399  (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
4400  }
4401  }
4402 
4403  return newMixedTileSizes;
4404 }
4405 
4406 template <typename OpTy>
4407 static LogicalResult
4409  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4410  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4411  "applies to only pack or unpack operations");
4412  int64_t destRank = op.getDestRank();
4413  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
4414  reifiedReturnShapes[0] =
4415  tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
4416  return success();
4417 }
4418 
4419 template <typename OpTy>
4421  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4422  "applies to only pack or unpack operations");
4423  DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
4424  ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
4425  SmallVector<OpFoldResult> tiles = op.getMixedTiles();
4426  assert(tiles.size() == dimsToTile.size() &&
4427  "tiles must match indices of dimension to block");
4428  // bind the dimension `i` with the tile factor.
4429  for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
4430  dimAndTileMapping[dimsToTile[i]] = tiles[i];
4431  return dimAndTileMapping;
4432 }
4433 
4434 template <typename OpTy>
4436  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4437  "applies to only pack or unpack operations");
4438  Builder builder(op);
4439  SmallVector<OpFoldResult> mixedInnerTiles;
4440  unsigned dynamicValIndex = 0;
4441  for (int64_t staticTile : op.getStaticInnerTiles()) {
4442  if (!ShapedType::isDynamic(staticTile))
4443  mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
4444  else
4445  mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
4446  }
4447  return mixedInnerTiles;
4448 }
4449 
4450 template <typename OpTy>
4452  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4453  "applies to only pack or unpack operations");
4454  SmallVector<Value> dynamicTiles;
4455  SmallVector<int64_t> staticTiles;
4456  dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
4457  return staticTiles;
4458 }
4459 
4460 /// Returns true if `dimsPos` is invalid. It is invalid when:
4461 /// a) It contains duplicate.
4462 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
4463 /// c) The number of elements in `dimsPos` is > than `rank`.
4465  size_t rank) {
4466  size_t dimsPosSize = dimsPos.size();
4467  if (dimsPosSize > rank)
4468  return true;
4469  DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
4470  if (dimsPosSize != uniqued.size())
4471  return true;
4472  return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
4473  return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
4474  });
4475 }
4476 
4477 /// Returns true if the dimension of `sourceShape` is smaller than the dimension
4478 /// of the `limitShape`.
4479 static bool areAllInBound(ArrayRef<int64_t> sourceShape,
4480  ArrayRef<int64_t> limitShape) {
4481  assert(
4482  sourceShape.size() == limitShape.size() &&
4483  "expected source shape rank, and limit of the shape to have same rank");
4484  return llvm::all_of(
4485  llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
4486  int64_t sourceExtent = std::get<0>(it);
4487  int64_t limit = std::get<1>(it);
4488  return ShapedType::isDynamic(sourceExtent) ||
4489  ShapedType::isDynamic(limit) || sourceExtent <= limit;
4490  });
4491 }
4492 
4493 template <typename OpTy>
4494 static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
4495  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4496  "applies to only pack or unpack operations");
4497  Operation *op = packOrUnPack.getOperation();
4498 
4499  // Return true if we have a zero-value tile.
4500  auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
4501  return llvm::any_of(tiles, isZeroInteger);
4502  };
4503 
4504  // Verify tiles. Do not allow zero tiles.
4505  SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
4506  if (hasZeros(mixedTiles))
4507  return op->emitError("invalid zero tile factor");
4508 
4509  // Verify inner_dims_pos and outer_dims_perm.
4510  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4511  ? packOrUnPack.getSourceType()
4512  : packOrUnPack.getDestType();
4513  size_t unpackedRank = unpackedType.getRank();
4514  ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
4515  ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
4517  return op->emitError("invalid inner_dims_pos vector");
4518  if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
4519  return op->emitError("invalid outer_dims_perm vector");
4520  if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
4521  return op->emitError("outer_dims_perm must be a permutation or empty");
4522 
4523  // Tiling factors must be less than or equal to the input rank for pack (or
4524  // output rank for unpack), and must match the number of `inner_dims_pos`.
4525  if (mixedTiles.size() > unpackedRank) {
4526  return op->emitError("tiling factors must be less than or equal to the "
4527  "input rank for pack or output rank for unpack");
4528  }
4529  if (mixedTiles.size() != innerDimsPos.size()) {
4530  return op->emitError(
4531  "tiling factors must equal the number of dimensions to tile");
4532  }
4533 
4534  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4535  ? packOrUnPack.getDestType()
4536  : packOrUnPack.getSourceType();
4537  size_t packedRank = packedType.getRank();
4538  // Require output rank to match input rank + number of blocking factors.
4539  size_t expectedPackedRank = unpackedRank + mixedTiles.size();
4540  if (expectedPackedRank != packedRank) {
4541  return op->emitError(
4542  "packed rank != (unpacked rank + num tiling factors), got ")
4543  << packedRank << " != " << expectedPackedRank;
4544  }
4545 
4546  // Verify result shape is greater than the minimum expected
4547  // by the pack operation, and that the output shape
4548  // represents full tiles.
4549  RankedTensorType expectedPackedType = PackOp::inferPackedType(
4550  unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
4551  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
4552  return op->emitError("the shape of output is not large enough to hold the "
4553  "packed data. Expected at least ")
4554  << expectedPackedType << ", got " << packedType;
4555  }
4556  if (!llvm::all_of(
4557  llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4558  mixedTiles),
4559  [](std::tuple<int64_t, OpFoldResult> it) {
4560  int64_t shape = std::get<0>(it);
4561  if (Attribute attr =
4562  llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4563  IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
4564  int64_t staticTileSize = intAttr.getValue().getSExtValue();
4565  return shape == staticTileSize;
4566  }
4567  return ShapedType::isDynamic(shape);
4568  })) {
4569  return op->emitError("mismatch in inner tile sizes specified and shaped of "
4570  "tiled dimension in the packed type");
4571  }
4572  return success();
4573 }
4574 
4575 namespace {
4576 /// Subset of PackOp/UnPackOp fields used to compute the result of applying
4577 /// various permutations to the op.
4578 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
4579 // these. These may or may not become true foldings / canonicalizations
4580 // depending on how aggressive we want to be in automatically folding
4581 // transposes.
4582 struct PackOrUnPackTransposeResult {
4586 };
4587 } // namespace
4588 
4589 template <typename OpTy>
4590 static PackOrUnPackTransposeResult
4592  ArrayRef<int64_t> innerPermutation,
4593  ArrayRef<int64_t> outerPermutation) {
4594  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4595  "applies to only pack or unpack operations");
4596  assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4597  "some permutation must be non-empty");
4598  PackOrUnPackTransposeResult metadata;
4599  metadata.innerDimsPos =
4600  SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
4601  metadata.innerTiles =
4602  SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
4603  int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4604  ? packOrUnPackOp.getSourceRank()
4605  : packOrUnPackOp.getDestRank();
4606  metadata.outerDimsPerm =
4607  packOrUnPackOp.getOuterDimsPerm().empty()
4608  ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
4609  : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
4610  if (!innerPermutation.empty()) {
4611  assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4612  isPermutationVector(innerPermutation) &&
4613  "invalid inner permutation");
4614  applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
4615  applyPermutationToVector(metadata.innerTiles, innerPermutation);
4616  }
4617  if (!outerPermutation.empty()) {
4618  assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4619  isPermutationVector(outerPermutation) &&
4620  "invalid outer permutation");
4621  applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
4622  }
4623  return metadata;
4624 }
4625 
4626 //===----------------------------------------------------------------------===//
4627 // PackOp
4628 //===----------------------------------------------------------------------===//
4629 
4630 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
4631  setNameFn(getResult(), "pack");
4632 }
4633 
4634 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
4637  std::optional<Value> paddingValue,
4639  assert(innerDimsPos.size() == innerTiles.size() &&
4640  "number of tile sizes specified must match the specified number of "
4641  "original dimensions to be tiled");
4642  SmallVector<int64_t> staticTileSizes;
4643  SmallVector<Value> dynamicTileSizes;
4644  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4645  build(builder, state, dest.getType(), source, dest,
4646  paddingValue ? *paddingValue : nullptr,
4647  outerDimsPerm.empty() ? nullptr
4649  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4650  builder.getDenseI64ArrayAttr(staticTileSizes));
4651 }
4652 
4653 LogicalResult
4655  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4656  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4657 }
4658 
4659 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
4660  return getDimAndTileMappingImpl(*this);
4661 }
4662 
4663 SmallVector<OpFoldResult> PackOp::getMixedTiles() {
4664  return getMixedTilesImpl(*this);
4665 }
4666 
4667 SmallVector<int64_t> PackOp::getStaticTiles() {
4668  return getStaticTilesImpl(*this);
4669 }
4670 
4671 ArrayRef<int64_t> PackOp::getAllOuterDims() {
4672  ShapedType inputType = getSourceType();
4673  int64_t inputRank = inputType.getRank();
4674  return getDestType().getShape().take_front(inputRank);
4675 }
4676 
4677 SmallVector<int64_t> PackOp::getTiledOuterDims() {
4678  auto innerDimsPos = getInnerDimsPos();
4679  auto packedShape = getDestType().getShape();
4681 
4682  for (auto index : innerDimsPos)
4683  res.push_back(packedShape[index]);
4684 
4685  return res;
4686 }
4687 
4688 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
4690  ArrayRef<int64_t> outputShape,
4693  SmallVector<int64_t> outputTileSizes(
4694  outputShape.take_front(inputShape.size()));
4695  if (!outerDimsPerm.empty()) {
4696  assert(outerDimsPerm.size() == outputTileSizes.size() &&
4697  "expected output and outer_dims_perm to have same size");
4698  applyPermutationToVector(outputTileSizes,
4700  }
4701  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
4702  if (ShapedType::isDynamic(inputShape[pos]))
4703  continue;
4704  std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
4705 
4706  if (!constantTile) {
4707  if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4708  (inputShape[pos] % outputTileSizes[pos] != 0))
4709  return true;
4710  } else if (inputShape[pos] % (*constantTile) != 0) {
4711  return true;
4712  }
4713  }
4714  return false;
4715 }
4716 
4717 LogicalResult PackOp::verify() {
4718  if (failed(commonVerifierPackAndUnPackOp(*this)))
4719  return failure();
4720 
4721  // Verify padding value, and bail out if the tile does not divide the
4722  // dimension fully. In the case of dynamic tile factors or dimensions, having
4723  // a partial tile is undefined behavior.
4724  auto paddingValue = getPaddingValue();
4725  if (paddingValue &&
4726  paddingValue.getType() != getSourceType().getElementType()) {
4727  return emitOpError("expected padding_value has ")
4728  << getSourceType().getElementType()
4729  << " but got: " << paddingValue.getType();
4730  }
4731 
4732  if (!paddingValue &&
4733  requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
4734  getDestType().getShape(), getOuterDimsPerm(),
4735  getMixedTiles())) {
4736  return emitOpError(
4737  "invalid tile factor or output size provided. Only full tiles are "
4738  "supported when padding_value is not set");
4739  }
4740  return success();
4741 }
4742 
4743 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
4744 /// Value's to kDynamic, even if they are arith.constant values.
4745 static SmallVector<int64_t>
4747  SmallVector<int64_t> result;
4748  for (auto o : ofrs) {
4749  // Have to do this first, as getConstantIntValue special-cases constants.
4750  if (llvm::dyn_cast_if_present<Value>(o))
4751  result.push_back(ShapedType::kDynamic);
4752  else
4753  result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
4754  }
4755  return result;
4756 }
4757 
4758 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
4759 /// the packed type. Having a shared helper helps implement these two methods in
4760 /// a way that ensures that they agree on which dimensions are dynamic.
4762  ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
4764  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
4765  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4766  if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4767  continue;
4768  if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4769  resultShape[tiledDim.value()] = ShapedType::kDynamic;
4770  continue;
4771  }
4772  resultShape[tiledDim.value()] = llvm::divideCeilSigned(
4773  resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4774  }
4775 
4776  // Swap tile loops if outer_dims_perm is available.
4777  if (!outerDimsPerm.empty())
4779 
4780  // Append the inner tile dimensions.
4781  resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4782  return resultShape;
4783 }
4784 
4785 SmallVector<OpFoldResult> PackOp::getResultShape(
4786  OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
4789  SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
4790 
4791  AffineExpr s0, s1;
4792  bindSymbols(builder.getContext(), s0, s1);
4793  AffineExpr ceilDivExpr = s0.ceilDiv(s1);
4794  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4795  resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
4796  builder, loc, ceilDivExpr,
4797  {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4798  }
4799  if (!outerDimsPerm.empty())
4801  resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4802 
4803  SmallVector<int64_t> resultTypeShape =
4805  asShapeWithAnyValueAsDynamic(innerTileSizes),
4807 
4808  // Fix-up `resultDims` to ensure that they are Value's if and only if the
4809  // result type shape says it's a dynamic dim. This is needed as callers may
4810  // use dispatchIndexOpFoldResults on the result, and rely on exact number of
4811  // dynamic dims returned by that.
4812  for (unsigned i = 0; i < resultDims.size(); ++i) {
4813  if (!ShapedType::isDynamic(resultTypeShape[i]))
4814  continue;
4815  resultDims[i] =
4816  getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
4817  }
4818 
4819  return resultDims;
4820 }
4821 
4822 /// Get the expected packed type based on source type, tile factors, position of
4823 /// the inner tiles and permutation of the outer tiled loop.
4824 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4825  ArrayRef<int64_t> innerTileSizes,
4829  sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
4830  return RankedTensorType::get(resultShape, sourceType.getElementType());
4831 }
4832 
4833 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
4834  ArrayRef<OpFoldResult> innerTileSizes,
4837  AffineExpr dim0, dim1;
4838  bindDims(b.getContext(), dim0, dim1);
4839  auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4840  return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
4841  {v1, v2});
4842  };
4843 
4844  SmallVector<OpFoldResult> mixedSizes;
4845  for (auto [index, value] : llvm::enumerate(
4846  llvm::cast<RankedTensorType>(source.getType()).getShape())) {
4847  if (ShapedType::isDynamic(value))
4848  mixedSizes.push_back(
4849  b.create<tensor::DimOp>(loc, source, index).getResult());
4850  else
4851  mixedSizes.push_back(b.getIndexAttr(value));
4852  }
4853  for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
4854  int64_t dimPos = std::get<0>(it);
4855  OpFoldResult tileSize = std::get<1>(it);
4856  mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4857  }
4858  if (!outerDimsPerm.empty())
4859  applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
4860 
4861  mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4862  auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4863  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4864 }
4865 
4866 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
4867  ArrayRef<int64_t> innerPermutation,
4868  ArrayRef<int64_t> outerPermutation) {
4869  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4870  *this, innerPermutation, outerPermutation);
4871  Value transposedDest =
4872  createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4873  metadata.innerDimsPos, metadata.outerDimsPerm);
4874  return b.create<PackOp>(loc, getSource(), transposedDest,
4875  metadata.innerDimsPos, metadata.innerTiles,
4876  getPaddingValue(), metadata.outerDimsPerm);
4877 }
4878 
4879 /// Returns true if the tiles and the tiled dims are constant.
4880 template <typename OpTy>
4882  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4883  "applies to only pack or unpack operations");
4884  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4885  ? op.getDestType()
4886  : op.getSourceType();
4887  SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
4888  for (auto [dimDest, tile] : llvm::zip(
4889  packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4890  std::optional<int64_t> constTileSize = getConstantIntValue(tile);
4891  if (!constTileSize || ShapedType::isDynamic(dimDest))
4892  return false;
4893  }
4894  return true;
4895 }
4896 
4897 Speculation::Speculatability PackOp::getSpeculatability() {
4898  if (getPaddingValue())
4900 
4901  // The verifier rejects already operations if we can statically prove that the
4902  // sizes of the tiles do not divide perfectly the dimension; thus, check only
4903  // to have constant tiles and tiled inner dimensions.
4904  if (!areTilesAndTiledDimsAllConstant(*this))
4906 
4908 }
4909 
4910 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same
4911 // dimensions for pack and unpack.
4912 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
4913  if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4914  return false;
4915  if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4916  return true;
4917  // Outer dims permutation is optional.
4918  // To compare unbalanced pack-unpack pair, treat no permutation as equal to
4919  // identity permutation.
4920  return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
4921  isIdentityPermutation(unPackOp.getOuterDimsPerm());
4922 }
4923 
4924 // Return true if pack and unpack have the same tiles.
4925 // Same SSA values or same integer constants.
4926 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
4927  auto packTiles = packOp.getMixedTiles();
4928  auto unPackTiles = unPackOp.getMixedTiles();
4929  if (packTiles.size() != unPackTiles.size())
4930  return false;
4931  for (size_t i = 0, e = packTiles.size(); i < e; i++) {
4932  if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
4933  return false;
4934  }
4935  return true;
4936 }
4937 
4938 /// Returns true if the pack op does not need a padding value.
4939 static bool paddingIsNotNeeded(PackOp op) {
4940  auto srcType = op.getSourceType();
4941  if (llvm::any_of(op.getInnerDimsPos(),
4942  [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4943  return false;
4944  if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4945  return false;
4946  return !PackOp::requirePaddingValue(
4947  srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4948  op.getOuterDimsPerm(), op.getMixedTiles());
4949 }
4950 
4951 /// Returns true if the `srcShape` or `destShape` is different from the one in
4952 /// `packOp` and populates each with the inferred static shape.
4953 static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
4954  SmallVectorImpl<int64_t> &destShape) {
4955  bool changeNeeded = false;
4956  srcShape.assign(packOp.getSourceType().getShape().begin(),
4957  packOp.getSourceType().getShape().end());
4958  destShape.assign(packOp.getDestType().getShape().begin(),
4959  packOp.getDestType().getShape().end());
4960  llvm::SmallSetVector<int64_t, 4> innerDims;
4961  innerDims.insert_range(packOp.getInnerDimsPos());
4962  SmallVector<int64_t> inverseOuterDimsPerm;
4963  if (!packOp.getOuterDimsPerm().empty())
4964  inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
4965  int srcRank = packOp.getSourceRank();
4966  for (auto i : llvm::seq<int64_t>(0, srcRank)) {
4967  if (innerDims.contains(i))
4968  continue;
4969  int64_t srcPos = i;
4970  int64_t destPos = i;
4971  if (!inverseOuterDimsPerm.empty())
4972  destPos = inverseOuterDimsPerm[srcPos];
4973  if (ShapedType::isDynamic(srcShape[srcPos]) ==
4974  ShapedType::isDynamic(destShape[destPos])) {
4975  continue;
4976  }
4977  int64_t size = srcShape[srcPos];
4978  if (ShapedType::isDynamic(size))
4979  size = destShape[destPos];
4980  srcShape[srcPos] = size;
4981  destShape[destPos] = size;
4982  changeNeeded = true;
4983  }
4984  return changeNeeded;
4985 }
4986 
4987 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4988  // Fold an pack(unpack(x)) to x.
4989  if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4990  if (unPackOp.getSourceType() != packOp.getDestType())
4991  return failure();
4992  if (packOp.getPaddingValue() ||
4993  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4994  !haveSameTiles(packOp, unPackOp))
4995  return failure();
4996  rewriter.replaceOp(packOp, unPackOp.getSource());
4997  return success();
4998  }
4999 
5000  // Fold optional PaddingValue operand away if padding is not needed.
5001  if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
5002  rewriter.startOpModification(packOp);
5003  packOp.getPaddingValueMutable().clear();
5004  rewriter.finalizeOpModification(packOp);
5005  return success();
5006  }
5007 
5008  // Insert tensor.cast ops if static shape inference is available..
5009  SmallVector<int64_t> srcShape, destShape;
5010  if (inferStaticShape(packOp, srcShape, destShape)) {
5011  Location loc = packOp.getLoc();
5012  Value source = packOp.getSource();
5013  if (srcShape != packOp.getSourceType().getShape()) {
5014  auto newSrcType = packOp.getSourceType().clone(srcShape);
5015  source =
5016  rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
5017  }
5018  Value dest = packOp.getDest();
5019  RankedTensorType originalResultType = packOp.getDestType();
5020  bool needUpdateDestType = (destShape != originalResultType.getShape());
5021  if (needUpdateDestType) {
5022  auto newDestType = packOp.getDestType().clone(destShape);
5023  dest =
5024  rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
5025  }
5026  rewriter.modifyOpInPlace(packOp, [&] {
5027  packOp.getSourceMutable().assign(source);
5028  packOp.getDestMutable().assign(dest);
5029  packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
5030  });
5031  // Insert a cast if needed
5032  if (needUpdateDestType) {
5033  rewriter.setInsertionPointAfter(packOp);
5034  auto castOp =
5035  rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
5036  rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
5037  }
5038  return success();
5039  }
5040 
5041  return failure();
5042 }
5043 
5044 template <typename PackOrUnpackOp>
5045 static bool isLikePadUnPad(PackOrUnpackOp packOp,
5046  RankedTensorType packedTensorType) {
5047  static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5048  std::is_same<PackOrUnpackOp, UnPackOp>::value,
5049  "Function meant for pack/unpack");
5050  // This is a pad if packing only adds ones and we don't transpose dimensions.
5051 
5052  // Check that we are not transposing any dimensions.
5053  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
5054  int64_t numPackedDims = innerDimsPos.size();
5055  auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5056  if (orderedDims != innerDimsPos) {
5057  // Dimensions don't happen in order.
5058  return false;
5059  }
5060 
5061  ArrayRef<int64_t> packedShape = packedTensorType.getShape();
5062  int64_t packedRank = packedTensorType.getRank();
5063  // At this point we know that we are taking numPackedDims outer
5064  // dimensions and pushing them all the way as the inner most dimensions.
5065  // What's left on the outer most dimensions is, in this order:
5066  // - the factor of the packed dimensions, then
5067  // - the untouched dimensions
5068  // This shifting inward of dimensions is a no-op (as opposed to a transpose)
5069  // if all the dimensions that bubble outerward are ones.
5070  // Therefore check that all the dimensions but the numPackedDims inner most
5071  // ones are ones.
5072  return llvm::all_of(
5073  llvm::seq<int64_t>(0, packedRank - numPackedDims),
5074  [&packedShape](int64_t i) { return packedShape[i] == 1; });
5075 }
5076 
5077 bool PackOp::isLikePad() {
5078  auto packedTensorType =
5079  llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
5080  return isLikePadUnPad(*this, packedTensorType);
5081 }
5082 
5083 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
5084  std::optional<Attribute> paddingValue;
5085  if (auto pad = adaptor.getPaddingValue())
5086  paddingValue = pad;
5087  if (OpFoldResult reshapedSource = reshapeConstantSource(
5088  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5089  getDestType(), paddingValue))
5090  return reshapedSource;
5091  return {};
5092 }
5093 
5094 /// Folds a tensor.cast op into a consuming PackOp op if the
5095 /// `tensor.cast` has source that is more static than the consuming op.
5096 ///
5097 /// Example:
5098 /// ```mlir
5099 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
5100 /// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
5101 /// ```
5102 ///
5103 /// folds into:
5104 ///
5105 /// ```mlir
5106 /// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
5107 /// ```
5108 struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
5110 
5111  LogicalResult matchAndRewrite(PackOp op,
5112  PatternRewriter &rewriter) const override {
5114  return failure();
5115 
5116  SmallVector<Type> newResultTypes(op->getResultTypes());
5117  SmallVector<Value> newOperands =
5119 
5120  // Get the updated mixed-tile-sizes attribute.
5121  SmallVector<OpFoldResult> newMixedTileSizes =
5122  getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
5123 
5124  // Clone op.
5125  // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5126  // this point. However, in practice, we use them for things that we'd like
5127  // to preserve. Implement a better abstraction.
5128  PackOp newOp = rewriter.create<PackOp>(
5129  op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
5130  newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
5131  newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5132 
5133  // Replace op.
5134  Value oldResult = op.getResult();
5135  Value newResult = newOp.getResult();
5136  Value replacement = (newResult.getType() != oldResult.getType())
5137  ? rewriter.create<tensor::CastOp>(
5138  op->getLoc(), oldResult.getType(), newResult)
5139  : newResult;
5140 
5141  rewriter.replaceOp(op, {replacement});
5142 
5143  return success();
5144  }
5145 };
5146 
5147 //===----------------------------------------------------------------------===//
5148 // UnPackOp
5149 //===----------------------------------------------------------------------===//
5150 
5151 void UnPackOp::getAsmResultNames(
5152  function_ref<void(Value, StringRef)> setNameFn) {
5153  setNameFn(getResult(), "unpack");
5154 }
5155 
5156 LogicalResult
5158  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5159  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5160 }
5161 
5162 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
5163  return getDimAndTileMappingImpl(*this);
5164 }
5165 
5166 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
5167  return getMixedTilesImpl(*this);
5168 }
5169 
5170 SmallVector<int64_t> UnPackOp::getStaticTiles() {
5171  return getStaticTilesImpl(*this);
5172 }
5173 
5174 ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
5175  ShapedType destType = getDestType();
5176  int64_t destRank = destType.getRank();
5177  return getSourceType().getShape().take_front(destRank);
5178 }
5179 
5180 SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
5181  auto innerDimsPos = getInnerDimsPos();
5182  auto packedShape = getSourceType().getShape();
5184 
5185  for (auto index : innerDimsPos)
5186  res.push_back(packedShape[index]);
5187 
5188  return res;
5189 }
5190 
5191 LogicalResult UnPackOp::verify() {
5192  return commonVerifierPackAndUnPackOp(*this);
5193 }
5194 
5195 Speculation::Speculatability UnPackOp::getSpeculatability() {
5196  // See PackOp::getSpeculatability.
5197  if (!areTilesAndTiledDimsAllConstant(*this))
5199 
5201 }
5202 
5203 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
5207  assert(innerDimsPos.size() == innerTiles.size() &&
5208  "number of tile sizes specified must match the specified number of "
5209  "original dimensions to be tiled");
5210  SmallVector<int64_t> staticTileSizes;
5211  SmallVector<Value> dynamicTileSizes;
5212  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5213  build(builder, state, dest.getType(), source, dest,
5214  outerDimsPerm.empty() ? nullptr
5216  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5217  builder.getDenseI64ArrayAttr(staticTileSizes));
5218 }
5219 
5220 Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
5221  Value source,
5222  ArrayRef<OpFoldResult> innerTileSizes,
5225  AffineExpr sym0, sym1;
5226  bindSymbols(b.getContext(), sym0, sym1);
5227  auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5228  return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
5229  };
5230 
5231  SmallVector<OpFoldResult> mixedSizes;
5232  auto srcType = llvm::cast<RankedTensorType>(source.getType());
5233  for (auto i :
5234  llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
5235  if (srcType.isDynamicDim(i))
5236  mixedSizes.push_back(b.create<tensor::DimOp>(loc, source, i).getResult());
5237  else
5238  mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
5239  }
5240  if (!outerDimsPerm.empty()) {
5241  applyPermutationToVector<OpFoldResult>(
5242  mixedSizes, invertPermutationVector(outerDimsPerm));
5243  }
5244 
5245  for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
5246  mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5247 
5248  auto elemType = srcType.getElementType();
5249  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
5250 }
5251 
5252 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
5253  Value transposedSource,
5254  ArrayRef<int64_t> innerPermutation,
5255  ArrayRef<int64_t> outerPermutation) {
5256  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5257  *this, innerPermutation, outerPermutation);
5258  return b.create<UnPackOp>(loc, transposedSource, getDest(),
5259  metadata.innerDimsPos, metadata.innerTiles,
5260  metadata.outerDimsPerm);
5261 }
5262 
5263 /// Returns true if the `srcShape` or `destShape` is different from the one in
5264 /// `op` and populates each with the inferred static shape.
5265 static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
5266  SmallVectorImpl<int64_t> &destShape) {
5267  bool changeNeeded = false;
5268  srcShape.assign(op.getSourceType().getShape().begin(),
5269  op.getSourceType().getShape().end());
5270  destShape.assign(op.getDestType().getShape().begin(),
5271  op.getDestType().getShape().end());
5272  llvm::SmallSetVector<int64_t, 4> innerDims;
5273  innerDims.insert_range(op.getInnerDimsPos());
5274  SmallVector<int64_t> inverseOuterDimsPerm;
5275  if (!op.getOuterDimsPerm().empty())
5276  inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
5277  int destRank = op.getDestRank();
5278  for (auto i : llvm::seq<int64_t>(0, destRank)) {
5279  if (innerDims.contains(i))
5280  continue;
5281  int64_t srcPos = i;
5282  int64_t destPos = i;
5283  if (!inverseOuterDimsPerm.empty())
5284  srcPos = inverseOuterDimsPerm[destPos];
5285  if (ShapedType::isDynamic(srcShape[srcPos]) ==
5286  ShapedType::isDynamic(destShape[destPos])) {
5287  continue;
5288  }
5289  int64_t size = srcShape[srcPos];
5290  if (ShapedType::isDynamic(size))
5291  size = destShape[destPos];
5292  srcShape[srcPos] = size;
5293  destShape[destPos] = size;
5294  changeNeeded = true;
5295  }
5296  return changeNeeded;
5297 }
5298 
5299 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5300  PatternRewriter &rewriter) {
5301  /// unpack(pack(x)) -> x
5302  if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5303  if (packOp.getSourceType() != unPackOp.getDestType())
5304  return failure();
5305  if (packOp.getPaddingValue() ||
5306  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
5307  !haveSameTiles(packOp, unPackOp))
5308  return failure();
5309  rewriter.replaceOp(unPackOp, packOp.getSource());
5310  return success();
5311  }
5312  /// unpack(destinationStyleOp(x)) -> unpack(x)
5313  if (auto dstStyleOp =
5314  unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5315  auto destValue = cast<OpResult>(unPackOp.getDest());
5316  Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5317  rewriter.modifyOpInPlace(unPackOp,
5318  [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5319  return success();
5320  }
5321  /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
5322  if (unPackOp->hasOneUse()) {
5323  auto extractSliceUser =
5324  dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5325  if (extractSliceUser &&
5326  areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
5327  areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
5328  extractSliceUser.getSourceType().getRank() ==
5329  extractSliceUser.getResultType().getRank()) {
5330  OpBuilder::InsertionGuard g(rewriter);
5331  rewriter.setInsertionPoint(unPackOp);
5332  auto newDest = rewriter.create<tensor::ExtractSliceOp>(
5333  unPackOp->getLoc(), unPackOp.getDest(),
5334  extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
5335  extractSliceUser.getMixedStrides());
5336  rewriter.modifyOpInPlace(unPackOp, [&]() {
5337  unPackOp.setDpsInitOperand(0, newDest);
5338  unPackOp.getResult().setType(newDest.getType());
5339  });
5340  rewriter.replaceOp(extractSliceUser, unPackOp);
5341  return success();
5342  }
5343  }
5344 
5345  // Insert tensor.cast ops if static shape inference is available..
5346  SmallVector<int64_t> srcShape, destShape;
5347  if (inferStaticShape(unPackOp, srcShape, destShape)) {
5348  Location loc = unPackOp.getLoc();
5349  Value source = unPackOp.getSource();
5350  if (srcShape != unPackOp.getSourceType().getShape()) {
5351  auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5352  source = rewriter.create<tensor::CastOp>(loc, newSrcType,
5353  unPackOp.getSource());
5354  }
5355  Value dest = unPackOp.getDest();
5356  if (destShape != unPackOp.getDestType().getShape()) {
5357  auto newDestType = unPackOp.getDestType().clone(destShape);
5358  dest =
5359  rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
5360  }
5361  Value newOp = rewriter.create<UnPackOp>(
5362  loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
5363  unPackOp.getOuterDimsPerm());
5364  rewriter.replaceOpWithNewOp<tensor::CastOp>(
5365  unPackOp, unPackOp.getResult().getType(), newOp);
5366  return success();
5367  }
5368 
5369  return failure();
5370 }
5371 
5372 bool UnPackOp::isLikeUnPad() {
5373  RankedTensorType packedTensorType = getSourceType();
5374  return isLikePadUnPad(*this, packedTensorType);
5375 }
5376 
5377 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
5378  if (OpFoldResult reshapedSource = reshapeConstantSource(
5379  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5380  getResult().getType()))
5381  return reshapedSource;
5382  return {};
5383 }
5384 
5385 /// Folds a tensor.cast op into a consuming UnPackOp op if the
5386 /// `tensor.cast` has source that is more static than the consuming op.
5387 ///
5388 /// Example:
5389 /// ```mlir
5390 /// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
5391 /// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
5392 /// ```
5393 ///
5394 /// folds into:
5395 ///
5396 /// ```mlir
5397 /// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
5398 /// ```
5399 struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
5401 
5402  LogicalResult matchAndRewrite(UnPackOp op,
5403  PatternRewriter &rewriter) const override {
5405  return failure();
5406 
5407  SmallVector<Type> newResultTypes(op->getResultTypes());
5408  SmallVector<Value> newOperands =
5410  Value sourceTensor = newOperands[0];
5411 
5412  // Get the updated mixed-tile-sizes attribute.
5413  SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
5414  rewriter, sourceTensor.getType(), op.getMixedTiles());
5415 
5416  // Clone op.
5417  // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5418  // this point. However, in practice, we use them for things that we'd like
5419  // to preserve. Implement a better abstraction.
5420  UnPackOp newOp = rewriter.create<UnPackOp>(
5421  op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
5422  newMixedTileSizes, op.getOuterDimsPerm());
5423  newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5424 
5425  // Replace op.
5426  Value oldResult = op.getResult();
5427  Value newResult = newOp.getResult();
5428  Value replacement = (newResult.getType() != oldResult.getType())
5429  ? rewriter.create<tensor::CastOp>(
5430  op->getLoc(), oldResult.getType(), newResult)
5431  : newResult;
5432 
5433  rewriter.replaceOp(op, {replacement});
5434 
5435  return success();
5436  }
5437 };
5438 
5439 //===----------------------------------------------------------------------===//
5440 // BatchReduceMatmulOp
5441 //===----------------------------------------------------------------------===//
5442 SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
5444  utils::IteratorType::reduction, utils::IteratorType::parallel,
5445  utils::IteratorType::parallel, utils::IteratorType::reduction};
5446 }
5447 
5449 BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
5450  AffineExpr d0, d1, d2, d3;
5451  SmallVector<AffineMap> indexingMaps;
5452  bindDims(context, d0, d1, d2, d3);
5453  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
5454  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
5455  indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
5456  return indexingMaps;
5457 }
5458 
5459 unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
5460 
5461 std::string BatchReduceMatmulOp::getLibraryCallName() {
5462  return generateLibraryCallName(getOperation());
5463 }
5464 
5465 /// Check if the op has broadcast and/or transpose semantic. Returns true if
5466 /// the user defined indexing maps are not equal to default map.
5467 bool BatchReduceMatmulOp::hasUserDefinedMaps() {
5468  SmallVector<AffineMap, 3> defaultMaps =
5469  getDefaultIndexingMaps(this->getContext());
5470  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
5471  return defaultMaps != explicitMaps;
5472 }
5473 
5474 /// Returns true if the given bcastMap map is a valid broadcast map. A valid
5475 /// broadcast map must include K dimension.
5476 /// TODO: Strict inclusion of K dimension in the broadcast map is not
5477 /// necessary for both input matrices simultaneously. We can relax this
5478 /// condition to have K dimension for one input matrix map and infer the K
5479 /// dimension for other input matrix map from the one already having K
5480 /// dimension.
5481 bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
5482  bool isLHS) {
5483  assert(bcastMap.getNumResults() < 3 &&
5484  "Expected less than 3 result dim expr.");
5485  bool isValid = false;
5486  enum Indices { batchPos, mPos, nPos, kPos };
5487  if (bcastMap.getNumResults() == 1) {
5488  AffineExpr expr = bcastMap.getResult(0);
5489  isValid = expr.isFunctionOfDim(kPos);
5490  } else if (bcastMap.getNumResults() == 2) {
5491  AffineExpr expr0 = bcastMap.getResult(0);
5492  AffineExpr expr1 = bcastMap.getResult(1);
5493  isValid =
5494  isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
5495  expr0.isFunctionOfDim(mPos)) &&
5496  expr1.isFunctionOfDim(kPos))
5497  : ((expr0.isFunctionOfDim(batchPos) &&
5498  expr1.isFunctionOfDim(kPos)) ||
5499  (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
5500  }
5501  return isValid;
5502 }
5503 
5504 void BatchReduceMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
5505  ArrayRef<NamedAttribute> attrs) {
5506  assert(block.getNumArguments() == 3 &&
5507  "BatchReduceMatmulOp regionBuilder expects 3 (>=0) args");
5508  RegionBuilderHelper helper(b, block);
5509  SmallVector<Value> yields;
5510 
5511  auto toType = block.getArgument(2).getType();
5512  Value castValA =
5513  helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
5514  Value castValB =
5515  helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
5516  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
5517  Value addVal =
5518  helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
5519  yields.push_back(addVal);
5520  helper.yieldOutputs(yields);
5521 }
5522 
5523 ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
5524  OperationState &result) {
5525  SmallVector<Attribute, 3> indexingMapsAttr;
5526  Attribute mapAttr;
5527  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
5528  if (parser.parseEqual())
5529  return failure();
5530  if (parser.parseLSquare())
5531  return failure();
5532 
5533  do {
5534  if (parser.parseAttribute(mapAttr))
5535  return failure();
5536  if (!isa<AffineMapAttr>(mapAttr)) {
5537  return parser.emitError(parser.getCurrentLocation(),
5538  "expected affine map attribute");
5539  }
5540  indexingMapsAttr.push_back(mapAttr);
5541 
5542  if (parser.parseOptionalComma())
5543  break;
5544  } while (true);
5545 
5546  if (parser.parseRSquare())
5547  return failure();
5548  }
5549  // Initialize indexingMaps, if not supplied explicitly.
5550  if (indexingMapsAttr.empty()) {
5551  indexingMapsAttr = llvm::map_to_vector(
5552  BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
5553  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
5554  }
5555  result.addAttribute("indexing_maps",
5556  parser.getBuilder().getArrayAttr(indexingMapsAttr));
5557  return ::parseNamedStructuredOp(parser, result,
5558  BatchReduceMatmulOp::getNumRegionArgs(),
5559  BatchReduceMatmulOp::getRegionBuilder());
5560 }
5561 
5563  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
5564  BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
5565  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
5566 
5567  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
5568  p << " indexing_maps = [";
5569  llvm::interleaveComma(getIndexingMaps(), p,
5570  [&](Attribute attr) { p.printAttribute(attr); });
5571  p << "]";
5572  }
5573 
5574  SmallVector<StringRef, 3> elidedAttrs = {
5575  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
5576  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
5577  elidedAttrs);
5578 }
5579 
5580 /// Verify the user defined indexing maps.
5581 LogicalResult BatchReduceMatmulOp::verify() {
5582  // Verification of pure batch_reduce_matmul is handled by
5583  // verifyStructuredOpInterface().
5584  if (!hasUserDefinedMaps())
5585  return success();
5586 
5587  for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
5588  if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
5589  return failure();
5590  }
5591  return success();
5592 }
5593 LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
5595  return memref::foldMemRefCast(*this);
5596 }
5597 void BatchReduceMatmulOp::getEffects(
5599  &effects) {
5600  if (hasPureTensorSemantics())
5601  return;
5602  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
5603 }
5604 
5605 Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
5606  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
5607 }
5608 
5609 } // namespace linalg
5610 } // namespace mlir
5611 
5612 //===----------------------------------------------------------------------===//
5613 // LinalgDialect
5614 //===----------------------------------------------------------------------===//
5615 
5616 void LinalgDialect::getCanonicalizationPatterns(
5617  RewritePatternSet &results) const {
5618  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
5619  FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
5620 }
5621 
5623  Attribute value, Type type,
5624  Location loc) {
5625  return arith::ConstantOp::materialize(builder, value, type, loc);
5626 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
Definition: LinalgOps.cpp:3491
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:1880
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:329
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
Definition: LinalgOps.cpp:2832
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
Definition: LinalgOps.cpp:2903
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
Definition: LinalgOps.cpp:3469
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:4585
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
Definition: LinalgOps.cpp:128
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:2370
SmallVector< OpFoldResult > innerTiles
Definition: LinalgOps.cpp:4584
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
Definition: LinalgOps.cpp:3484
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:317
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
Definition: LinalgOps.cpp:1713
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
Definition: LinalgOps.cpp:1761
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
Definition: LinalgOps.cpp:2877
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:190
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
Definition: LinalgOps.cpp:3555
static Operation * findPayloadOp(Block *body, bool initFirst=false)
Definition: LinalgOps.cpp:1513
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
Definition: LinalgOps.cpp:162
TernaryFn ternaryFn
Definition: LinalgOps.cpp:4154
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
Definition: LinalgOps.cpp:1383
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:206
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
Definition: LinalgOps.cpp:3517
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2854
ElementwiseArityGroup arityGroup
Definition: LinalgOps.cpp:4148
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
Definition: LinalgOps.cpp:1264
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1231
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2263
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:4583
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:355
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
Definition: LinalgOps.cpp:1006
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:348
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition: LinalgOps.cpp:61
UnaryFn unaryFn
Definition: LinalgOps.cpp:4152
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
Definition: LinalgOps.cpp:1437
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:223
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
Definition: LinalgOps.cpp:1542
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:385
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
Definition: LinalgOps.cpp:3596
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
Definition: LinalgOps.cpp:392
BinaryFn binaryFn
Definition: LinalgOps.cpp:4153
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
Definition: LinalgOps.cpp:243
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
Definition: AffineExpr.cpp:316
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:968
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:615
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:383
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
IndexType getIndexType()
Definition: Builders.cpp:51
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:314
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:730
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:428
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:395
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:409
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
This is a value defined by a result of an operation.
Definition: Value.h:447
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_iterator result_begin()
Definition: Operation.h:413
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_iterator result_end()
Definition: Operation.h:414
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
Definition: Operation.h:523
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:666
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:578
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:104
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1175
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1225
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2651
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Definition: LinalgOps.cpp:4408
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
Definition: LinalgOps.cpp:4953
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
Definition: LinalgOps.cpp:5045
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
Definition: LinalgOps.cpp:4746
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
Definition: LinalgOps.cpp:4464
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
Definition: LinalgOps.cpp:4479
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
Definition: LinalgOps.cpp:4451
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
Definition: LinalgOps.cpp:4761
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2362
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
Definition: LinalgOps.cpp:4163
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
Definition: LinalgOps.cpp:4591
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:108
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:2403
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
Definition: LinalgOps.cpp:4939
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
Definition: LinalgOps.cpp:2342
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:2353
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
Definition: LinalgOps.cpp:4420
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
Definition: LinalgOps.cpp:4376
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Definition: LinalgOps.cpp:4926
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:99
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
Definition: LinalgOps.cpp:4912
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
Definition: LinalgOps.cpp:4881
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
Definition: LinalgOps.cpp:4435
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
Definition: LinalgOps.cpp:4494
FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
Definition: LinalgOps.cpp:3707
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
DynamicAPInt ceil(const Fraction &f)
Definition: Fraction.h:79
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
Definition: Enums.h:443
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
Definition: TensorOps.cpp:363
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:356
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
Definition: TensorOps.cpp:372
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:73
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:788
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1297
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
Definition: LinalgOps.cpp:2023
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:2026
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
Definition: LinalgOps.cpp:2048
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:2051
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:330
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
Definition: LinalgOps.cpp:5108
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:5111
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
Definition: LinalgOps.cpp:5399
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:5402