MLIR 23.0.0git
SPIRVTosaOps.cpp
Go to the documentation of this file.
1//===- SPIRVTosaOps.cpp - MLIR SPIR-V Tosa operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the Tosa operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
14#include "llvm/ADT/STLExtras.h"
15#include "llvm/Support/InterleavedRange.h"
16#include <algorithm>
17
18namespace mlir::spirv {
19
20//===----------------------------------------------------------------------===//
21// SPIRV Tosa Custom formatters
22//===----------------------------------------------------------------------===//
23
27 auto f = [&]() {
28 int32_t value;
29 ParseResult r = parser.parseInteger(value);
30 elements.push_back(value);
31 return r;
32 };
33 if (parser.parseCommaSeparatedList(
35 "parsing values in integer list attribute")) {
36 return failure();
37 }
38
39 auto i32Type = IntegerType::get(parser.getContext(), 32);
40 auto type = TensorArmType::get(
41 ArrayRef{static_cast<int64_t>(elements.size())}, i32Type);
42 attr = DenseIntElementsAttr::get(type, elements);
43 return success();
44}
45
48 printer << llvm::interleaved_array(
49 llvm::map_range(attr.getValues<APInt>(),
50 [](const APInt &a) { return a.getSExtValue(); }));
51}
52
53//===----------------------------------------------------------------------===//
54// SPIRV Tosa Custom verifiers
55//===----------------------------------------------------------------------===//
56
57LogicalResult TosaSelectOp::verify() {
58 TensorArmType condType = getConditionType();
59 TensorArmType trueValType = getTrueValueType();
60 TensorArmType falseValType = getFalseValueType();
61 TensorArmType resultType = getResultType();
62
63 if (llvm::any_of(ArrayRef<TensorArmType>{condType, trueValType, falseValType,
64 resultType},
65 [](TensorArmType type) { return !type.hasRank(); }))
66 return success();
67
68 ArrayRef<int64_t> condShape = condType.getShape();
69 ArrayRef<int64_t> trueValShape = trueValType.getShape();
70 ArrayRef<int64_t> falseValShape = falseValType.getShape();
71 ArrayRef<int64_t> resultShape = resultType.getShape();
72
73 if (!llvm::all_equal({condShape.size(), trueValShape.size(),
74 falseValShape.size(), resultShape.size()})) {
75 // The AllRanksMatch predicate enforces that all ranks are equal.
76 // This is just an extra safe guard for the code coming after that
77 // assumes that all ranks are equal.
78 return failure();
79 }
80
81 for (auto dims :
82 llvm::zip_equal(condShape, trueValShape, falseValShape, resultShape)) {
83 auto [condDim, trueValDim, falseValDim, resultDim] = dims;
84
85 if (llvm::any_of(
86 ArrayRef<int64_t>{condDim, trueValDim, falseValDim, resultDim},
87 [](int64_t dim) { return ShapedType::isDynamic(dim); })) {
88 continue;
89 }
90
91 auto isPairBroadcastable = [](int64_t lhs, int64_t rhs) {
92 return lhs == rhs || lhs == 1 || rhs == 1;
93 };
94
95 if (!isPairBroadcastable(condDim, trueValDim) ||
96 !isPairBroadcastable(condDim, falseValDim) ||
97 !isPairBroadcastable(trueValDim, falseValDim)) {
98 return emitOpError(
99 "failed to verify that the shape of inputs: condition, "
100 "true_value, and false_value are compatible for "
101 "broadcasting");
102 }
103
104 int64_t bradcastedInputDim =
105 std::max(condDim, std::max(trueValDim, falseValDim));
106 if (bradcastedInputDim != resultDim) {
107 return emitOpError(
108 "failed to verify that the broadcast shape of inputs: condition, "
109 "true_value, and false_value is equal to "
110 "the output shape");
111 }
112 }
113 return success();
114}
115
116} // namespace mlir::spirv
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
lhs
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
void printSPIRV_I32_1DArmTensor(OpAsmPrinter &printer, Operation *, DenseIntElementsAttr attr)
ParseResult parseSPIRV_I32_1DArmTensor(OpAsmParser &parser, DenseIntElementsAttr &attr)